blob: c1a252f3b593c74d7ff243f1bfd664592faec2e3 [file] [log] [blame]
#include <fcntl.h>
#include <glog/logging.h>
#include "absl/container/btree_set.h"
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/flags/usage.h"
#include "absl/random/random.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/time/clock.h"
#include "benchmark/benchmark.h" // third_party/benchmark
#include "dpf/distributed_point_function.h"
#include "dpf/distributed_point_function.pb.h"
#include "imap.hpp"
#include "riegeli/bytes/fd_reader.h"
#include "riegeli/lines/line_reading.h"
ABSL_FLAG(std::string, input, "",
"CSV file containing non-zero buckets in the first column.");
ABSL_FLAG(int, log_domain_size, 20,
"Logarithm of the domain size. All non-zeros in `input` must be in "
"[0, 2^log_domain_size).");
ABSL_FLAG(int, num_iterations, 20, "Number of iterations to benchmark.");
ABSL_FLAG(int, max_expansion_factor, 2,
"Limits the maximum number of elements the expansion at any "
"hierarchy level can have to a multiple of the number of unique "
"buckets in the input file. Must be at least 2.");
ABSL_FLAG(bool, only_nonzeros, false,
"Only evaluates at the nonzero indices of the input file passed via "
"--input, instead of performing hierarchical evaluation. If true, "
"all flags related to hierarchy levels will be ignored");
ABSL_FLAG(std::vector<std::string>, levels_to_evaluate, {},
"List of integers specifying the log domain sizes at which to insert "
"hierarchy levels.");
#ifndef QCHECK
#define QCHECK(x) CHECK(x)
#endif
namespace {
const char* Usage() {
return "synthetic_data_benchmarks [OPTIONS]\n\n"
"Runs a single DPF key evaluation on the specified domain. If an "
"input file is specified with --input, it is read as a CSV file "
"containing the bucket IDs to expand in the first column. Otherwise, "
"the full domain will be expanded.";
}
void ValidateFlags() {
int log_domain_size = absl::GetFlag(FLAGS_log_domain_size);
QCHECK(log_domain_size >= 0) << "--log_domain_size must be non-negative";
int num_iterations = absl::GetFlag(FLAGS_num_iterations);
QCHECK(num_iterations > 0) << "--num_iterations must be positive";
if (absl::GetFlag(FLAGS_only_nonzeros)) {
QCHECK(!absl::GetFlag(FLAGS_input).empty())
<< "--input is required when --only_nonzeros is true";
}
int max_expansion_factor = absl::GetFlag(FLAGS_max_expansion_factor);
QCHECK(max_expansion_factor >= 2)
<< "--max_expansion_factor must be at least 2";
std::vector<std::string> levels_to_evaluate =
absl::GetFlag(FLAGS_levels_to_evaluate);
for (absl::string_view level_str : levels_to_evaluate) {
int level;
QCHECK(absl::SimpleAtoi(level_str, &level));
QCHECK(level > 0 && level <= log_domain_size)
<< "--levels_to_evaluate must be in [1, log_domain_size]";
}
}
// Returns the prefixes of the given buckets using the bit lengths specified
// in `parameters`.
std::vector<std::vector<absl::uint128>> ComputePrefixes(
const absl::btree_set<absl::uint128>& last_level_prefixes,
int log_domain_size) {
std::vector<std::vector<absl::uint128>> result(log_domain_size + 1);
result.back() = std::vector<absl::uint128>(last_level_prefixes.begin(),
last_level_prefixes.end());
// Iterate backwards through previous levels, computing prefixes by
// appropriately shifting the ones from higher levels.
for (int i = static_cast<int>(result.size()) - 1; i > 1; --i) {
absl::btree_set<absl::uint128> current_level_prefixes;
for (const auto& x : result[i]) {
current_level_prefixes.insert(x >> 1);
}
result[i - 1] = std::vector<absl::uint128>(current_level_prefixes.begin(),
current_level_prefixes.end());
}
return result;
}
// Parses `input_file` as a CSV file and returns the unique integers in the
// first column as a set.
absl::btree_set<absl::uint128> ReadUniqueValuesFromFile(
absl::string_view input_file) {
absl::btree_set<absl::uint128> nonzeros;
LOG(INFO) << "Reading input file...";
int line_number = 0;
riegeli::FdReader reader(input_file, O_RDONLY);
absl::string_view line;
while (riegeli::ReadLine(reader, line)) {
std::vector<absl::string_view> fields =
absl::StrSplit(line, ',', absl::SkipWhitespace());
QCHECK(!fields.empty()) << "Line " << line_number << " is empty";
absl::uint128 nonzero;
QCHECK(absl::SimpleAtoi(fields[0], &nonzero))
<< "Invalid bucket ID on line " << line_number;
nonzeros.insert(nonzero);
++line_number;
}
QCHECK(reader.healthy());
LOG(INFO) << "Read " << nonzeros.size() << " nonzeros from " << line_number
<< " lines";
return nonzeros;
}
std::vector<int> ComputeLevelsToEvaluate(
absl::Span<const std::vector<absl::uint128>> prefixes,
int log_domain_size) {
int num_nonzeros = prefixes.back().size();
if (num_nonzeros > 0) {
std::vector<int> levels_to_evaluate;
// The first level is chosen such that it has size at most expansion_factor
// * num_nonzeros.
int max_expansion_factor = absl::GetFlag(FLAGS_max_expansion_factor);
int first_level =
std::min(log_domain_size,
static_cast<int>(std::log2(num_nonzeros) +
std::log2(max_expansion_factor))) -
1;
levels_to_evaluate.push_back(first_level);
while (levels_to_evaluate.back() < log_domain_size) {
int nonzeros_at_last_level =
prefixes[levels_to_evaluate.back() + 1].size();
// We want to evaluate as many levels as possible so that we get no more
// than expansion_factor * num_nonzeros. So 2^bits_to_next_level *
// nonzeros_at_last_level < expansion_factor * num_nonzeros.
levels_to_evaluate.push_back(std::min(
log_domain_size,
static_cast<int>(levels_to_evaluate.back() + std::log2(num_nonzeros) +
std::log2(max_expansion_factor) -
std::log2(nonzeros_at_last_level))));
}
return levels_to_evaluate;
}
return {log_domain_size};
}
template <typename T>
void RunHierarchicalEvaluation(
const distributed_point_functions::DistributedPointFunction& dpf,
const distributed_point_functions::DpfKey& key,
absl::Span<const std::vector<absl::uint128>> prefixes, int num_iterations) {
const distributed_point_functions::EvaluationContext ctx =
dpf.CreateEvaluationContext(key).value();
CHECK_EQ(prefixes.size(), ctx.parameters_size());
for (int i = 0; i < num_iterations; ++i) {
distributed_point_functions::EvaluationContext ctx_copy = ctx;
for (int level = 0; level < static_cast<int>(prefixes.size()); ++level) {
std::vector<T> result =
dpf.EvaluateUntil<T>(level, prefixes[level], ctx_copy).value();
if (i == 0) {
LOG(INFO) << "Number of outputs at " << level
<< "-th level: " << result.size();
LOG(INFO) << "log_domain_size="
<< ctx.parameters(level).log_domain_size();
}
benchmark::DoNotOptimize(result);
}
}
}
template <typename T>
void RunBatchedSinglePointEvaluation(
const distributed_point_functions::DistributedPointFunction& dpf,
const distributed_point_functions::DpfKey& key,
absl::Span<const absl::uint128> nonzeros, int num_iterations) {
// Check that we have a single hierarchy level.
CHECK_EQ(dpf.parameters().size(), 1);
for (int i = 0; i < num_iterations; ++i) {
std::vector<T> result = dpf.EvaluateAt<T>(key, 0, nonzeros).value();
CHECK_EQ(result.size(), nonzeros.size());
benchmark::DoNotOptimize(result);
}
}
} // namespace
int main(int argc, char* argv[]) {
google::InitGoogleLogging(argv[0]);
absl::SetProgramUsageMessage(Usage());
absl::ParseCommandLine(argc, argv);
FLAGS_logtostderr = 1;
ValidateFlags();
// Read nonzeros from input file, compute prefixes,
std::string input_file = absl::GetFlag(FLAGS_input);
const int log_domain_size = absl::GetFlag(FLAGS_log_domain_size);
std::vector<std::vector<absl::uint128>> prefixes(1);
if (!input_file.empty()) {
absl::btree_set<absl::uint128> nonzeros =
ReadUniqueValuesFromFile(input_file);
prefixes = ComputePrefixes(nonzeros, log_domain_size);
}
int num_nonzeros = prefixes.back().size();
LOG(INFO) << "Number of nonzeros: " << num_nonzeros;
// Compute levels to evaluate and choose the correct prefixes.
std::vector<std::string> levels_to_evaluate_str =
absl::GetFlag(FLAGS_levels_to_evaluate);
std::vector<int> levels_to_evaluate(levels_to_evaluate_str.size());
bool only_nonzeros = absl::GetFlag(FLAGS_only_nonzeros);
for (int i = 0; i < static_cast<int>(levels_to_evaluate.size()); ++i) {
CHECK(absl::SimpleAtoi(levels_to_evaluate_str[i], &levels_to_evaluate[i]));
}
if (levels_to_evaluate.empty()) {
if (!only_nonzeros) {
levels_to_evaluate = ComputeLevelsToEvaluate(prefixes, log_domain_size);
} else {
levels_to_evaluate = {log_domain_size};
}
}
LOG(INFO) << "Levels to evaluate: " << absl::StrJoin(levels_to_evaluate, ",");
std::vector<std::vector<absl::uint128>> prefixes_to_evaluate(1);
prefixes_to_evaluate.reserve(levels_to_evaluate.size());
for (int i = 1; i < levels_to_evaluate.size(); ++i) {
prefixes_to_evaluate.push_back(prefixes[levels_to_evaluate[i - 1]]);
}
LOG(INFO) << "Numbers of prefixes per level: "
<< absl::StrJoin(iter::imap([](auto& c) { return c.size(); },
prefixes_to_evaluate),
",");
LOG(INFO) << "Numbers of prefixes per bit: "
<< absl::StrJoin(
iter::imap([](auto& c) { return c.size(); }, prefixes), ",");
// Set up parameters and create DPF instance.
std::vector<distributed_point_functions::DpfParameters> parameters(
levels_to_evaluate.size());
const int element_bitsize = 32; // TODO(schoppmann): Make this a flag?
for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(
element_bitsize);
parameters[i].set_log_domain_size(levels_to_evaluate[i]);
}
std::unique_ptr<distributed_point_functions::DistributedPointFunction> dpf =
distributed_point_functions::DistributedPointFunction::CreateIncremental(
parameters)
.value();
// Generate DPF key.
absl::BitGen rng;
absl::uint128 alpha = absl::MakeUint128(absl::Uniform<uint64_t>(rng),
absl::Uniform<uint64_t>(rng));
if (log_domain_size < 128) {
alpha %= absl::uint128{1} << log_domain_size;
}
std::vector<absl::uint128> beta(parameters.size(), 1);
distributed_point_functions::DpfKey key;
std::tie(key, std::ignore) =
dpf->GenerateKeysIncremental(alpha, beta).value();
// Run the experiment and measure time.
int num_iterations = absl::GetFlag(FLAGS_num_iterations);
using T = uint32_t;
absl::Time start = absl::Now();
if (only_nonzeros) {
RunBatchedSinglePointEvaluation<T>(*dpf, key, prefixes.back(),
num_iterations);
} else {
RunHierarchicalEvaluation<T>(*dpf, key, prefixes_to_evaluate,
num_iterations);
}
absl::Duration wallclock = absl::Now() - start;
LOG(INFO) << "Wallclock time per iteration: "
<< wallclock / absl::GetFlag(FLAGS_num_iterations);
}