Skip to content

Commit 15fa3af

Browse files
committed
Fix CR suggestions
1 parent 8b242e2 commit 15fa3af

File tree

3 files changed

+67
-59
lines changed

3 files changed

+67
-59
lines changed

discojs/src/aggregator/get.ts

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,61 +32,47 @@ export function getAggregator(
3232
options: AggregatorOptions = {},
3333
): aggregator.Aggregator {
3434
const scheme = options.scheme ?? task.trainingInformation.scheme
35+
36+
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
37+
38+
// If scheme == 'federated' then we only expect the server's contribution at each round
39+
// so we set the aggregation threshold to 1 contribution
40+
// If scheme == 'local' then we only expect our own contribution
41+
42+
const networkOptions: Required<AggregatorOptions> = {
43+
scheme,
44+
roundCutOff: 0,
45+
threshold: 1,
46+
thresholdType: scheme === "decentralized" ? "relative" : "absolute",
47+
...options, // user overrides defaults
48+
};
3549

3650
switch (task.trainingInformation.aggregationStrategy) {
3751
case 'byzantine': {
38-
const {
39-
clippingRadius = 1.0,
40-
maxIterations = 1,
41-
beta = 0.9,
52+
const {clippingRadius = 1.0, maxIterations = 1, beta = 0.9,
4253
} = task.trainingInformation;
4354

44-
if (scheme === "decentralized") {
45-
options = {
46-
roundCutOff: undefined,
47-
threshold: 1,
48-
thresholdType: "relative",
49-
...options,
50-
};
51-
} else {
52-
options = {
53-
roundCutOff: undefined,
54-
threshold: 1,
55-
thresholdType: "absolute",
56-
...options,
57-
};
58-
}
59-
6055
return new ByzantineRobustAggregator(
61-
options.roundCutOff ?? 0,
62-
options.threshold ?? 1,
63-
options.thresholdType,
56+
networkOptions.roundCutOff ?? 0,
57+
networkOptions.threshold ?? 1,
58+
networkOptions.thresholdType,
6459
clippingRadius,
6560
maxIterations,
6661
beta
6762
);
6863
}
6964
case 'mean':
70-
if (scheme === 'decentralized') {
71-
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
72-
options = {
73-
roundCutOff: undefined, threshold: 1, thresholdType: 'relative',
74-
...options
75-
}
76-
} else {
77-
// If scheme == 'federated' then we only expect the server's contribution at each round
78-
// so we set the aggregation threshold to 1 contribution
79-
// If scheme == 'local' then we only expect our own contribution
80-
options = {
81-
roundCutOff: undefined, threshold: 1, thresholdType: 'absolute',
82-
...options
83-
}
84-
}
85-
return new aggregator.MeanAggregator(options.roundCutOff, options.threshold, options.thresholdType)
65+
return new aggregator.MeanAggregator(
66+
networkOptions.roundCutOff,
67+
networkOptions.threshold,
68+
networkOptions.thresholdType
69+
)
8670
case 'secure':
8771
if (scheme !== 'decentralized') {
8872
throw new Error('secure aggregation is currently supported for decentralized only')
8973
}
90-
return new aggregator.SecureAggregator(task.trainingInformation.maxShareValue)
74+
return new aggregator.SecureAggregator(
75+
task.trainingInformation.maxShareValue
76+
)
9177
}
9278
}

discojs/src/task/training_information.ts

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ const nonLocalNetworkSchema = z.object({
2323
minNbOfParticipants: z.number().positive().int(),
2424
});
2525

26+
const byzantineSchema = z.object({
27+
aggregationStrategy: z.literal("byzantine"),
28+
clippingRadius: z.number().positive().optional().default(1.0),
29+
maxIterations: z.number().int().positive().optional().default(1),
30+
beta: z.number().min(0).max(1).optional().default(0.9),
31+
});
32+
33+
2634
export namespace TrainingInformation {
2735
export const baseSchema = z.object({
2836
// number of epochs to run training for
@@ -38,6 +46,8 @@ export namespace TrainingInformation {
3846
tensorBackend: z.enum(["gpt", "tfjs"]),
3947
});
4048

49+
50+
4151
export const dataTypeToSchema = {
4252
image: z.object({
4353
// classes, e.g. if two class of images, one with dogs and one with cats, then we would
@@ -74,12 +84,7 @@ export namespace TrainingInformation {
7484
z.object({
7585
aggregationStrategy: z.literal("mean"),
7686
}),
77-
z.object({
78-
aggregationStrategy: z.literal("byzantine"),
79-
clippingRadius: z.number().positive().optional().default(1.0),
80-
maxIterations: z.number().int().positive().optional().default(1),
81-
beta: z.number().min(0).max(1).optional().default(0.9),
82-
}),
87+
byzantineSchema,
8388
z.object({
8489
aggregationStrategy: z.literal("secure"),
8590
// Secure Aggregation: maximum absolute value of a number in a randomly generated share
@@ -98,12 +103,7 @@ export namespace TrainingInformation {
98103
z.object({
99104
aggregationStrategy: z.literal("mean"),
100105
}),
101-
z.object({
102-
aggregationStrategy: z.literal("byzantine"),
103-
clippingRadius: z.number().positive().optional().default(1.0),
104-
maxIterations: z.number().int().positive().optional().default(1),
105-
beta: z.number().min(0).max(1).optional().default(0.9),
106-
}),
106+
byzantineSchema,
107107
]),
108108
),
109109
local: z.object({

webapp/src/components/task_creation_form/TaskCreationForm.vue

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,17 @@
314314
v-model="aggregationStrategy"
315315
name="trainingInformation.aggregationStrategy"
316316
as="select"
317-
:disabled="scheme !== 'decentralized'"
317+
:disabled="scheme == 'local'"
318318
>
319-
<option value="mean">Mean</option>
320-
<option value="secure">Secure</option>
321-
<option value="byzantine">Byzantine</option>
319+
<!-- Federated supports mean & byzantine -->
320+
<option v-if="scheme === 'federated'" value="mean">Mean</option>
321+
<option v-if="scheme === 'federated'" value="byzantine">Byzantine</option>
322+
<!-- Decentralized supports mean, byzantine & secure -->
323+
<option v-if="scheme === 'decentralized'" value="mean">Mean</option>
324+
<option v-if="scheme === 'decentralized'" value="byzantine">Byzantine</option>
325+
<option v-if="scheme === 'decentralized'" value="secure">Secure</option>
326+
<!-- Local supports only mean -->
327+
<option v-if="scheme === 'local'" value="mean">Mean</option>
322328
</FormField>
323329

324330
<FormLabel
@@ -338,33 +344,33 @@
338344
<FormLabel
339345
v-show="aggregationStrategy === 'byzantine'"
340346
label="Clipping radius (λ) for Centered Clipping"
341-
type="required"
342347
>
343348
<FormField
344349
name="trainingInformation.clippingRadius"
345350
placeholder="1.0"
346351
as="input"
347352
type="number"
353+
min="0"
354+
step="0.01"
348355
/>
349356
</FormLabel>
350357

351358
<FormLabel
352359
v-show="aggregationStrategy === 'byzantine'"
353360
label="Max Centered Clipping iterations (L)"
354-
type="required"
355361
>
356362
<FormField
357363
name="trainingInformation.maxIterations"
358364
placeholder="1"
359365
as="input"
360366
type="number"
367+
min="1"
361368
/>
362369
</FormLabel>
363370

364371
<FormLabel
365372
v-show="aggregationStrategy === 'byzantine'"
366373
label="Momentum coefficient (β) to smooth the aggregation over multiple rounds"
367-
type="required"
368374
>
369375
<FormField
370376
name="trainingInformation.beta"
@@ -622,6 +628,20 @@ window.onbeforeunload = (event) => {
622628
event.preventDefault();
623629
};
624630
631+
const byzantineParams = z.object({
632+
clippingRadius: z
633+
.number()
634+
.positive("Clipping radius must be positive"),
635+
maxIterations: z
636+
.number()
637+
.int("Max iterations must be an integer")
638+
.positive("Max iterations must be > 0"),
639+
beta: z
640+
.number()
641+
.min(0, "Momentum β must be ≥ 0")
642+
.max(1, "Momentum β must be ≤ 1"),
643+
});
644+
625645
const nonLocalNetwork = {
626646
privacy: z
627647
.object({
@@ -667,6 +687,7 @@ const trainingInformationNetworks = z.union([
667687
}),
668688
z.object({
669689
aggregationStrategy: z.literal("byzantine"),
690+
...byzantineParams.shape,
670691
}),
671692
z.object({
672693
aggregationStrategy: z.literal("secure"),
@@ -686,6 +707,7 @@ const trainingInformationNetworks = z.union([
686707
}),
687708
z.object({
688709
aggregationStrategy: z.literal("byzantine"),
710+
...byzantineParams.shape,
689711
}),
690712
]),
691713
),

0 commit comments

Comments
 (0)