// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/attribution_reporting/aggregatable_values.h"

#include <stdint.h>

#include <algorithm>
#include <optional>
#include <utility>

#include "base/check.h"
#include "base/containers/flat_tree.h"
#include "base/numerics/safe_conversions.h"
#include "base/types/expected.h"
#include "base/types/expected_macros.h"
#include "base/values.h"
#include "components/attribution_reporting/aggregatable_utils.h"
#include "components/attribution_reporting/constants.h"
#include "components/attribution_reporting/filters.h"
#include "components/attribution_reporting/parsing_utils.h"
#include "components/attribution_reporting/trigger_registration_error.mojom.h"

namespace attribution_reporting {

namespace {

using ::attribution_reporting::mojom::TriggerRegistrationError;

bool IsValid(const AggregatableValues::Values& values) {
  return std::ranges::all_of(values, [](const auto& value) {
    return IsAggregatableValueInRange(value.second.value());
  });
}

base::expected<AggregatableValues::Values, TriggerRegistrationError>
ParseValues(const base::DictValue& dict,
            TriggerRegistrationError key_error,
            TriggerRegistrationError value_error) {
  AggregatableValues::Values::container_type container;

  for (auto [id, key_value] : dict) {
    ASSIGN_OR_RETURN(AggregatableValuesValue value,
                     AggregatableValuesValue::FromJSON(key_value, value_error));
    container.emplace_back(id, std::move(value));
  }
  return AggregatableValues::Values(base::sorted_unique, std::move(container));
}

}  // namespace

// static
std::optional<AggregatableValuesValue>
AggregatableValuesValue::AggregatableValuesValue::Create(
    int value,
    uint64_t filtering_id) {
  if (!IsAggregatableValueInRange(value)) {
    return std::nullopt;
  }
  return AggregatableValuesValue(value, filtering_id);
}

// static
base::expected<AggregatableValuesValue, TriggerRegistrationError>
AggregatableValuesValue::FromJSON(const base::Value& json,
                                  TriggerRegistrationError value_error) {
  int value;
  std::optional<uint64_t> filtering_id;

  if (const base::DictValue* dict = json.GetIfDict()) {
    const base::Value* value_v = dict->Find(kValue);
    if (!value_v) {
      return base::unexpected(value_error);
    }
    ASSIGN_OR_RETURN(value, ParseAggregatableValue(*value_v),
                     [value_error](std::monostate) { return value_error; });

    ASSIGN_OR_RETURN(filtering_id, ParseUint64(*dict, kFilteringId),
                     [value_error](std::monostate) { return value_error; });
  } else {
    ASSIGN_OR_RETURN(value, ParseAggregatableValue(json),
                     [value_error](std::monostate) { return value_error; });
  }
  return AggregatableValuesValue(value,
                                 filtering_id.value_or(kDefaultFilteringId));
}

AggregatableValuesValue::AggregatableValuesValue(uint32_t value,
                                                 uint64_t filtering_id)
    : value_(value), filtering_id_(filtering_id) {}

// static
std::optional<AggregatableValues> AggregatableValues::Create(
    Values values,
    FilterPair filters) {
  if (!IsValid(values)) {
    return std::nullopt;
  }

  return AggregatableValues(std::move(values), std::move(filters));
}

// static
base::expected<std::vector<AggregatableValues>, TriggerRegistrationError>
AggregatableValues::FromJSON(base::Value* input_value) {
  std::vector<AggregatableValues> configs;
  if (!input_value) {
    return configs;
  }

  if (base::DictValue* dict = input_value->GetIfDict()) {
    ASSIGN_OR_RETURN(
        Values values,
        ParseValues(*dict,
                    TriggerRegistrationError::kAggregatableValuesKeyTooLong,
                    TriggerRegistrationError::kAggregatableValuesValueInvalid));
    if (!values.empty()) {
      configs.push_back(AggregatableValues(std::move(values), FilterPair()));
    }
  } else if (base::ListValue* list = input_value->GetIfList()) {
    configs.reserve(list->size());
    for (auto& maybe_dict_value : *list) {
      base::DictValue* dict_value = maybe_dict_value.GetIfDict();
      if (!dict_value) {
        return base::unexpected(
            TriggerRegistrationError::kAggregatableValuesWrongType);
      }

      const base::DictValue* agg_values_dict = dict_value->FindDict(kValues);
      if (!agg_values_dict) {
        return base::unexpected(TriggerRegistrationError::
                                    kAggregatableValuesListValuesFieldMissing);
      }

      ASSIGN_OR_RETURN(
          Values values,
          ParseValues(
              *agg_values_dict,
              TriggerRegistrationError::kAggregatableValuesListKeyTooLong,
              TriggerRegistrationError::kAggregatableValuesListValueInvalid));
      ASSIGN_OR_RETURN(FilterPair filters, FilterPair::FromJSON(*dict_value));

      configs.push_back(
          AggregatableValues(std::move(values), std::move(filters)));
    }
  } else {
    return base::unexpected(
        TriggerRegistrationError::kAggregatableValuesWrongType);
  }
  return configs;
}

base::DictValue AggregatableValuesValue::ToJson() const {
  CHECK(base::IsValueInRangeForNumericType<int>(value_));

  base::DictValue dict;

  dict.Set(kValue, static_cast<int>(value_));
  SerializeUint64(dict, kFilteringId, filtering_id_);

  return dict;
}

AggregatableValues::AggregatableValues() = default;

AggregatableValues::AggregatableValues(Values values, FilterPair filters)
    : values_(std::move(values)), filters_(std::move(filters)) {
  CHECK(IsValid(values_));
}

AggregatableValues::~AggregatableValues() = default;

AggregatableValues::AggregatableValues(const AggregatableValues&) = default;

AggregatableValues& AggregatableValues::operator=(const AggregatableValues&) =
    default;

AggregatableValues::AggregatableValues(AggregatableValues&&) = default;

AggregatableValues& AggregatableValues::operator=(AggregatableValues&&) =
    default;

base::DictValue AggregatableValues::ToJson() const {
  base::DictValue values_dict;
  for (const auto& [key, value] : values_) {
    values_dict.Set(key, value.ToJson());
  }

  base::DictValue dict;
  dict.Set(kValues, std::move(values_dict));
  filters_.SerializeIfNotEmpty(dict);
  return dict;
}

}  // namespace attribution_reporting
