-
Notifications
You must be signed in to change notification settings - Fork 30
Byzantine robust aggregator - initial implementation #913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
tharvik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the great work! 🎉
a few things to move around/improve and we'll be ready to go!
can you also add a testcase to ensure that it behave well?
|
maybe later once your code is tested we could also do a small comparison to the older implementation here e53676c (that's a very old completely differently structured version of disco, but potentially achieving similar ML accuracy target under byzantine, which might be interesting to check if time remains) |
1cd555d to
40fe815
Compare
Co-authored-by: Valérian Rousset <tharvik@users.noreply.github.com>
8af37ff to
7726dbc
Compare
tharvik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
incredible work, that's so great! in addition I forgot about the MultiRoundAggregator and that's waay clearer to use this way.
can you also add a way to create a task with the byzantine aggregator? in TaskCreationForm, we need to add a way to select it, and with parameters you think are understable by laypersons
here goes a small patch for the core of disco (feel free to change it, I choose of using the default values but you know best what can be interesting to be set in the Task), adding a e2e test in server using this aggregator would also ensure that it works with the whole system.
diff --git a/discojs/src/aggregator/get.ts b/discojs/src/aggregator/get.ts
index 23b3bc88..49550bb6 100644
--- a/discojs/src/aggregator/get.ts
+++ b/discojs/src/aggregator/get.ts
@@ -1,5 +1,6 @@
import type { DataType, Network, Task } from '../index.js'
import { aggregator } from '../index.js'
+import { ByzantineRobustAggregator } from './byzantine.js';
type AggregatorOptions = Partial<{
scheme: Task<DataType, Network>["trainingInformation"]["scheme"]; // if undefined, fallback on task.trainingInformation.scheme
@@ -33,6 +34,8 @@ export function getAggregator(
const scheme = options.scheme ?? task.trainingInformation.scheme
switch (task.trainingInformation.aggregationStrategy) {
+ case 'byzantine':
+ return new ByzantineRobustAggregator()
case 'mean':
if (scheme === 'decentralized') {
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
diff --git a/discojs/src/default_tasks/cifar10.ts b/discojs/src/default_tasks/cifar10.ts
index 56ca523d..ba3ace5d 100644
--- a/discojs/src/default_tasks/cifar10.ts
+++ b/discojs/src/default_tasks/cifar10.ts
@@ -35,7 +35,7 @@ export const cifar10: TaskProvider<"image", "decentralized"> = {
IMAGE_W: 224,
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
scheme: 'decentralized',
- aggregationStrategy: 'mean',
+ aggregationStrategy: 'byzantine',
privacy: { clippingRadius: 20, noiseScale: 1 },
minNbOfParticipants: 3,
maxShareValue: 100,
diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts
index 145fb223..391e0039 100644
--- a/discojs/src/task/training_information.ts
+++ b/discojs/src/task/training_information.ts
@@ -72,7 +72,10 @@ export namespace TrainingInformation {
.and(
z.union([
z.object({
- aggregationStrategy: z.literal("mean"),
+ aggregationStrategy: z.union([
+ z.literal("byzantine"),
+ z.literal("mean"),
+ ]),
}),
z.object({
aggregationStrategy: z.literal("secure"),
@@ -85,7 +88,10 @@ export namespace TrainingInformation {
federated: z
.object({
scheme: z.literal("federated"),
- aggregationStrategy: z.literal("mean"),
+ aggregationStrategy: z.union([
+ z.literal("mean"),
+ z.literal("byzantine"),
+ ]),
})
.merge(nonLocalNetworkSchema),
local: z.object({Co-authored-by: Valérian Rousset <5735566+tharvik@users.noreply.github.com>
…into byzantine-robust-aggregator
Co-authored-by: Valérian Rousset <5735566+tharvik@users.noreply.github.com>
0dcd3a7 to
15fa3af
Compare
tharvik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only one or two things to iron out: I'm currently unable to submit the form if I don't set the clipping radius for center clipping, dunno why but I suspect that it has smth to do with using the same name for two fields 🤷
| <FormLabel | ||
| v-model="weightClipping" | ||
| label="Weight clipping" | ||
| type="checkbox" | ||
| > | ||
| <div v-show="weightClipping" class="flex flex-col"> | ||
| <FormLabel | ||
| label="Maximum drift, measured by its norm, that can be made by the aggregated weights each round" | ||
| type="required" | ||
| > | ||
| <FormField | ||
| name="trainingInformation.privacy.clippingRadius" | ||
| placeholder="40" | ||
| as="input" | ||
| type="number" | ||
| /> | ||
| </FormLabel> | ||
| </div> | ||
| </FormLabel> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why removing this clippingRadius? there are now two types of clipping (yeah, a bit confusing, if you've a better name, go for it) no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we use other clipping anywhere. Not so sure, but I will double check.
Co-authored-by: Valérian Rousset <5735566+tharvik@users.noreply.github.com>
…into byzantine-robust-aggregator
Initial implementation of Byzantine robust aggregator with momentums and Center-Clipping aggregation.