Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions willow/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,15 @@ cc_library(
":client_cxx",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/strings:string_view",
"@cxx.rs//:core",
"//ffi_utils:cxx_utils",
"//ffi_utils:status_macros",
"//willow/input_encoding:codec",
"//willow/proto/shell:shell_ciphertexts_cc_proto",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:input_spec_cc_proto",
"//willow/proto/willow:messages_cc_proto",
"//willow/proto/willow:server_accumulator_cc_proto",
],
Expand Down
43 changes: 42 additions & 1 deletion willow/api/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,62 @@

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "ffi_utils/cxx_utils.h"
#include "ffi_utils/status_macros.h"
#include "include/cxx.h"
#include "willow/api/client.rs.h"
#include "willow/input_encoding/codec.h"
#include "willow/proto/shell/ciphertexts.pb.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/input_spec.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"

namespace secure_aggregation {

absl::StatusOr<willow::AggregationConfigProto> CreateAggregationConfig(
const willow::InputSpec& input_spec_proto, absl::string_view key_id,
int64_t max_number_of_clients, int64_t max_number_of_decryptors,
int64_t max_decryptor_dropouts, int64_t default_max_metric_value) {
willow::AggregationConfigProto config_proto;
config_proto.set_max_number_of_clients(max_number_of_clients);
config_proto.set_max_number_of_decryptors(max_number_of_decryptors);
config_proto.set_max_decryptor_dropouts(max_decryptor_dropouts);
config_proto.set_key_id(std::string(key_id));
// All metrics have same vector length, corresponding to the Cartesian product
// of group-by domains.
int64_t flattened_domain_size = 1;
for (const auto& group_by_spec : input_spec_proto.group_by_vector_specs()) {
if (group_by_spec.domain_spec().string_values().values_size() == 0) {
return absl::InvalidArgumentError(absl::StrCat(
"Missing domain, invalid domain type (must be StringValues), or "
"empty string_values for group by vector: ",
group_by_spec.vector_name()));
}
flattened_domain_size *=
group_by_spec.domain_spec().string_values().values_size();
}
// Build VectorConfig (length and bound) for each metric.
for (const auto& metric_spec : input_spec_proto.metric_vector_specs()) {
auto& vector_config =
(*config_proto.mutable_vector_configs())[metric_spec.vector_name()];
vector_config.set_length(flattened_domain_size);
if (metric_spec.has_domain_spec() &&
metric_spec.domain_spec().has_interval()) {
vector_config.set_bound(
static_cast<int64_t>(metric_spec.domain_spec().interval().max()));
} else {
vector_config.set_bound(default_max_metric_value);
}
}
return config_proto;
}

absl::StatusOr<willow::ClientMessage> GenerateClientContribution(
const willow::AggregationConfigProto& aggregation_config,
const willow::EncodedData& encoded_data,
const willow::ShellAhePublicKey& public_key, const std::string& nonce) {
const willow::ShellAhePublicKey& public_key, absl::string_view nonce) {
// Initialize client.
std::string config_str = aggregation_config.SerializeAsString();
auto config_ptr = std::make_unique<std::string>(std::move(config_str));
Expand Down
27 changes: 25 additions & 2 deletions willow/api/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,46 @@
#ifndef SECURE_AGGREGATION_WILLOW_API_CLIENT_H_
#define SECURE_AGGREGATION_WILLOW_API_CLIENT_H_

#include <string>
#include <cstdint>

#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "willow/input_encoding/codec.h"
#include "willow/proto/shell/ciphertexts.pb.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/input_spec.pb.h"
#include "willow/proto/willow/messages.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"

namespace secure_aggregation {

// Default maximum bound for metric vector elements if not specified in the
// InputSpec and not passed to CreateAggregationConfig.
inline constexpr int64_t kDefaultMaxMetricValue = 1LL << 30;

// Default configuration for Willow clients and decryptors.
inline constexpr int64_t kDefaultMaxNumberOfClients = 10000000;
inline constexpr int64_t kDefaultMaxDecryptors = 1;
inline constexpr int64_t kDefaultMaxDecryptorDropouts = 0;

// Creates an AggregationConfigProto from the given InputSpec and other
// parameters. For each metric, it builds a willow.VectorConfig proto, where
// the length is the Cartesian product of the group-by vector domains, and the
// bound is the interval max if specified, or default_max_metric_value
// otherwise.
absl::StatusOr<willow::AggregationConfigProto> CreateAggregationConfig(
const willow::InputSpec& input_spec, absl::string_view key_id,
int64_t max_number_of_clients = kDefaultMaxNumberOfClients,
int64_t max_number_of_decryptors = kDefaultMaxDecryptors,
int64_t max_decryptor_dropouts = kDefaultMaxDecryptorDropouts,
int64_t default_max_metric_value = kDefaultMaxMetricValue);

// Generates a client contribution by encrypting the encoded data with the
// provided AHE public key.
absl::StatusOr<willow::ClientMessage> GenerateClientContribution(
const willow::AggregationConfigProto& aggregation_config,
const willow::EncodedData& encoded_data,
const willow::ShellAhePublicKey& public_key, const std::string& nonce);
const willow::ShellAhePublicKey& public_key, absl::string_view nonce);

} // namespace secure_aggregation

Expand Down
65 changes: 65 additions & 0 deletions willow/api/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "willow/api/client.h"

#include <cstdint>
#include <memory>
#include <string>

Expand Down Expand Up @@ -139,6 +140,70 @@ TEST(WillowShellClientTest, InvalidAggregationConfig) {
StatusIs(absl::StatusCode::kInvalidArgument));
}

