From 8200aed1f9ea6528bfec7db8148fba794a65c38b Mon Sep 17 00:00:00 2001 From: Pierre Tholoniat Date: Wed, 18 Mar 2026 07:16:09 -0700 Subject: [PATCH] Refactor FCP Willow client implementation. PiperOrigin-RevId: 885576096 --- willow/api/BUILD | 3 ++ willow/api/client.cc | 43 +++++++++++++++++++++++++- willow/api/client.h | 27 ++++++++++++++-- willow/api/client_test.cc | 65 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 3 deletions(-) diff --git a/willow/api/BUILD b/willow/api/BUILD index 4af3800..f54b9d8 100644 --- a/willow/api/BUILD +++ b/willow/api/BUILD @@ -176,12 +176,15 @@ cc_library( ":client_cxx", "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:string_view", "@cxx.rs//:core", "//ffi_utils:cxx_utils", "//ffi_utils:status_macros", "//willow/input_encoding:codec", "//willow/proto/shell:shell_ciphertexts_cc_proto", "//willow/proto/willow:aggregation_config_cc_proto", + "//willow/proto/willow:input_spec_cc_proto", "//willow/proto/willow:messages_cc_proto", "//willow/proto/willow:server_accumulator_cc_proto", ], diff --git a/willow/api/client.cc b/willow/api/client.cc index d875933..b36ed42 100644 --- a/willow/api/client.cc +++ b/willow/api/client.cc @@ -22,6 +22,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "ffi_utils/cxx_utils.h" #include "ffi_utils/status_macros.h" #include "include/cxx.h" @@ -29,14 +31,53 @@ #include "willow/input_encoding/codec.h" #include "willow/proto/shell/ciphertexts.pb.h" #include "willow/proto/willow/aggregation_config.pb.h" +#include "willow/proto/willow/input_spec.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" namespace secure_aggregation { +absl::StatusOr CreateAggregationConfig( + const willow::InputSpec& input_spec_proto, absl::string_view key_id, + int64_t max_number_of_clients, int64_t max_number_of_decryptors, + int64_t max_decryptor_dropouts, int64_t default_max_metric_value) { + willow::AggregationConfigProto config_proto; + config_proto.set_max_number_of_clients(max_number_of_clients); + config_proto.set_max_number_of_decryptors(max_number_of_decryptors); + config_proto.set_max_decryptor_dropouts(max_decryptor_dropouts); + config_proto.set_key_id(std::string(key_id)); + // All metrics have same vector length, corresponding to the Cartesian product + // of group-by domains. + int64_t flattened_domain_size = 1; + for (const auto& group_by_spec : input_spec_proto.group_by_vector_specs()) { + if (group_by_spec.domain_spec().string_values().values_size() == 0) { + return absl::InvalidArgumentError(absl::StrCat( + "Missing domain, invalid domain type (must be StringValues), or " + "empty string_values for group by vector: ", + group_by_spec.vector_name())); + } + flattened_domain_size *= + group_by_spec.domain_spec().string_values().values_size(); + } + // Build VectorConfig (length and bound) for each metric. + for (const auto& metric_spec : input_spec_proto.metric_vector_specs()) { + auto& vector_config = + (*config_proto.mutable_vector_configs())[metric_spec.vector_name()]; + vector_config.set_length(flattened_domain_size); + if (metric_spec.has_domain_spec() && + metric_spec.domain_spec().has_interval()) { + vector_config.set_bound( + static_cast(metric_spec.domain_spec().interval().max())); + } else { + vector_config.set_bound(default_max_metric_value); + } + } + return config_proto; +} + absl::StatusOr GenerateClientContribution( const willow::AggregationConfigProto& aggregation_config, const willow::EncodedData& encoded_data, - const willow::ShellAhePublicKey& public_key, const std::string& nonce) { + const willow::ShellAhePublicKey& public_key, absl::string_view nonce) { // Initialize client. std::string config_str = aggregation_config.SerializeAsString(); auto config_ptr = std::make_unique(std::move(config_str)); diff --git a/willow/api/client.h b/willow/api/client.h index 7d8551b..1b33c7f 100644 --- a/willow/api/client.h +++ b/willow/api/client.h @@ -17,23 +17,46 @@ #ifndef SECURE_AGGREGATION_WILLOW_API_CLIENT_H_ #define SECURE_AGGREGATION_WILLOW_API_CLIENT_H_ -#include +#include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "willow/input_encoding/codec.h" #include "willow/proto/shell/ciphertexts.pb.h" #include "willow/proto/willow/aggregation_config.pb.h" +#include "willow/proto/willow/input_spec.pb.h" #include "willow/proto/willow/messages.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" namespace secure_aggregation { +// Default maximum bound for metric vector elements if not specified in the +// InputSpec and not passed to CreateAggregationConfig. +inline constexpr int64_t kDefaultMaxMetricValue = 1LL << 30; + +// Default configuration for Willow clients and decryptors. +inline constexpr int64_t kDefaultMaxNumberOfClients = 10000000; +inline constexpr int64_t kDefaultMaxDecryptors = 1; +inline constexpr int64_t kDefaultMaxDecryptorDropouts = 0; + +// Creates an AggregationConfigProto from the given InputSpec and other +// parameters. For each metric, it builds a willow.VectorConfig proto, where +// the length is the Cartesian product of the group-by vector domains, and the +// bound is the interval max if specified, or default_max_metric_value +// otherwise. +absl::StatusOr CreateAggregationConfig( + const willow::InputSpec& input_spec, absl::string_view key_id, + int64_t max_number_of_clients = kDefaultMaxNumberOfClients, + int64_t max_number_of_decryptors = kDefaultMaxDecryptors, + int64_t max_decryptor_dropouts = kDefaultMaxDecryptorDropouts, + int64_t default_max_metric_value = kDefaultMaxMetricValue); + // Generates a client contribution by encrypting the encoded data with the // provided AHE public key. absl::StatusOr GenerateClientContribution( const willow::AggregationConfigProto& aggregation_config, const willow::EncodedData& encoded_data, - const willow::ShellAhePublicKey& public_key, const std::string& nonce); + const willow::ShellAhePublicKey& public_key, absl::string_view nonce); } // namespace secure_aggregation diff --git a/willow/api/client_test.cc b/willow/api/client_test.cc index 73e3ae6..05f5956 100644 --- a/willow/api/client_test.cc +++ b/willow/api/client_test.cc @@ -16,6 +16,7 @@ #include "willow/api/client.h" +#include #include #include @@ -139,6 +140,70 @@ TEST(WillowShellClientTest, InvalidAggregationConfig) { StatusIs(absl::StatusCode::kInvalidArgument)); } +TEST(WillowShellClientTest, CreateAggregationConfigSuccess) { + InputSpec input_spec = CreateTestInputSpecProto(); + SECAGG_ASSERT_OK_AND_ASSIGN( + AggregationConfigProto config, + CreateAggregationConfig(input_spec, /*key_id=*/"test", + /*max_number_of_clients=*/10)); + + AggregationConfigProto expected_config = CreateTestAggregationConfigProto(); + EXPECT_EQ(config.max_number_of_decryptors(), + expected_config.max_number_of_decryptors()); + EXPECT_EQ(config.max_number_of_clients(), + expected_config.max_number_of_clients()); + EXPECT_EQ(config.key_id(), expected_config.key_id()); + + const auto& vector_configs = config.vector_configs(); + const auto& expected_vector_configs = expected_config.vector_configs(); + EXPECT_EQ(vector_configs.size(), expected_vector_configs.size()); + for (const auto& [key, value] : expected_vector_configs) { + ASSERT_TRUE(vector_configs.contains(key)); + EXPECT_EQ(vector_configs.at(key).length(), value.length()); + EXPECT_EQ(vector_configs.at(key).bound(), value.bound()); + } +} + +TEST(WillowShellClientTest, CreateAggregationConfigDefaultBound) { + InputSpec input_spec = CreateTestInputSpecProto(); + // Clear the interval to verify that the default bound is used. + input_spec.mutable_metric_vector_specs(0) + ->mutable_domain_spec() + ->clear_interval(); + + int64_t default_bound = 12345; + SECAGG_ASSERT_OK_AND_ASSIGN( + AggregationConfigProto config, + CreateAggregationConfig(input_spec, /*key_id=*/"test", + /*max_number_of_clients=*/10, + /*max_number_of_decryptors=*/1, + /*max_decryptor_dropouts=*/0, + /*default_max_metric_value=*/default_bound)); + + const auto& vector_configs = config.vector_configs(); + ASSERT_TRUE(vector_configs.contains("metric1")); + EXPECT_EQ(vector_configs.at("metric1").bound(), default_bound); +} + +TEST(WillowShellClientTest, CreateAggregationConfigFailsOnEmptyDomain) { + InputSpec input_spec = CreateTestInputSpecProto(); + ASSERT_GT(input_spec.group_by_vector_specs_size(), 0); + + // Clear the string values to zero length. + input_spec.mutable_group_by_vector_specs(0) + ->mutable_domain_spec() + ->mutable_string_values() + ->clear_values(); + + EXPECT_THAT(CreateAggregationConfig(input_spec, "test", 10), + StatusIs(absl::StatusCode::kInvalidArgument)); + + // Also fails if domain is not set. + input_spec.mutable_group_by_vector_specs(0)->clear_domain_spec(); + EXPECT_THAT(CreateAggregationConfig(input_spec, "test", 10), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + } // namespace } // namespace willow } // namespace secure_aggregation