TEST(WillowShellClientTest, CreateAggregationConfigSuccess) {
InputSpec input_spec = CreateTestInputSpecProto();
SECAGG_ASSERT_OK_AND_ASSIGN(
AggregationConfigProto config,
CreateAggregationConfig(input_spec, /*key_id=*/"test",
/*max_number_of_clients=*/10));

AggregationConfigProto expected_config = CreateTestAggregationConfigProto();
EXPECT_EQ(config.max_number_of_decryptors(),
expected_config.max_number_of_decryptors());
EXPECT_EQ(config.max_number_of_clients(),
expected_config.max_number_of_clients());
EXPECT_EQ(config.key_id(), expected_config.key_id());

const auto& vector_configs = config.vector_configs();
const auto& expected_vector_configs = expected_config.vector_configs();
EXPECT_EQ(vector_configs.size(), expected_vector_configs.size());
for (const auto& [key, value] : expected_vector_configs) {
ASSERT_TRUE(vector_configs.contains(key));
EXPECT_EQ(vector_configs.at(key).length(), value.length());
EXPECT_EQ(vector_configs.at(key).bound(), value.bound());
}
}

TEST(WillowShellClientTest, CreateAggregationConfigDefaultBound) {
InputSpec input_spec = CreateTestInputSpecProto();
// Clear the interval to verify that the default bound is used.
input_spec.mutable_metric_vector_specs(0)
->mutable_domain_spec()
->clear_interval();

int64_t default_bound = 12345;
SECAGG_ASSERT_OK_AND_ASSIGN(
AggregationConfigProto config,
CreateAggregationConfig(input_spec, /*key_id=*/"test",
/*max_number_of_clients=*/10,
/*max_number_of_decryptors=*/1,
/*max_decryptor_dropouts=*/0,
/*default_max_metric_value=*/default_bound));

const auto& vector_configs = config.vector_configs();
ASSERT_TRUE(vector_configs.contains("metric1"));
EXPECT_EQ(vector_configs.at("metric1").bound(), default_bound);
}

TEST(WillowShellClientTest, CreateAggregationConfigFailsOnEmptyDomain) {
InputSpec input_spec = CreateTestInputSpecProto();
ASSERT_GT(input_spec.group_by_vector_specs_size(), 0);

// Clear the string values to zero length.
input_spec.mutable_group_by_vector_specs(0)
->mutable_domain_spec()
->mutable_string_values()
->clear_values();

EXPECT_THAT(CreateAggregationConfig(input_spec, "test", 10),
StatusIs(absl::StatusCode::kInvalidArgument));

// Also fails if domain is not set.
input_spec.mutable_group_by_vector_specs(0)->clear_domain_spec();
EXPECT_THAT(CreateAggregationConfig(input_spec, "test", 10),
StatusIs(absl::StatusCode::kInvalidArgument));
}

} // namespace
} // namespace willow
} // namespace secure_aggregation