From a9ecfc384dbc0b4fad66a6f2ac055583ef1a4c01 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Mon, 21 Apr 2025 19:54:29 +0200 Subject: [PATCH 01/18] Byzantine robust aggregator - initial implementation --- discojs/src/aggregator/byzantine.ts | 185 ++++++++++++++++++++++++++++ discojs/src/aggregator/index.ts | 1 + 2 files changed, 186 insertions(+) create mode 100644 discojs/src/aggregator/byzantine.ts diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts new file mode 100644 index 000000000..1f784b86a --- /dev/null +++ b/discojs/src/aggregator/byzantine.ts @@ -0,0 +1,185 @@ +import createDebug from "debug"; +import { List, Map } from "immutable"; + +import { AggregationStep, Aggregator } from "./aggregator.js"; +import { WeightsContainer, client } from "../index.js"; +import { aggregation } from "../index.js"; +import * as tf from '@tensorflow/tfjs' + + +const debug = createDebug("discojs:aggregator:mean"); + +type ThresholdType = 'relative' | 'absolute' + +/** + * Mean aggregator whose aggregation step consists in computing the mean of the received weights. + * + */ +export class ByzantineRobustAggregator extends Aggregator { + readonly #threshold: number; + readonly #thresholdType: ThresholdType; + private readonly clippingRadius: number; + private readonly maxIterations: number; + #minNbOfParticipants: number | undefined; + private momentumHistory: Map = Map() + + /** + * Create a mean aggregator that averages all weight updates received when a specified threshold is met. + * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that + * only accepts contributions from the current round (drops contributions from previous rounds). + * + * @param threshold - how many contributions trigger an aggregation step. + * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions. + * Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`. + * It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions + * Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update), + * set `threshold = 1` and `thresholdType = 'absolute'` + * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1, + * If `threshold != 1` then the specified thresholdType is ignored and overwritten + * If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution + * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions, + * @param roundCutoff - from how many past rounds do we still accept contributions. + * If 0 then only accept contributions from the current round, + * if 1 then the current round and the previous one, etc. + */ + constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType, clippingRadius: number = 1.0, maxIterations: number = 10) { + + if (threshold <= 0) throw new Error("threshold must be strictly positive"); + if (threshold > 1 && (!Number.isInteger(threshold))) + throw new Error("absolute thresholds must be integral"); + + super(roundCutoff, 1); + this.#threshold = threshold; + this.clippingRadius = clippingRadius + this.maxIterations = maxIterations + + if (threshold < 1) { + // Throw exception if threshold and thresholdType are conflicting + if (thresholdType === 'absolute') { + throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`) + } + this.#thresholdType = 'relative' + } + else if (threshold > 1) { + // Throw exception if threshold and thresholdType are conflicting + if (thresholdType === 'relative') { + throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`) + } + this.#thresholdType = 'absolute' + } + // remaining case: threshold == 1 + else { + // Print a warning regarding the default behavior when thresholdType is not specified + if (thresholdType === undefined) { + // TODO enforce validity by splitting the different threshold types into separate classes instead of warning + debug( + "[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " + + "To instead wait for a single contribution, set thresholdType = 'absolute'" + ) + this.#thresholdType = 'relative' + } else { + this.#thresholdType = thresholdType + } + } + } + + /** Checks whether the contributions buffer is full. */ + override isFull(): boolean { + // Make sure that we are over the minimum number of participants + // if specified + if (this.#minNbOfParticipants !== undefined && + this.nodes.size < this.#minNbOfParticipants) return false + + const thresholdValue = + this.#thresholdType == 'relative' + ? this.#threshold * this.nodes.size + : this.#threshold; + + return (this.contributions.get(0)?.size ?? 0) >= thresholdValue; + } + + set minNbOfParticipants(minNbOfParticipants: number) { + this.#minNbOfParticipants = minNbOfParticipants + } + + override _add(nodeId: client.NodeID, contribution: WeightsContainer): void { + + this.log( + this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, + nodeId, + ); + + const beta: number = 0.9; // Momentum parameter + + // Apply worker-side momentum + + const prevMomentum: WeightsContainer = this.contributions.getIn([0, nodeId]) as WeightsContainer; + let momentum: WeightsContainer; + + if (prevMomentum) { + // m_t = (1 - beta) * grad + beta * m_{t-1} + momentum = contribution.map((g, i) => + g.mul(1 - beta).add(prevMomentum.weights[i].mul(beta)) + ); + } else { + momentum = contribution.map(g => g.mul(1 - beta)); + } + + this.contributions = this.contributions.setIn([0, nodeId], momentum); + } + + override aggregate(): WeightsContainer { + const currentContributions = this.contributions.get(0); + + if (currentContributions === undefined) + throw new Error("aggregating without any contribution"); + + this.log(AggregationStep.AGGREGATE); + + let v = aggregation.avg(currentContributions.values()); + + for (let iter = 0; iter < this.maxIterations; iter++) { + const updated = currentContributions.map(m => { + const diff = m.sub(v) + const norm = this.euclideanNorm(diff) + const scale = tf.tidy(() => + tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm)) + ) + return diff.mul(scale).add(v) + }) + + v = aggregation.avg(updated.values()); + } + + return v + } + + private euclideanNorm(w: WeightsContainer): tf.Scalar { + return tf.tidy(() => { + // Start with a scalar value of 0 + let sumSquares = tf.scalar(0); + + // Iterate through weights and accumulate sum of squares + for (const tensor of w.weights) { + // Square each tensor + const squared = tf.square(tensor); + + // Sum the squared values - convert the result to a scalar + const summed = tf.sum(squared); + + // Ensure we're adding scalars + sumSquares = tf.add(sumSquares, summed.asScalar()); + } + + // Take the square root to get the norm + return tf.sqrt(sumSquares); + }); + } + + override makePayloads( + weights: WeightsContainer, + ): Map { + // Communicate our local weights to every other node, be it a peer or a server + return this.nodes.toMap().map(() => weights); + } +} diff --git a/discojs/src/aggregator/index.ts b/discojs/src/aggregator/index.ts index 3310f7d16..6dca014f4 100644 --- a/discojs/src/aggregator/index.ts +++ b/discojs/src/aggregator/index.ts @@ -1,5 +1,6 @@ export { Aggregator, AggregationStep } from './aggregator.js' export { MeanAggregator } from './mean.js' export { SecureAggregator } from './secure.js' +export { ByzantineRobustAggregator } from './byzantine.js' export { getAggregator } from './get.js' \ No newline at end of file From f36ec7d74ad8b2cd279baf37a72220ee3c5421d8 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Wed, 30 Apr 2025 11:54:51 +0200 Subject: [PATCH 02/18] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rousset --- discojs/src/aggregator/byzantine.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts index 1f784b86a..224b930fa 100644 --- a/discojs/src/aggregator/byzantine.ts +++ b/discojs/src/aggregator/byzantine.ts @@ -42,7 +42,7 @@ export class ByzantineRobustAggregator extends Aggregator { * If 0 then only accept contributions from the current round, * if 1 then the current round and the previous one, etc. */ - constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType, clippingRadius: number = 1.0, maxIterations: number = 10) { + constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType, clippingRadius = 1.0, maxIterations = 10) { if (threshold <= 0) throw new Error("threshold must be strictly positive"); if (threshold > 1 && (!Number.isInteger(threshold))) @@ -113,13 +113,13 @@ export class ByzantineRobustAggregator extends Aggregator { // Apply worker-side momentum - const prevMomentum: WeightsContainer = this.contributions.getIn([0, nodeId]) as WeightsContainer; + const prevMomentum = this.contributions.getIn([0, nodeId]) as WeightsContainer | undefined; let momentum: WeightsContainer; if (prevMomentum) { // m_t = (1 - beta) * grad + beta * m_{t-1} - momentum = contribution.map((g, i) => - g.mul(1 - beta).add(prevMomentum.weights[i].mul(beta)) + momentum = contribution.mapWith(prevMomentum, (g, prev) => + g.mul(1 - beta).add(prev.mul(beta)) ); } else { momentum = contribution.map(g => g.mul(1 - beta)); From cdb4e7d723031553fbf6fd6f6a92c4beb08c9586 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Wed, 21 May 2025 12:38:23 +0200 Subject: [PATCH 03/18] Multiround aggregator added and mean refactored --- discojs/src/aggregator/mean.ts | 88 ++------------------------ discojs/src/aggregator/multiround.ts | 94 ++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 82 deletions(-) create mode 100644 discojs/src/aggregator/multiround.ts diff --git a/discojs/src/aggregator/mean.ts b/discojs/src/aggregator/mean.ts index cea0dba9f..10326bcff 100644 --- a/discojs/src/aggregator/mean.ts +++ b/discojs/src/aggregator/mean.ts @@ -1,114 +1,38 @@ -import createDebug from "debug"; import type { Map } from "immutable"; - -import { AggregationStep, Aggregator } from "./aggregator.js"; +import { AggregationStep } from "./aggregator.js"; +import { MultiRoundAggregator, ThresholdType } from "./multiround.js"; import type { WeightsContainer, client } from "../index.js"; import { aggregation } from "../index.js"; +import createDebug from "debug" const debug = createDebug("discojs:aggregator:mean"); -type ThresholdType = 'relative' | 'absolute' /** * Mean aggregator whose aggregation step consists in computing the mean of the received weights. * */ -export class MeanAggregator extends Aggregator { - readonly #threshold: number; - readonly #thresholdType: ThresholdType; - #minNbOfParticipants: number | undefined; - +export class MeanAggregator extends MultiRoundAggregator { /** * Create a mean aggregator that averages all weight updates received when a specified threshold is met. * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that * only accepts contributions from the current round (drops contributions from previous rounds). - * - * @param threshold - how many contributions trigger an aggregation step. - * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions. - * Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`. - * It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions - * Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update), - * set `threshold = 1` and `thresholdType = 'absolute'` - * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1, - * If `threshold != 1` then the specified thresholdType is ignored and overwritten - * If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution - * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions, - * @param roundCutoff - from how many past rounds do we still accept contributions. - * If 0 then only accept contributions from the current round, - * if 1 then the current round and the previous one, etc. */ constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType) { - if (threshold <= 0) throw new Error("threshold must be strictly positive"); - if (threshold > 1 && (!Number.isInteger(threshold))) - throw new Error("absolute thresholds must be integral"); - - - super(roundCutoff, 1); - this.#threshold = threshold; - - if (threshold < 1) { - // Throw exception if threshold and thresholdType are conflicting - if (thresholdType === 'absolute') { - throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`) - } - this.#thresholdType = 'relative' - } - else if (threshold > 1) { - // Throw exception if threshold and thresholdType are conflicting - if (thresholdType === 'relative') { - throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`) - } - this.#thresholdType = 'absolute' - } - // remaining case: threshold == 1 - else { - // Print a warning regarding the default behavior when thresholdType is not specified - if (thresholdType === undefined) { - // TODO enforce validity by splitting the different threshold types into separate classes instead of warning - debug( - "[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " + - "To instead wait for a single contribution, set thresholdType = 'absolute'" - ) - this.#thresholdType = 'relative' - } else { - this.#thresholdType = thresholdType - } - } - } - - /** Checks whether the contributions buffer is full. */ - override isFull(): boolean { - // Make sure that we are over the minimum number of participants - // if specified - if (this.#minNbOfParticipants !== undefined && - this.nodes.size < this.#minNbOfParticipants) return false - - const thresholdValue = - this.#thresholdType == 'relative' - ? this.#threshold * this.nodes.size - : this.#threshold; - - return (this.contributions.get(0)?.size ?? 0) >= thresholdValue; - } - - set minNbOfParticipants(minNbOfParticipants: number) { - this.#minNbOfParticipants = minNbOfParticipants + super(roundCutoff, threshold, thresholdType); } override _add(nodeId: client.NodeID, contribution: WeightsContainer): void { - this.log( this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId, ); - this.contributions = this.contributions.setIn([0, nodeId], contribution); } override aggregate(): WeightsContainer { const currentContributions = this.contributions.get(0); - if (currentContributions === undefined) - throw new Error("aggregating without any contribution"); + if (!currentContributions) throw new Error("aggregating without any contribution"); this.log(AggregationStep.AGGREGATE); diff --git a/discojs/src/aggregator/multiround.ts b/discojs/src/aggregator/multiround.ts new file mode 100644 index 000000000..f0131103a --- /dev/null +++ b/discojs/src/aggregator/multiround.ts @@ -0,0 +1,94 @@ +import { Aggregator, AggregationStep } from "./aggregator.js"; +import createDebug from "debug"; + +export type ThresholdType = 'relative' | 'absolute'; + +const debug = createDebug("discojs:aggregator:multiround"); + +/** + * Base class for multi-round aggregators. + * Multi-round aggregators are aggregators that wait for a certain number of contributions before aggregating. + * They can be used to implement different aggregation strategies, such as Byzantine robust aggregation or Mean Aggregator. + */ +export abstract class MultiRoundAggregator extends Aggregator { + readonly #threshold: number; + readonly #thresholdType: ThresholdType; + #minNbOfParticipants: number | undefined; + + /** + * Abstract class of a multi-round aggregator that wait for a certain number of contributions before aggregating + * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that + * only accepts contributions from the current round (drops contributions from previous rounds). + * + * @param threshold - how many contributions trigger an aggregation step. + * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions. + * Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`. + * It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions + * Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update), + * set `threshold = 1` and `thresholdType = 'absolute'` + * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1, + * If `threshold != 1` then the specified thresholdType is ignored and overwritten + * If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution + * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions, + * @param roundCutoff - from how many past rounds do we still accept contributions. + * If 0 then only accept contributions from the current round, + * if 1 then the current round and the previous one, etc. + */ + + constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType) { + if (threshold <= 0) throw new Error("threshold must be strictly positive"); + if (threshold > 1 && (!Number.isInteger(threshold))) + throw new Error("absolute thresholds must be integral"); + + super(roundCutoff, 1); + this.#threshold = threshold; + + if (threshold < 1) { + // Throw exception if threshold and thresholdType are conflicting + if (thresholdType === 'absolute') { + throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`) + } + this.#thresholdType = 'relative' + } + else if (threshold > 1) { + // Throw exception if threshold and thresholdType are conflicting + if (thresholdType === 'relative') { + throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`) + } + this.#thresholdType = 'absolute' + } + // remaining case: threshold == 1 + else { + // Print a warning regarding the default behavior when thresholdType is not specified + if (thresholdType === undefined) { + // TODO enforce validity by splitting the different threshold types into separate classes instead of warning + debug( + "[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " + + "To instead wait for a single contribution, set thresholdType = 'absolute'" + ) + this.#thresholdType = 'relative' + } else { + this.#thresholdType = thresholdType + } + } + } + + /** Checks whether the contributions buffer is full. */ + override isFull(): boolean { + // Make sure that we are over the minimum number of participants + // if specified + if (this.#minNbOfParticipants !== undefined && + this.nodes.size < this.#minNbOfParticipants) return false; + + const thresholdValue = + this.#thresholdType == 'relative' + ? this.#threshold * this.nodes.size + : this.#threshold; + + return (this.contributions.get(0)?.size ?? 0) >= thresholdValue; + } + + set minNbOfParticipants(minNbOfParticipants: number) { + this.#minNbOfParticipants = minNbOfParticipants; + } +} From cb95c7468971c04cf2f3d33d1891d1a41b452658 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Tue, 27 May 2025 06:29:02 +0200 Subject: [PATCH 04/18] Added comments and descriptions, as well as history map. --- discojs/src/aggregator/byzantine.ts | 273 ++++++++++------------------ 1 file changed, 99 insertions(+), 174 deletions(-) diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts index 224b930fa..8a8554cbb 100644 --- a/discojs/src/aggregator/byzantine.ts +++ b/discojs/src/aggregator/byzantine.ts @@ -1,185 +1,110 @@ -import createDebug from "debug"; -import { List, Map } from "immutable"; - -import { AggregationStep, Aggregator } from "./aggregator.js"; -import { WeightsContainer, client } from "../index.js"; +import { Map } from "immutable"; +import * as tf from '@tensorflow/tfjs'; +import { AggregationStep } from "./aggregator.js"; +import { MultiRoundAggregator, ThresholdType } from "./multiround.js"; +import type { WeightsContainer, client } from "../index.js"; import { aggregation } from "../index.js"; -import * as tf from '@tensorflow/tfjs' - - -const debug = createDebug("discojs:aggregator:mean"); - -type ThresholdType = 'relative' | 'absolute' -/** - * Mean aggregator whose aggregation step consists in computing the mean of the received weights. +/** + * Byzantine-robust aggregator using Centered Clipping (CC), based on the + * "Learning from History for Byzantine Robust Optimization" paper: https://arxiv.org/abs/2012.10333 * + * This class implements a gradient aggregation rule that clips updates + * in an iterative fashion to mitigate the influence of Byzantine nodes, as well as momentum calculations. + */ +export class ByzantineRobustAggregator extends MultiRoundAggregator { + private readonly clippingRadius: number; + private readonly maxIterations: number; + private readonly beta: number; + private momentums: Map = Map(); + + /** + @property clippingRadius The clipping threshold (λ) used to limit the influence of outlier updates. + * - Type: `number` + * - Determines the maximum norm allowed for the difference between a client update and the current estimate. + * - Used in the Centered Clipping step to compute a scaling factor for updates. + * - Smaller values clip more aggressively. + * - Default value is 1.0. + * + * @property maxIterations The number of iterations (L) to run the Centered Clipping update loop. + * - Type: `number` + * - Controls how many refinement steps are used to compute the final aggregate `v`. + * - Default value is 1. + * * @property beta The momentum coefficient used to smooth the aggregation over multiple rounds. + * - Type: `number` + * - Must be between 0 and 1. + * - Used to compute the exponential moving average of past aggregates (i.e., momentum vector). + * The update typically looks like: `v_t = beta * v_{t-1} + (1 - beta) * g_t`, where `g_t` is the current clipped average. + * - A higher beta gives more weight to past rounds (more smoothing), while a lower beta makes the aggregator more responsive to new updates. */ -export class ByzantineRobustAggregator extends Aggregator { - readonly #threshold: number; - readonly #thresholdType: ThresholdType; - private readonly clippingRadius: number; - private readonly maxIterations: number; - #minNbOfParticipants: number | undefined; - private momentumHistory: Map = Map() - - /** - * Create a mean aggregator that averages all weight updates received when a specified threshold is met. - * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that - * only accepts contributions from the current round (drops contributions from previous rounds). - * - * @param threshold - how many contributions trigger an aggregation step. - * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions. - * Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`. - * It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions - * Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update), - * set `threshold = 1` and `thresholdType = 'absolute'` - * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1, - * If `threshold != 1` then the specified thresholdType is ignored and overwritten - * If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution - * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions, - * @param roundCutoff - from how many past rounds do we still accept contributions. - * If 0 then only accept contributions from the current round, - * if 1 then the current round and the previous one, etc. - */ - constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType, clippingRadius = 1.0, maxIterations = 10) { - - if (threshold <= 0) throw new Error("threshold must be strictly positive"); - if (threshold > 1 && (!Number.isInteger(threshold))) - throw new Error("absolute thresholds must be integral"); - - super(roundCutoff, 1); - this.#threshold = threshold; - this.clippingRadius = clippingRadius - this.maxIterations = maxIterations - - if (threshold < 1) { - // Throw exception if threshold and thresholdType are conflicting - if (thresholdType === 'absolute') { - throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`) - } - this.#thresholdType = 'relative' - } - else if (threshold > 1) { - // Throw exception if threshold and thresholdType are conflicting - if (thresholdType === 'relative') { - throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`) - } - this.#thresholdType = 'absolute' - } - // remaining case: threshold == 1 - else { - // Print a warning regarding the default behavior when thresholdType is not specified - if (thresholdType === undefined) { - // TODO enforce validity by splitting the different threshold types into separate classes instead of warning - debug( - "[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " + - "To instead wait for a single contribution, set thresholdType = 'absolute'" - ) - this.#thresholdType = 'relative' - } else { - this.#thresholdType = thresholdType - } - } - } - - /** Checks whether the contributions buffer is full. */ - override isFull(): boolean { - // Make sure that we are over the minimum number of participants - // if specified - if (this.#minNbOfParticipants !== undefined && - this.nodes.size < this.#minNbOfParticipants) return false - - const thresholdValue = - this.#thresholdType == 'relative' - ? this.#threshold * this.nodes.size - : this.#threshold; - - return (this.contributions.get(0)?.size ?? 0) >= thresholdValue; - } - - set minNbOfParticipants(minNbOfParticipants: number) { - this.#minNbOfParticipants = minNbOfParticipants - } - - override _add(nodeId: client.NodeID, contribution: WeightsContainer): void { - - this.log( - this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, - nodeId, - ); - - const beta: number = 0.9; // Momentum parameter - - // Apply worker-side momentum - - const prevMomentum = this.contributions.getIn([0, nodeId]) as WeightsContainer | undefined; - let momentum: WeightsContainer; - if (prevMomentum) { - // m_t = (1 - beta) * grad + beta * m_{t-1} - momentum = contribution.mapWith(prevMomentum, (g, prev) => - g.mul(1 - beta).add(prev.mul(beta)) - ); - } else { - momentum = contribution.map(g => g.mul(1 - beta)); - } - this.contributions = this.contributions.setIn([0, nodeId], momentum); + constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType, clippingRadius = 1.0, maxIterations = 1, beta = 0.9) { + super(roundCutoff, threshold, thresholdType); + if (clippingRadius <= 0) throw new Error("Clipping radius needs to be positive number > 0."); + if (maxIterations < 1) throw new Error("There must be at least one iteration for clipping."); + if (!Number.isInteger(maxIterations)) throw new Error("Number of iterations must be intiger value."); + if ((beta < 0) || (beta > 1)) throw new Error("Beta must be between 0 and 1, since it is coeficient."); + this.clippingRadius = clippingRadius; + this.maxIterations = maxIterations; + this.beta = beta; + } + + override _add(nodeId: client.NodeID, contribution: WeightsContainer): void { + this.log( + this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, + nodeId, + ); + + const prevMomentum = this.momentums.get(nodeId); + const momentum = prevMomentum + ? contribution.mapWith(prevMomentum, (g, m) => g.mul(1 - this.beta).add(m.mul(this.beta))) + : contribution.map(g => g.mul(1 - this.beta)); + + this.momentums = this.momentums.set(nodeId, momentum); + this.contributions = this.contributions.setIn([0, nodeId], momentum); + } + + override aggregate(): WeightsContainer { + const currentContributions = this.contributions.get(0); + if (!currentContributions) throw new Error("aggregating without any contribution"); + + this.log(AggregationStep.AGGREGATE); + + // Step 1: initialize v to average of momentum + let v = aggregation.avg(currentContributions.values()); + + // Step 2: Iterate Centered Clipping + for (let l = 0; l < this.maxIterations; l++) { + const clippedDiffs = Array.from(currentContributions.values()).map(m => { + const diff = m.sub(v); + const norm = euclideanNorm(diff); + const scale = tf.tidy(() => tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm))); + return diff.mul(scale); + }); + + const avgClip = aggregation.avg(clippedDiffs); + v = v.add(avgClip.mul(1 / currentContributions.size)); } - override aggregate(): WeightsContainer { - const currentContributions = this.contributions.get(0); - - if (currentContributions === undefined) - throw new Error("aggregating without any contribution"); - - this.log(AggregationStep.AGGREGATE); - - let v = aggregation.avg(currentContributions.values()); + return v; + } - for (let iter = 0; iter < this.maxIterations; iter++) { - const updated = currentContributions.map(m => { - const diff = m.sub(v) - const norm = this.euclideanNorm(diff) - const scale = tf.tidy(() => - tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm)) - ) - return diff.mul(scale).add(v) - }) - - v = aggregation.avg(updated.values()); - } - - return v - } - - private euclideanNorm(w: WeightsContainer): tf.Scalar { - return tf.tidy(() => { - // Start with a scalar value of 0 - let sumSquares = tf.scalar(0); - - // Iterate through weights and accumulate sum of squares - for (const tensor of w.weights) { - // Square each tensor - const squared = tf.square(tensor); - - // Sum the squared values - convert the result to a scalar - const summed = tf.sum(squared); - - // Ensure we're adding scalars - sumSquares = tf.add(sumSquares, summed.asScalar()); - } - - // Take the square root to get the norm - return tf.sqrt(sumSquares); - }); - } + override makePayloads(weights: WeightsContainer): Map { + // Communicate our local weights to every other node, be it a peer or a server + return this.nodes.toMap().map(() => weights); + } +} - override makePayloads( - weights: WeightsContainer, - ): Map { - // Communicate our local weights to every other node, be it a peer or a server - return this.nodes.toMap().map(() => weights); +function euclideanNorm(w: WeightsContainer): tf.Scalar { + // Computes the Euclidean (L2) norm of all tensors in a WeightsContainer by summing the squares of their elements and taking the square root. + return tf.tidy(() => { + let sumSquares = tf.scalar(0); + for (const tensor of w.weights) { + const squared = tf.square(tensor); + const summed = tf.sum(squared); + sumSquares = tf.add(sumSquares, summed.asScalar()); } -} + return tf.sqrt(sumSquares); + }); +} \ No newline at end of file From 33c120c50bd12421951c713fb179b46d9679e57f Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Wed, 11 Jun 2025 14:54:58 +0200 Subject: [PATCH 05/18] debugging tests --- discojs/src/aggregator/byzantine.spec.ts | 127 +++++++++++++++++++++++ discojs/src/aggregator/byzantine.ts | 74 ++++++++++--- discojs/src/weights/weights_container.ts | 32 +++--- 3 files changed, 206 insertions(+), 27 deletions(-) create mode 100644 discojs/src/aggregator/byzantine.spec.ts diff --git a/discojs/src/aggregator/byzantine.spec.ts b/discojs/src/aggregator/byzantine.spec.ts new file mode 100644 index 000000000..9b3178129 --- /dev/null +++ b/discojs/src/aggregator/byzantine.spec.ts @@ -0,0 +1,127 @@ +import { expect } from "chai"; +import { Set } from "immutable"; +import * as tf from "@tensorflow/tfjs"; + +import { WeightsContainer } from "../index.js"; +import { ByzantineRobustAggregator } from "./byzantine.js"; + +// Helper to convert WeightsContainer → number[][] for easy assertions +async function WSIntoArrays(ws: WeightsContainer): Promise { + return Promise.all(ws.weights.map(async t => Array.from(await t.data()))); +} + +describe("ByzantineRobustAggregator", () => { + it("throws on invalid constructor parameters", () => { + expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 0, 1, 0.5)).to.throw(); + expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 1, 0, 0.5)).to.throw(); + expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 1, 1.1, 0.5)).to.throw(); + expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 1, 1, 1.5)).to.throw(); + }); + + it("performs basic mean when clippingRadius is large and beta = 0", async () => { + const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1e6, 1, 0); + const [id1, id2] = ["c1", "c2"]; + agg.setNodes(Set.of(id1, id2)); + + const p = agg.getPromiseForAggregation(); + agg.add(id1, WeightsContainer.of([1], [2]), 0); + agg.add(id2, WeightsContainer.of([3], [4]), 0); + + const out = await p; + const arr = await WSIntoArrays(out); + expect(arr).to.deep.equal([[2], [3]]); + }); + + it("clips a single outlier with small radius", async () => { + const agg = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 1, 0); + const [c1, c2, bad] = ["c1", "c2", "bad"]; + agg.setNodes(Set.of(c1, c2, bad)); + + const p = agg.getPromiseForAggregation(); + agg.add(c1, WeightsContainer.of([1]), 0); + agg.add(c2, WeightsContainer.of([1]), 0); + agg.add(bad, WeightsContainer.of([100]), 0); + + const out = await p; + const arr = await WSIntoArrays(out); + expect(arr[0][0]).to.be.closeTo(1, 1e-6); + }); + + it("applies multiple clipping iterations (maxIterations > 1)", async () => { + const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1.0, 3, 0); + const [c1, bad] = ["c1", "bad"]; + agg.setNodes(Set.of(c1, bad)); + + const p = agg.getPromiseForAggregation(); + agg.add(c1, WeightsContainer.of([0]), 0); + agg.add(bad, WeightsContainer.of([10]), 0); + + const out = await p; + const arr = await WSIntoArrays(out); + expect(arr[0][0]).to.be.lessThan(1); // clipped closer to 0 + }); + + it("uses momentum when beta > 0", async () => { + const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1e6, 1, 0.5); + const [c1, c2] = ["c1", "c2"]; + agg.setNodes(Set.of(c1, c2)); + + const p1 = agg.getPromiseForAggregation(); + agg.add(c1, WeightsContainer.of([2]), 0); + agg.add(c2, WeightsContainer.of([2]), 0); + const out1 = await p1; + const arr1 = await WSIntoArrays(out1); + expect(arr1[0][0]).to.equal(2); + + const p2 = agg.getPromiseForAggregation(); + agg.add(c1, WeightsContainer.of([4]), 1); + agg.add(c2, WeightsContainer.of([4]), 1); + const out2 = await p2; + const arr2 = await WSIntoArrays(out2); + + // With momentum = 0.5, result = 0.5 * prev + 0.5 * current = 3.0 + expect(arr2[0][0]).to.be.closeTo(3, 1e-6); + }); + + it("respects roundCutoff — ignores old contributions", async () => { + const agg = new ByzantineRobustAggregator(1, 1, 'absolute', 1e6, 1, 0); + const id = "c1"; + agg.setNodes(Set.of(id)); + + // Round 0 + const p0 = agg.getPromiseForAggregation(); + agg.add(id, WeightsContainer.of([10]), 0); + const out0 = await p0; + const arr0 = await WSIntoArrays(out0); + expect(arr0[0][0]).to.equal(10); + + // Round 2 with cutoff=1 → contributions from round 0 should be discarded + const p2 = agg.getPromiseForAggregation(); + agg.add(id, WeightsContainer.of([20]), 2); + const out2 = await p2; + const arr2 = await WSIntoArrays(out2); + expect(arr2[0][0]).to.equal(20); + }); + + it("waits for minNbOfParticipants even with threshold met", async () => { + const agg = new ByzantineRobustAggregator(0, 1, 'absolute', 1e6, 1, 0); + agg.minNbOfParticipants = 2; + + const [c1, c2] = ["c1", "c2"]; + agg.setNodes(Set.of(c1, c2)); + + const p = agg.getPromiseForAggregation(); + agg.add(c1, WeightsContainer.of([5]), 0); + + // Should not emit yet: + let resolved = false; + p.then(() => (resolved = true)); + await new Promise(r => setTimeout(r, 50)); + expect(resolved).to.be.false; + + agg.add(c2, WeightsContainer.of([7]), 0); + const out = await p; + const arr = await WSIntoArrays(out); + expect(arr[0][0]).to.equal(6); + }); +}); diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts index 8a8554cbb..68f8fb8ff 100644 --- a/discojs/src/aggregator/byzantine.ts +++ b/discojs/src/aggregator/byzantine.ts @@ -65,31 +65,83 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { this.contributions = this.contributions.setIn([0, nodeId], momentum); } + // override aggregate(): WeightsContainer { + // const currentContributions = this.contributions.get(0); + // if (!currentContributions) throw new Error("aggregating without any contribution"); + + // this.log(AggregationStep.AGGREGATE); + + // if (!isFinite(this.clippingRadius)) { + // return aggregation.avg(currentContributions.values()); // Identity fallback for large radius + // } + + // // Step 1: initialize v to average of momentum + // let v = aggregation.avg(currentContributions.values()); + + // // Step 2: Iterate Centered Clipping + // // for (let l = 0; l < this.maxIterations; l++) { + // // const clippedDiffs = Array.from(currentContributions.values()).map(m => { + // // const diff = m.sub(v); + // // const norm = euclideanNorm(diff); + // // const scale = tf.tidy(() => tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm))); + // // return diff.mul(scale); + // // }); + + // // const avgClip = aggregation.avg(clippedDiffs); + // // v = v.add(avgClip.mul(1 / currentContributions.size)); + // // } + + // for (let l = 0; l < this.maxIterations; l++) { + // const clippedDiffs = Array.from(currentContributions.values()).map(m => tf.tidy(() => { + // const diff = m.sub(v); + // const norm = euclideanNorm(diff); + // const scale = tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm)); + // return diff.mul(scale); + // })); + + // const avgClip = aggregation.avg(clippedDiffs); + // v = tf.tidy(() => v.add(avgClip)); + // clippedDiffs.forEach(d => d.dispose()); + // } + + // return v; + // } + override aggregate(): WeightsContainer { const currentContributions = this.contributions.get(0); if (!currentContributions) throw new Error("aggregating without any contribution"); this.log(AggregationStep.AGGREGATE); - // Step 1: initialize v to average of momentum + // If clipping radius is infinite, fall back to simple mean + if (!isFinite(this.clippingRadius)) { + return aggregation.avg(currentContributions.values()); + } + + // Step 1: Initialize v to average of momentums let v = aggregation.avg(currentContributions.values()); - // Step 2: Iterate Centered Clipping + // Step 2: Iterative Centered ClippingF for (let l = 0; l < this.maxIterations; l++) { const clippedDiffs = Array.from(currentContributions.values()).map(m => { const diff = m.sub(v); - const norm = euclideanNorm(diff); + const norm = tf.tidy(() => euclideanNorm(diff)); const scale = tf.tidy(() => tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm))); - return diff.mul(scale); + const clipped = diff.mul(scale); + norm.dispose(); scale.dispose(); + return clipped; }); const avgClip = aggregation.avg(clippedDiffs); - v = v.add(avgClip.mul(1 / currentContributions.size)); + const newV = v.add(avgClip); + clippedDiffs.forEach(d => d.dispose()); + v.dispose(); // Safe if v is no longer needed + v = newV; } - return v; } + override makePayloads(weights: WeightsContainer): Map { // Communicate our local weights to every other node, be it a peer or a server return this.nodes.toMap().map(() => weights); @@ -99,12 +151,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { function euclideanNorm(w: WeightsContainer): tf.Scalar { // Computes the Euclidean (L2) norm of all tensors in a WeightsContainer by summing the squares of their elements and taking the square root. return tf.tidy(() => { - let sumSquares = tf.scalar(0); - for (const tensor of w.weights) { - const squared = tf.square(tensor); - const summed = tf.sum(squared); - sumSquares = tf.add(sumSquares, summed.asScalar()); - } - return tf.sqrt(sumSquares); + const norms = w.weights.map(t => tf.sum(tf.square(t)) as tf.Scalar); + const total = norms.reduce((a, b) => tf.add(a, b)) as tf.Scalar; + return tf.sqrt(total); }); } \ No newline at end of file diff --git a/discojs/src/weights/weights_container.ts b/discojs/src/weights/weights_container.ts index 4f1a6de38..9d28d2fd5 100644 --- a/discojs/src/weights/weights_container.ts +++ b/discojs/src/weights/weights_container.ts @@ -16,12 +16,12 @@ export class WeightsContainer { * The iterable's elements can either be regular TF.js tensors or number arrays. * @param weights The weights iterable to build the weights container from */ - constructor (weights: Iterable) { + constructor(weights: Iterable) { this._weights = List(weights).map((w) => w instanceof tf.Tensor ? w : tf.tensor(w)) } - get weights (): Weights { + get weights(): Weights { return this._weights.toArray() } @@ -31,7 +31,7 @@ export class WeightsContainer { * @param other The other weights container * @returns A new subtracted weights container */ - add (other: WeightsContainer): WeightsContainer { + add(other: WeightsContainer): WeightsContainer { return this.mapWith(other, tf.add) } @@ -41,7 +41,7 @@ export class WeightsContainer { * @param other The other weights container * @returns A new subtracted weights container */ - sub (other: WeightsContainer): WeightsContainer { + sub(other: WeightsContainer): WeightsContainer { return this.mapWith(other, tf.sub) } @@ -51,7 +51,7 @@ export class WeightsContainer { * @param other The other weights container * @returns A new multiplied weights container */ - mul (other: TensorLike | number): WeightsContainer { + mul(other: TensorLike | number): WeightsContainer { return new WeightsContainer( this._weights .map(w => w.mul(other)) @@ -65,7 +65,7 @@ export class WeightsContainer { * @param fn The binary operator * @returns The mapping's result */ - mapWith (other: WeightsContainer, fn: (a: tf.Tensor, b: tf.Tensor) => tf.Tensor): WeightsContainer { + mapWith(other: WeightsContainer, fn: (a: tf.Tensor, b: tf.Tensor) => tf.Tensor): WeightsContainer { return new WeightsContainer( this._weights .zip(other._weights) @@ -73,13 +73,13 @@ export class WeightsContainer { ) } - map (fn: (t: tf.Tensor, i: number) => tf.Tensor): WeightsContainer - map (fn: (t: tf.Tensor) => tf.Tensor): WeightsContainer - map (fn: ((t: tf.Tensor) => tf.Tensor) | ((t: tf.Tensor, i: number) => tf.Tensor)): WeightsContainer { + map(fn: (t: tf.Tensor, i: number) => tf.Tensor): WeightsContainer + map(fn: (t: tf.Tensor) => tf.Tensor): WeightsContainer + map(fn: ((t: tf.Tensor) => tf.Tensor) | ((t: tf.Tensor, i: number) => tf.Tensor)): WeightsContainer { return new WeightsContainer(this._weights.map(fn)) } - reduce (fn: (acc: tf.Tensor, t: tf.Tensor) => tf.Tensor): tf.Tensor { + reduce(fn: (acc: tf.Tensor, t: tf.Tensor) => tf.Tensor): tf.Tensor { return this._weights.reduce(fn) } @@ -88,29 +88,33 @@ export class WeightsContainer { * @param index The tensor's index * @returns The tensor located at the index */ - get (index: number): tf.Tensor | undefined { + get(index: number): tf.Tensor | undefined { return this._weights.get(index) } - concat (other: WeightsContainer): WeightsContainer { + concat(other: WeightsContainer): WeightsContainer { return WeightsContainer.of( ...this.weights, ...other.weights ) } - equals (other: WeightsContainer, margin = 0): boolean { + equals(other: WeightsContainer, margin = 0): boolean { return this._weights .zip(other._weights) .every(([w1, w2]) => w1.sub(w2).abs().lessEqual(margin).all().dataSync()[0] === 1) } + + dispose(): void { + this._weights.forEach(w => w.dispose()); + } /** * Instantiates a new weights container from the given tensors or arrays of numbers. * @param weights The tensors or number arrays * @returns The instantiated weights container */ - static of (...weights: TensorLike[]): WeightsContainer { + static of(...weights: TensorLike[]): WeightsContainer { return new this(weights) } } From c282dc62129fba4a7072e26d7db2cfe8d676590d Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Wed, 18 Jun 2025 14:38:06 +0200 Subject: [PATCH 06/18] Byzantine correct aggreagation --- discojs/src/aggregator/byzantine.spec.ts | 22 -------- discojs/src/aggregator/byzantine.ts | 72 +++++++----------------- 2 files changed, 20 insertions(+), 74 deletions(-) diff --git a/discojs/src/aggregator/byzantine.spec.ts b/discojs/src/aggregator/byzantine.spec.ts index 9b3178129..b1250c4b8 100644 --- a/discojs/src/aggregator/byzantine.spec.ts +++ b/discojs/src/aggregator/byzantine.spec.ts @@ -102,26 +102,4 @@ describe("ByzantineRobustAggregator", () => { const arr2 = await WSIntoArrays(out2); expect(arr2[0][0]).to.equal(20); }); - - it("waits for minNbOfParticipants even with threshold met", async () => { - const agg = new ByzantineRobustAggregator(0, 1, 'absolute', 1e6, 1, 0); - agg.minNbOfParticipants = 2; - - const [c1, c2] = ["c1", "c2"]; - agg.setNodes(Set.of(c1, c2)); - - const p = agg.getPromiseForAggregation(); - agg.add(c1, WeightsContainer.of([5]), 0); - - // Should not emit yet: - let resolved = false; - p.then(() => (resolved = true)); - await new Promise(r => setTimeout(r, 50)); - expect(resolved).to.be.false; - - agg.add(c2, WeightsContainer.of([7]), 0); - const out = await p; - const arr = await WSIntoArrays(out); - expect(arr[0][0]).to.equal(6); - }); }); diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts index 68f8fb8ff..2e17c18c4 100644 --- a/discojs/src/aggregator/byzantine.ts +++ b/discojs/src/aggregator/byzantine.ts @@ -2,8 +2,9 @@ import { Map } from "immutable"; import * as tf from '@tensorflow/tfjs'; import { AggregationStep } from "./aggregator.js"; import { MultiRoundAggregator, ThresholdType } from "./multiround.js"; -import type { WeightsContainer, client } from "../index.js"; +import { WeightsContainer, client } from "../index.js"; import { aggregation } from "../index.js"; +import { Repeat } from "immutable"; /** * Byzantine-robust aggregator using Centered Clipping (CC), based on the @@ -16,7 +17,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { private readonly clippingRadius: number; private readonly maxIterations: number; private readonly beta: number; - private momentums: Map = Map(); + private historyMomentums: Map = Map(); + private prevAggregate: WeightsContainer | null = null; /** @property clippingRadius The clipping threshold (λ) used to limit the influence of outlier updates. @@ -56,57 +58,15 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { nodeId, ); - const prevMomentum = this.momentums.get(nodeId); - const momentum = prevMomentum + const prevMomentum = this.historyMomentums.get(nodeId); + const newMomentum = prevMomentum ? contribution.mapWith(prevMomentum, (g, m) => g.mul(1 - this.beta).add(m.mul(this.beta))) - : contribution.map(g => g.mul(1 - this.beta)); + : contribution; // no scaling on first momentum - this.momentums = this.momentums.set(nodeId, momentum); - this.contributions = this.contributions.setIn([0, nodeId], momentum); + this.historyMomentums = this.historyMomentums.set(nodeId, newMomentum); + this.contributions = this.contributions.setIn([0, nodeId], newMomentum); } - // override aggregate(): WeightsContainer { - // const currentContributions = this.contributions.get(0); - // if (!currentContributions) throw new Error("aggregating without any contribution"); - - // this.log(AggregationStep.AGGREGATE); - - // if (!isFinite(this.clippingRadius)) { - // return aggregation.avg(currentContributions.values()); // Identity fallback for large radius - // } - - // // Step 1: initialize v to average of momentum - // let v = aggregation.avg(currentContributions.values()); - - // // Step 2: Iterate Centered Clipping - // // for (let l = 0; l < this.maxIterations; l++) { - // // const clippedDiffs = Array.from(currentContributions.values()).map(m => { - // // const diff = m.sub(v); - // // const norm = euclideanNorm(diff); - // // const scale = tf.tidy(() => tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm))); - // // return diff.mul(scale); - // // }); - - // // const avgClip = aggregation.avg(clippedDiffs); - // // v = v.add(avgClip.mul(1 / currentContributions.size)); - // // } - - // for (let l = 0; l < this.maxIterations; l++) { - // const clippedDiffs = Array.from(currentContributions.values()).map(m => tf.tidy(() => { - // const diff = m.sub(v); - // const norm = euclideanNorm(diff); - // const scale = tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm)); - // return diff.mul(scale); - // })); - - // const avgClip = aggregation.avg(clippedDiffs); - // v = tf.tidy(() => v.add(avgClip)); - // clippedDiffs.forEach(d => d.dispose()); - // } - - // return v; - // } - override aggregate(): WeightsContainer { const currentContributions = this.contributions.get(0); if (!currentContributions) throw new Error("aggregating without any contribution"); @@ -118,9 +78,15 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { return aggregation.avg(currentContributions.values()); } - // Step 1: Initialize v to average of momentums - let v = aggregation.avg(currentContributions.values()); - + // Step 1: Initialize v to average of previous aggregations + let v: WeightsContainer; + if (this.prevAggregate) { + v = this.prevAggregate; + } else { + // Use shape of the first contribution to create zero vector + const sample = currentContributions.values().next().value; + v = sample.map((t: any) => tf.zerosLike(t)); + } // Step 2: Iterative Centered ClippingF for (let l = 0; l < this.maxIterations; l++) { const clippedDiffs = Array.from(currentContributions.values()).map(m => { @@ -138,6 +104,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { v.dispose(); // Safe if v is no longer needed v = newV; } + // Step 3: Update momentum history + this.prevAggregate = v; return v; } From eb796369a5524f3c846fbdcf345c2cb627327955 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Wed, 25 Jun 2025 15:05:14 +0200 Subject: [PATCH 07/18] Secure aggr with momentums --- discojs/src/aggregator/secure_history.spec.ts | 141 ++++++++++++++++++ discojs/src/aggregator/secure_history.ts | 43 ++++++ 2 files changed, 184 insertions(+) create mode 100644 discojs/src/aggregator/secure_history.spec.ts create mode 100644 discojs/src/aggregator/secure_history.ts diff --git a/discojs/src/aggregator/secure_history.spec.ts b/discojs/src/aggregator/secure_history.spec.ts new file mode 100644 index 000000000..941cabcae --- /dev/null +++ b/discojs/src/aggregator/secure_history.spec.ts @@ -0,0 +1,141 @@ +import { List, Set, Range, Map } from "immutable"; +import { assert, expect } from "chai"; + +import { + aggregator as aggregators, + aggregation, + WeightsContainer, +} from "../index.js"; + +import { SecureHistoryAggregator } from "./secure_history.js"; +import { SecureAggregator } from "./secure.js"; + +import { wsIntoArrays, communicate, setupNetwork } from "../aggregator.spec.js"; + +describe("secure history aggregator", function () { + const epsilon = 1e-4; + + const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]); + const secrets = List.of( + WeightsContainer.of([1, 2, 3, -1], [-5, 6]), + WeightsContainer.of([2, 3, 7, 1], [-10, 5]), + WeightsContainer.of([3, 1, 5, 3], [-15, 19]), + ); + + function buildShares(): List> { + const nodes = Set(secrets.keys()).map(String); + return secrets.map((secret) => { + const aggregator = new SecureHistoryAggregator(); + aggregator.setNodes(nodes); + return aggregator.generateAllShares(secret); + }); + } + + function buildPartialSums( + allShares: List>, + ): List { + return Range(0, secrets.size) + .map((idx) => allShares.map((shares) => shares.get(idx))) + .map((shares) => aggregation.sum(shares as List)) + .toList(); + } + + it("recovers secrets from shares", () => { + const recovered = buildShares().map((shares) => aggregation.sum(shares)); + assert.isTrue( + ( + recovered.zip(secrets) as List<[WeightsContainer, WeightsContainer]> + ).every(([actual, expected]) => actual.equals(expected, epsilon)), + ); + }); + + it("aggregates partial sums with momentum smoothing", () => { + const aggregator = new SecureHistoryAggregator(100, 0.8); + const nodes = Set(secrets.keys()).map(String); + aggregator.setNodes(nodes); + + // simulate first communication round contributions (shares) + const sharesRound0 = buildShares(); + sharesRound0.forEach((shares, idx) => { + shares.forEach((share, nodeIdx) => { + aggregator.add(nodeIdx.toString(), share, 0); + }); + }); + + // aggregate round 0 sums + const sumRound0 = aggregator.aggregate(); + expect(sumRound0.equals(aggregation.sum(sharesRound0.get(0)!), epsilon)).to.be.true; + + // // simulate second communication round partial sums + // const partialSums = buildPartialSums(sharesRound0); + // partialSums.forEach((partialSum, nodeIdx) => { + // aggregator.add(nodeIdx.toString(), partialSum, 1); + // }); + + // // First aggregation with momentum - no previous momentum, so just average + // let agg1 = aggregator.aggregate(); + // const avgPartialSum = aggregation.avg(partialSums); + // expect(agg1.equals(avgPartialSum, epsilon)).to.be.true; + + // // Add another set of partial sums with slight modification + // const partialSums2 = partialSums.map(ws => + // ws.map(t => t.mul(1.1)) + // ); + + // partialSums2.forEach((partialSum, nodeIdx) => { + // aggregator.add(nodeIdx.toString(), partialSum, 1); + // }); + + // // Now momentum should smooth the updated average and previous aggregate + // const agg2 = aggregator.aggregate(); + + // // agg2 should be between avgPartialSum and new partial sums average weighted by beta + // const avgPartialSum2 = aggregation.avg(partialSums2); + // // expected = beta * agg1 + (1 - beta) * avgPartialSum2 + // const expectedAgg2 = agg1.mapWith(avgPartialSum2, (a, b) => + // a.mul(aggregator['beta']).add(b.mul(1 - aggregator['beta'])) + // ); + + // // Compare agg2 and expectedAgg2 elementwise + // expect(agg2.equals(expectedAgg2, epsilon)).to.be.true; + }); + + it("behaves similar to SecureAggregator without momentum (beta=0)", async () => { + class TestSecureHistoryAggregator extends SecureHistoryAggregator { + constructor() { + super(0, 0); // beta=0 disables momentum smoothing + } + } + const secureHistoryNetwork = setupNetwork(TestSecureHistoryAggregator); // beta=0 disables momentum smoothing + const secureNetwork = setupNetwork(SecureAggregator); + + const secureHistoryResults = await communicate( + Map( + secureHistoryNetwork + .entrySeq() + .zip(Range(0, 3)) + .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), + ), + 0, + ); + const secureResults = await communicate( + Map( + secureNetwork + .entrySeq() + .zip(Range(0, 3)) + .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), + ), + 0, + ); + + List(await Promise.all(secureHistoryResults.sort().valueSeq().map(wsIntoArrays))) + .flatMap((x) => x) + .flatMap((x) => x) + .zipAll( + List(await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays))) + .flatMap((x) => x) + .flatMap((x) => x), + ) + .forEach(([secureHistory, secure]) => expect(secureHistory).to.be.closeTo(secure, 0.001)); + }); +}); diff --git a/discojs/src/aggregator/secure_history.ts b/discojs/src/aggregator/secure_history.ts new file mode 100644 index 000000000..46fe6d846 --- /dev/null +++ b/discojs/src/aggregator/secure_history.ts @@ -0,0 +1,43 @@ +import type { WeightsContainer, client } from "../index.js"; +import { SecureAggregator } from "./secure.js"; +import * as tf from "@tensorflow/tfjs"; +import { aggregation } from "../index.js"; + +export class SecureHistoryAggregator extends SecureAggregator { + private prevAggregate: WeightsContainer | null = null; + private readonly beta: number; + + constructor(maxShareValue = 100, beta = 0.9) { + super(maxShareValue); + this.beta = beta; + this.prevAggregate = null; + } + + override aggregate(): WeightsContainer { + // Call the base class aggregate for rounds other than 1 + if (this.communicationRound !== 1) { + return super.aggregate(); + } + + // For communication round 1, do average + momentum smoothing + const currentContributions = this.contributions.get(1); + if (!currentContributions) throw new Error("aggregating without any contribution"); + + const avg = aggregation.avg(currentContributions.values()); + + if (this.prevAggregate === null) { + this.prevAggregate = avg; + return avg; + } + + const updatedMomentum = this.prevAggregate.mapWith(avg, (prevT, currT) => + prevT.mul(this.beta).add(currT.mul(1 - this.beta)) + ); + + // Dispose old tensors to avoid memory leaks + this.prevAggregate.weights.forEach(t => t.dispose()); + this.prevAggregate = updatedMomentum; + + return updatedMomentum; + } +} From 96a5e4196a3801ce49e31db98af0830c725cc217 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Thu, 26 Jun 2025 14:44:01 +0200 Subject: [PATCH 08/18] Secure-history aggr, test fix --- discojs/src/aggregator/aggregator.ts | 2 +- discojs/src/aggregator/secure_history.spec.ts | 273 ++++++++++-------- 2 files changed, 148 insertions(+), 127 deletions(-) diff --git a/discojs/src/aggregator/aggregator.ts b/discojs/src/aggregator/aggregator.ts index 41611a0f4..f21aff2aa 100644 --- a/discojs/src/aggregator/aggregator.ts +++ b/discojs/src/aggregator/aggregator.ts @@ -89,7 +89,7 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon throw new Error("Tried adding an invalid contribution. Handle this case before calling add.") // call the abstract method _add, implemented by subclasses - this._add(nodeId, contribution, communicationRound) + this._add(nodeId, contribution, communicationRound ?? this.communicationRound) // If the aggregator has enough contributions then aggregate the weights // and emit the 'aggregation' event if (this.isFull()) { diff --git a/discojs/src/aggregator/secure_history.spec.ts b/discojs/src/aggregator/secure_history.spec.ts index 941cabcae..2b7418629 100644 --- a/discojs/src/aggregator/secure_history.spec.ts +++ b/discojs/src/aggregator/secure_history.spec.ts @@ -1,141 +1,162 @@ import { List, Set, Range, Map } from "immutable"; import { assert, expect } from "chai"; +import * as tf from "@tensorflow/tfjs"; import { - aggregator as aggregators, - aggregation, - WeightsContainer, + aggregator as aggregators, + aggregation, + WeightsContainer, } from "../index.js"; -import { SecureHistoryAggregator } from "./secure_history.js"; +import { SecureHistoryAggregator } from "./secure_history.js"; import { SecureAggregator } from "./secure.js"; import { wsIntoArrays, communicate, setupNetwork } from "../aggregator.spec.js"; -describe("secure history aggregator", function () { - const epsilon = 1e-4; - - const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]); - const secrets = List.of( - WeightsContainer.of([1, 2, 3, -1], [-5, 6]), - WeightsContainer.of([2, 3, 7, 1], [-10, 5]), - WeightsContainer.of([3, 1, 5, 3], [-15, 19]), - ); - - function buildShares(): List> { - const nodes = Set(secrets.keys()).map(String); - return secrets.map((secret) => { - const aggregator = new SecureHistoryAggregator(); - aggregator.setNodes(nodes); - return aggregator.generateAllShares(secret); - }); - } - - function buildPartialSums( - allShares: List>, - ): List { - return Range(0, secrets.size) - .map((idx) => allShares.map((shares) => shares.get(idx))) - .map((shares) => aggregation.sum(shares as List)) - .toList(); - } - - it("recovers secrets from shares", () => { - const recovered = buildShares().map((shares) => aggregation.sum(shares)); - assert.isTrue( - ( - recovered.zip(secrets) as List<[WeightsContainer, WeightsContainer]> - ).every(([actual, expected]) => actual.equals(expected, epsilon)), +describe("Secure history aggregator", function () { + const epsilon = 1e-4; + + const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]); + const secrets = List.of( + WeightsContainer.of([1, 2, 3, -1], [-5, 6]), + WeightsContainer.of([2, 3, 7, 1], [-10, 5]), + WeightsContainer.of([3, 1, 5, 3], [-15, 19]), ); - }); - - it("aggregates partial sums with momentum smoothing", () => { - const aggregator = new SecureHistoryAggregator(100, 0.8); - const nodes = Set(secrets.keys()).map(String); - aggregator.setNodes(nodes); - - // simulate first communication round contributions (shares) - const sharesRound0 = buildShares(); - sharesRound0.forEach((shares, idx) => { - shares.forEach((share, nodeIdx) => { - aggregator.add(nodeIdx.toString(), share, 0); - }); - }); - // aggregate round 0 sums - const sumRound0 = aggregator.aggregate(); - expect(sumRound0.equals(aggregation.sum(sharesRound0.get(0)!), epsilon)).to.be.true; - - // // simulate second communication round partial sums - // const partialSums = buildPartialSums(sharesRound0); - // partialSums.forEach((partialSum, nodeIdx) => { - // aggregator.add(nodeIdx.toString(), partialSum, 1); - // }); - - // // First aggregation with momentum - no previous momentum, so just average - // let agg1 = aggregator.aggregate(); - // const avgPartialSum = aggregation.avg(partialSums); - // expect(agg1.equals(avgPartialSum, epsilon)).to.be.true; - - // // Add another set of partial sums with slight modification - // const partialSums2 = partialSums.map(ws => - // ws.map(t => t.mul(1.1)) - // ); - - // partialSums2.forEach((partialSum, nodeIdx) => { - // aggregator.add(nodeIdx.toString(), partialSum, 1); - // }); - - // // Now momentum should smooth the updated average and previous aggregate - // const agg2 = aggregator.aggregate(); - - // // agg2 should be between avgPartialSum and new partial sums average weighted by beta - // const avgPartialSum2 = aggregation.avg(partialSums2); - // // expected = beta * agg1 + (1 - beta) * avgPartialSum2 - // const expectedAgg2 = agg1.mapWith(avgPartialSum2, (a, b) => - // a.mul(aggregator['beta']).add(b.mul(1 - aggregator['beta'])) - // ); - - // // Compare agg2 and expectedAgg2 elementwise - // expect(agg2.equals(expectedAgg2, epsilon)).to.be.true; - }); - - it("behaves similar to SecureAggregator without momentum (beta=0)", async () => { - class TestSecureHistoryAggregator extends SecureHistoryAggregator { - constructor() { - super(0, 0); // beta=0 disables momentum smoothing - } + function buildShares(): List> { + const nodes = Set(secrets.keys()).map(String); + return secrets.map((secret) => { + const aggregator = new SecureHistoryAggregator(); + aggregator.setNodes(nodes); + return aggregator.generateAllShares(secret); + }); } - const secureHistoryNetwork = setupNetwork(TestSecureHistoryAggregator); // beta=0 disables momentum smoothing - const secureNetwork = setupNetwork(SecureAggregator); - - const secureHistoryResults = await communicate( - Map( - secureHistoryNetwork - .entrySeq() - .zip(Range(0, 3)) - .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), - ), - 0, - ); - const secureResults = await communicate( - Map( - secureNetwork - .entrySeq() - .zip(Range(0, 3)) - .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), - ), - 0, - ); - List(await Promise.all(secureHistoryResults.sort().valueSeq().map(wsIntoArrays))) - .flatMap((x) => x) - .flatMap((x) => x) - .zipAll( - List(await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays))) - .flatMap((x) => x) - .flatMap((x) => x), - ) - .forEach(([secureHistory, secure]) => expect(secureHistory).to.be.closeTo(secure, 0.001)); - }); + function buildPartialSums( + allShares: List>, + ): List { + return Range(0, secrets.size) + .map((idx) => allShares.map((shares) => shares.get(idx))) + .map((shares) => aggregation.sum(shares as List)) + .toList(); + } + + it("recovers secrets from shares", () => { + const recovered = buildShares().map((shares) => aggregation.sum(shares)); + assert.isTrue( + ( + recovered.zip(secrets) as List<[WeightsContainer, WeightsContainer]> + ).every(([actual, expected]) => actual.equals(expected, epsilon)), + ); + }); + + it("aggregates partial sums with momentum smoothing", async () => { + const aggregator = new SecureHistoryAggregator(100, 0.8); + const nodes = Set(secrets.keys()).map(String); + aggregator.setNodes(nodes); + + // Prepare to capture aggregation result + const aggregationPromise = aggregator.getPromiseForAggregation(); + + const sharesRound0 = buildShares(); + + let partialSums = Range(0, nodes.size).map((receiverIdx) => { + const receivedShares = sharesRound0.map(shares => shares.get(receiverIdx)!); + return aggregation.sum(receivedShares as List); + }).toList(); + + // Add one total contribution per node + partialSums.forEach((partialSum, idx) => { + const nodeId = idx.toString(); + aggregator.add(nodeId, partialSum, 0); + }); + + const sumRound0 = await aggregationPromise; + + const expectedSum = aggregation.sum( + sharesRound0.flatMap(x => x) // flatten to List + ); + expect(sumRound0.equals(expectedSum, epsilon)).to.be.true; + + + // simulate second communication round partial sums + const aggregationPromise2 = aggregator.getPromiseForAggregation(); + + partialSums.forEach((partialSum, idx) => { + const nodeId = idx.toString(); + aggregator.add(nodeId, partialSum, 0); + }); + const sumRound1 = await aggregationPromise2; + + // First aggregation with momentum - no previous momentum, so just average + const avgPartialSum = aggregation.avg(partialSums); + expect(sumRound1.equals(avgPartialSum, epsilon)).to.be.true; + + const dummyPromise = aggregator.getPromiseForAggregation(); + partialSums.forEach((partialSum, idx) => { + const nodeId = idx.toString(); + aggregator.add(nodeId, partialSum, 1); // round 0 of next aggregation round + }); + await dummyPromise; + + const aggregationPromise3 = aggregator.getPromiseForAggregation(); + // Add another set of partial sums with slight modification + const partialSums2 = partialSums.map(ws => + ws.map((tensor) => tf.mul(tensor, 1.1)) + ); + + // Step 3: Add new partial sums to aggregator + partialSums2.forEach((partialSum, idx) => { + const nodeId = idx.toString(); + aggregator.add(nodeId, partialSum, 1); + }); + const sumRound2 = await aggregationPromise3; + + const avgPartialSum2 = aggregation.avg(partialSums2); + const expectedSumRound2 = avgPartialSum.mapWith(avgPartialSum2, (prev, curr) => + prev.mul(0.8).add(curr.mul(0.2)) // 0.8 = beta, 0.2 = (1 - beta) + ); + + // Compare the actual result to the expected smoothed result + expect(sumRound2.equals(expectedSumRound2, 1e-3)).to.be.true; + }); + + it("behaves similar to SecureAggregator without momentum (beta=0)", async () => { + class TestSecureHistoryAggregator extends SecureHistoryAggregator { + constructor() { + super(0, 0); // beta=0 disables momentum smoothing + } + } + const secureHistoryNetwork = setupNetwork(TestSecureHistoryAggregator); // beta=0 disables momentum smoothing + const secureNetwork = setupNetwork(SecureAggregator); + + const secureHistoryResults = await communicate( + Map( + secureHistoryNetwork + .entrySeq() + .zip(Range(0, 3)) + .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), + ), + 0, + ); + const secureResults = await communicate( + Map( + secureNetwork + .entrySeq() + .zip(Range(0, 3)) + .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), + ), + 0, + ); + + List(await Promise.all(secureHistoryResults.sort().valueSeq().map(wsIntoArrays))) + .flatMap((x) => x) + .flatMap((x) => x) + .zipAll( + List(await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays))) + .flatMap((x) => x) + .flatMap((x) => x), + ) + .forEach(([secureHistory, secure]) => expect(secureHistory).to.be.closeTo(secure, 0.001)); + }); }); From 89fa8ed0fab99597b4c26d81cac449e0789d15e1 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Thu, 26 Jun 2025 17:26:12 +0200 Subject: [PATCH 09/18] Docstrings and comments added to secure history aggregator --- discojs/src/aggregator/secure_history.spec.ts | 5 +++-- discojs/src/aggregator/secure_history.ts | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/discojs/src/aggregator/secure_history.spec.ts b/discojs/src/aggregator/secure_history.spec.ts index 2b7418629..740a5fbbf 100644 --- a/discojs/src/aggregator/secure_history.spec.ts +++ b/discojs/src/aggregator/secure_history.spec.ts @@ -92,6 +92,7 @@ describe("Secure history aggregator", function () { const avgPartialSum = aggregation.avg(partialSums); expect(sumRound1.equals(avgPartialSum, epsilon)).to.be.true; + // Now we simulate a second round of aggregation with momentum smoothing const dummyPromise = aggregator.getPromiseForAggregation(); partialSums.forEach((partialSum, idx) => { const nodeId = idx.toString(); @@ -105,7 +106,7 @@ describe("Secure history aggregator", function () { ws.map((tensor) => tf.mul(tensor, 1.1)) ); - // Step 3: Add new partial sums to aggregator + // Add the modified partial sums to the aggregator partialSums2.forEach((partialSum, idx) => { const nodeId = idx.toString(); aggregator.add(nodeId, partialSum, 1); @@ -117,7 +118,7 @@ describe("Secure history aggregator", function () { prev.mul(0.8).add(curr.mul(0.2)) // 0.8 = beta, 0.2 = (1 - beta) ); - // Compare the actual result to the expected smoothed result + // Compare the actual result to the expected smoothed result using momentum expect(sumRound2.equals(expectedSumRound2, 1e-3)).to.be.true; }); diff --git a/discojs/src/aggregator/secure_history.ts b/discojs/src/aggregator/secure_history.ts index 46fe6d846..59c28381b 100644 --- a/discojs/src/aggregator/secure_history.ts +++ b/discojs/src/aggregator/secure_history.ts @@ -3,10 +3,30 @@ import { SecureAggregator } from "./secure.js"; import * as tf from "@tensorflow/tfjs"; import { aggregation } from "../index.js"; +/** + * Aggregator that implements secure multi-party computation with history-based momentum smoothing. + * It aggregates contributions in two communication rounds: + * - In the first round, nodes send their secret shares to each other. + * - In the second round, they sum their received shares and communicate the result. + * Finally, nodes average the received partial sums to establish the aggregation result. + * This aggregator also applies momentum smoothing based on the previous aggregation result. + * It uses a beta parameter to control the smoothing effect. + * The first aggregation round uses the average of contributions, while subsequent rounds apply momentum smoothing. + * This allows for a more stable aggregation result over time, reducing the impact of outliers. + * * @extends SecureAggregator + * * @example + * const aggregator = new SecureHistoryAggregator(100, 0.9); + */ + export class SecureHistoryAggregator extends SecureAggregator { private prevAggregate: WeightsContainer | null = null; private readonly beta: number; + /** + * @param maxShareValue - The maximum value for each share. + * @param beta - The momentum smoothing factor (0 < beta < 1). + */ + constructor(maxShareValue = 100, beta = 0.9) { super(maxShareValue); this.beta = beta; From 7726dbc584f31b827621805c4779d73d88f273c8 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Thu, 6 Nov 2025 01:15:58 +0100 Subject: [PATCH 10/18] Fix linter errors on Secure-history and Byzantine aggregator --- discojs/src/aggregator/byzantine.spec.ts | 4 +--- discojs/src/aggregator/byzantine.ts | 9 ++++----- discojs/src/aggregator/mean.ts | 4 ---- discojs/src/aggregator/multiround.ts | 2 +- discojs/src/aggregator/secure_history.spec.ts | 18 ++++-------------- discojs/src/aggregator/secure_history.ts | 3 +-- 6 files changed, 11 insertions(+), 29 deletions(-) diff --git a/discojs/src/aggregator/byzantine.spec.ts b/discojs/src/aggregator/byzantine.spec.ts index b1250c4b8..d300fbb4c 100644 --- a/discojs/src/aggregator/byzantine.spec.ts +++ b/discojs/src/aggregator/byzantine.spec.ts @@ -1,6 +1,5 @@ -import { expect } from "chai"; import { Set } from "immutable"; -import * as tf from "@tensorflow/tfjs"; +import { describe, expect, it } from "vitest"; import { WeightsContainer } from "../index.js"; import { ByzantineRobustAggregator } from "./byzantine.js"; @@ -88,7 +87,6 @@ describe("ByzantineRobustAggregator", () => { const id = "c1"; agg.setNodes(Set.of(id)); - // Round 0 const p0 = agg.getPromiseForAggregation(); agg.add(id, WeightsContainer.of([10]), 0); const out0 = await p0; diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts index 2e17c18c4..d989b8a79 100644 --- a/discojs/src/aggregator/byzantine.ts +++ b/discojs/src/aggregator/byzantine.ts @@ -4,7 +4,6 @@ import { AggregationStep } from "./aggregator.js"; import { MultiRoundAggregator, ThresholdType } from "./multiround.js"; import { WeightsContainer, client } from "../index.js"; import { aggregation } from "../index.js"; -import { Repeat } from "immutable"; /** * Byzantine-robust aggregator using Centered Clipping (CC), based on the @@ -84,8 +83,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { v = this.prevAggregate; } else { // Use shape of the first contribution to create zero vector - const sample = currentContributions.values().next().value; - v = sample.map((t: any) => tf.zerosLike(t)); + const sample = currentContributions.values().next().value as WeightsContainer; + v = sample.map((t: tf.Tensor) => tf.zerosLike(t)); } // Step 2: Iterative Centered ClippingF for (let l = 0; l < this.maxIterations; l++) { @@ -119,8 +118,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { function euclideanNorm(w: WeightsContainer): tf.Scalar { // Computes the Euclidean (L2) norm of all tensors in a WeightsContainer by summing the squares of their elements and taking the square root. return tf.tidy(() => { - const norms = w.weights.map(t => tf.sum(tf.square(t)) as tf.Scalar); - const total = norms.reduce((a, b) => tf.add(a, b)) as tf.Scalar; + const norms: tf.Scalar[] = w.weights.map(t => tf.sum(tf.square(t))); + const total = norms.reduce((a, b) => tf.add(a, b)); return tf.sqrt(total); }); } \ No newline at end of file diff --git a/discojs/src/aggregator/mean.ts b/discojs/src/aggregator/mean.ts index 10326bcff..eb24d370a 100644 --- a/discojs/src/aggregator/mean.ts +++ b/discojs/src/aggregator/mean.ts @@ -3,10 +3,6 @@ import { AggregationStep } from "./aggregator.js"; import { MultiRoundAggregator, ThresholdType } from "./multiround.js"; import type { WeightsContainer, client } from "../index.js"; import { aggregation } from "../index.js"; -import createDebug from "debug" - -const debug = createDebug("discojs:aggregator:mean"); - /** * Mean aggregator whose aggregation step consists in computing the mean of the received weights. diff --git a/discojs/src/aggregator/multiround.ts b/discojs/src/aggregator/multiround.ts index f0131103a..ff6d4f021 100644 --- a/discojs/src/aggregator/multiround.ts +++ b/discojs/src/aggregator/multiround.ts @@ -1,4 +1,4 @@ -import { Aggregator, AggregationStep } from "./aggregator.js"; +import { Aggregator } from "./aggregator.js"; import createDebug from "debug"; export type ThresholdType = 'relative' | 'absolute'; diff --git a/discojs/src/aggregator/secure_history.spec.ts b/discojs/src/aggregator/secure_history.spec.ts index 740a5fbbf..aa9a56776 100644 --- a/discojs/src/aggregator/secure_history.spec.ts +++ b/discojs/src/aggregator/secure_history.spec.ts @@ -1,9 +1,9 @@ import { List, Set, Range, Map } from "immutable"; -import { assert, expect } from "chai"; +import { describe, expect, it, assert } from "vitest"; + import * as tf from "@tensorflow/tfjs"; import { - aggregator as aggregators, aggregation, WeightsContainer, } from "../index.js"; @@ -16,7 +16,6 @@ import { wsIntoArrays, communicate, setupNetwork } from "../aggregator.spec.js"; describe("Secure history aggregator", function () { const epsilon = 1e-4; - const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]); const secrets = List.of( WeightsContainer.of([1, 2, 3, -1], [-5, 6]), WeightsContainer.of([2, 3, 7, 1], [-10, 5]), @@ -32,15 +31,6 @@ describe("Secure history aggregator", function () { }); } - function buildPartialSums( - allShares: List>, - ): List { - return Range(0, secrets.size) - .map((idx) => allShares.map((shares) => shares.get(idx))) - .map((shares) => aggregation.sum(shares as List)) - .toList(); - } - it("recovers secrets from shares", () => { const recovered = buildShares().map((shares) => aggregation.sum(shares)); assert.isTrue( @@ -60,9 +50,9 @@ describe("Secure history aggregator", function () { const sharesRound0 = buildShares(); - let partialSums = Range(0, nodes.size).map((receiverIdx) => { + const partialSums = Range(0, nodes.size).map((receiverIdx) => { const receivedShares = sharesRound0.map(shares => shares.get(receiverIdx)!); - return aggregation.sum(receivedShares as List); + return aggregation.sum(receivedShares); }).toList(); // Add one total contribution per node diff --git a/discojs/src/aggregator/secure_history.ts b/discojs/src/aggregator/secure_history.ts index 59c28381b..a24b5acab 100644 --- a/discojs/src/aggregator/secure_history.ts +++ b/discojs/src/aggregator/secure_history.ts @@ -1,6 +1,5 @@ -import type { WeightsContainer, client } from "../index.js"; +import type { WeightsContainer } from "../index.js"; import { SecureAggregator } from "./secure.js"; -import * as tf from "@tensorflow/tfjs"; import { aggregation } from "../index.js"; /** From c2f5e663700bb7205eae088e857f4e08e0138264 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Mon, 24 Nov 2025 12:39:19 +0100 Subject: [PATCH 11/18] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rousset <5735566+tharvik@users.noreply.github.com> --- discojs/src/aggregator/byzantine.ts | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts index d989b8a79..64d5cbe43 100644 --- a/discojs/src/aggregator/byzantine.ts +++ b/discojs/src/aggregator/byzantine.ts @@ -44,7 +44,7 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { super(roundCutoff, threshold, thresholdType); if (clippingRadius <= 0) throw new Error("Clipping radius needs to be positive number > 0."); if (maxIterations < 1) throw new Error("There must be at least one iteration for clipping."); - if (!Number.isInteger(maxIterations)) throw new Error("Number of iterations must be intiger value."); + if (!Number.isInteger(maxIterations)) throw new Error("Number of iterations must be an integer."); if ((beta < 0) || (beta > 1)) throw new Error("Beta must be between 0 and 1, since it is coeficient."); this.clippingRadius = clippingRadius; this.maxIterations = maxIterations; @@ -83,10 +83,11 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { v = this.prevAggregate; } else { // Use shape of the first contribution to create zero vector - const sample = currentContributions.values().next().value as WeightsContainer; - v = sample.map((t: tf.Tensor) => tf.zerosLike(t)); + const first = currentContributions.values().next(); + if (first.done) throw new Error("zero sized contribution") + v = first.value.map((t: tf.Tensor) => tf.zerosLike(t)); } - // Step 2: Iterative Centered ClippingF + // Step 2: Iterative Centered Clipping for (let l = 0; l < this.maxIterations; l++) { const clippedDiffs = Array.from(currentContributions.values()).map(m => { const diff = m.sub(v); From 34ab126858b5b6c0bbc58d9bcdbb483c7ae4f40f Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Thu, 27 Nov 2025 23:54:03 +0100 Subject: [PATCH 12/18] Patch GUI and core to support Byzantine aggregator --- discojs/src/aggregator/get.ts | 33 ++++++++ discojs/src/default_tasks/cifar10.ts | 5 +- discojs/src/task/training_information.ts | 22 ++++- .../task_creation_form/TaskCreationForm.vue | 81 ++++++++++++++----- 4 files changed, 116 insertions(+), 25 deletions(-) diff --git a/discojs/src/aggregator/get.ts b/discojs/src/aggregator/get.ts index 23b3bc884..b3ad00ec3 100644 --- a/discojs/src/aggregator/get.ts +++ b/discojs/src/aggregator/get.ts @@ -1,5 +1,6 @@ import type { DataType, Network, Task } from '../index.js' import { aggregator } from '../index.js' +import { ByzantineRobustAggregator } from './byzantine.js'; type AggregatorOptions = Partial<{ scheme: Task["trainingInformation"]["scheme"]; // if undefined, fallback on task.trainingInformation.scheme @@ -33,6 +34,38 @@ export function getAggregator( const scheme = options.scheme ?? task.trainingInformation.scheme switch (task.trainingInformation.aggregationStrategy) { + case 'byzantine': { + const { + clippingRadius = 1.0, + maxIterations = 1, + beta = 0.9, + } = task.trainingInformation as any; + + if (scheme === "decentralized") { + options = { + roundCutOff: undefined, + threshold: 1, + thresholdType: "relative", + ...options, + }; + } else { + options = { + roundCutOff: undefined, + threshold: 1, + thresholdType: "absolute", + ...options, + }; + } + + return new ByzantineRobustAggregator( + options.roundCutOff ?? 0, + options.threshold ?? 1, + options.thresholdType, + clippingRadius, + maxIterations, + beta + ); + } case 'mean': if (scheme === 'decentralized') { // If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100% diff --git a/discojs/src/default_tasks/cifar10.ts b/discojs/src/default_tasks/cifar10.ts index 56ca523d9..9a3ac0e98 100644 --- a/discojs/src/default_tasks/cifar10.ts +++ b/discojs/src/default_tasks/cifar10.ts @@ -35,7 +35,10 @@ export const cifar10: TaskProvider<"image", "decentralized"> = { IMAGE_W: 224, LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], scheme: 'decentralized', - aggregationStrategy: 'mean', + aggregationStrategy: 'byzantine', + clippingRadius: 1.0, + maxIterations: 1, + beta: 0.9, privacy: { clippingRadius: 20, noiseScale: 1 }, minNbOfParticipants: 3, maxShareValue: 100, diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts index 145fb2238..992ec25a5 100644 --- a/discojs/src/task/training_information.ts +++ b/discojs/src/task/training_information.ts @@ -74,6 +74,12 @@ export namespace TrainingInformation { z.object({ aggregationStrategy: z.literal("mean"), }), + z.object({ + aggregationStrategy: z.literal("byzantine"), + clippingRadius: z.number().positive().optional().default(1.0), + maxIterations: z.number().int().positive().optional().default(1), + beta: z.number().min(0).max(1).optional().default(0.9), + }), z.object({ aggregationStrategy: z.literal("secure"), // Secure Aggregation: maximum absolute value of a number in a randomly generated share @@ -85,9 +91,21 @@ export namespace TrainingInformation { federated: z .object({ scheme: z.literal("federated"), - aggregationStrategy: z.literal("mean"), }) - .merge(nonLocalNetworkSchema), + .merge(nonLocalNetworkSchema) + .and( + z.union([ + z.object({ + aggregationStrategy: z.literal("mean"), + }), + z.object({ + aggregationStrategy: z.literal("byzantine"), + clippingRadius: z.number().positive().optional().default(1.0), + maxIterations: z.number().int().positive().optional().default(1), + beta: z.number().min(0).max(1).optional().default(0.9), + }), + ]), + ), local: z.object({ scheme: z.literal("local"), aggregationStrategy: z.literal("mean"), diff --git a/webapp/src/components/task_creation_form/TaskCreationForm.vue b/webapp/src/components/task_creation_form/TaskCreationForm.vue index 56c286dbd..280ca4863 100644 --- a/webapp/src/components/task_creation_form/TaskCreationForm.vue +++ b/webapp/src/components/task_creation_form/TaskCreationForm.vue @@ -318,6 +318,7 @@ > + + + + + + + + + + + + + + + - - -
- - - -
-
@@ -642,6 +667,9 @@ const trainingInformationNetworks = z.union([ z.object({ aggregationStrategy: z.literal("mean"), }), + z.object({ + aggregationStrategy: z.literal("byzantine"), + }), z.object({ aggregationStrategy: z.literal("secure"), maxShareValue: z.number().positive().int(), @@ -651,9 +679,18 @@ const trainingInformationNetworks = z.union([ z .object({ scheme: z.literal("federated"), - aggregationStrategy: z.literal("mean"), }) - .merge(nonLocalNetworkSchema), + .merge(nonLocalNetworkSchema) + .and( + z.union([ + z.object({ + aggregationStrategy: z.literal("mean"), + }), + z.object({ + aggregationStrategy: z.literal("byzantine"), + }), + ]), + ), z.object({ scheme: z.literal("local"), aggregationStrategy: z.literal("mean"), From f15443e76d1e6c5b2a8d4dca0f9d87b0c3b85d38 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Fri, 28 Nov 2025 00:08:38 +0100 Subject: [PATCH 13/18] Correct federated mart in Task creation form --- webapp/src/components/task_creation_form/TaskCreationForm.vue | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webapp/src/components/task_creation_form/TaskCreationForm.vue b/webapp/src/components/task_creation_form/TaskCreationForm.vue index aafc3dc65..93c842d9f 100644 --- a/webapp/src/components/task_creation_form/TaskCreationForm.vue +++ b/webapp/src/components/task_creation_form/TaskCreationForm.vue @@ -678,8 +678,8 @@ const trainingInformationNetworks = z.union([ z .object({ scheme: z.literal("federated"), + ...nonLocalNetwork, }) - .merge(nonLocalNetworkSchema) .and( z.union([ z.object({ From 495b9debcddb08bf67a992bbf3add84244e1b898 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Fri, 28 Nov 2025 00:22:51 +0100 Subject: [PATCH 14/18] Fix linter errors --- discojs/src/aggregator/get.ts | 6 +++++- .../src/components/task_creation_form/TaskCreationForm.vue | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/discojs/src/aggregator/get.ts b/discojs/src/aggregator/get.ts index b3ad00ec3..33e956d30 100644 --- a/discojs/src/aggregator/get.ts +++ b/discojs/src/aggregator/get.ts @@ -39,7 +39,11 @@ export function getAggregator( clippingRadius = 1.0, maxIterations = 1, beta = 0.9, - } = task.trainingInformation as any; + }:{ + clippingRadius?: number; + maxIterations?: number; + beta?: number; + } = task.trainingInformation; if (scheme === "decentralized") { options = { diff --git a/webapp/src/components/task_creation_form/TaskCreationForm.vue b/webapp/src/components/task_creation_form/TaskCreationForm.vue index 93c842d9f..378dec658 100644 --- a/webapp/src/components/task_creation_form/TaskCreationForm.vue +++ b/webapp/src/components/task_creation_form/TaskCreationForm.vue @@ -593,7 +593,6 @@ const dataType = ref("image"); const scheme = ref("federated"); const aggregationStrategy = ref("mean"); const differentialPrivacy = ref(false); -const weightClipping = ref(false); const form = useTemplateRef("form"); From 8b242e2e164f449098ab8c2326e3693bb181400a Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Sun, 7 Dec 2025 18:00:14 +0100 Subject: [PATCH 15/18] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rousset <5735566+tharvik@users.noreply.github.com> --- discojs/src/aggregator/get.ts | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/discojs/src/aggregator/get.ts b/discojs/src/aggregator/get.ts index 33e956d30..7c483cbdf 100644 --- a/discojs/src/aggregator/get.ts +++ b/discojs/src/aggregator/get.ts @@ -39,11 +39,7 @@ export function getAggregator( clippingRadius = 1.0, maxIterations = 1, beta = 0.9, - }:{ - clippingRadius?: number; - maxIterations?: number; - beta?: number; - } = task.trainingInformation; + } = task.trainingInformation; if (scheme === "decentralized") { options = { From 15fa3afe2ef3a497ba4548193422257826c12995 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Mon, 8 Dec 2025 16:09:33 +0100 Subject: [PATCH 16/18] Fix CR suggestions --- discojs/src/aggregator/get.ts | 66 ++++++++----------- discojs/src/task/training_information.ts | 24 +++---- .../task_creation_form/TaskCreationForm.vue | 36 ++++++++-- 3 files changed, 67 insertions(+), 59 deletions(-) diff --git a/discojs/src/aggregator/get.ts b/discojs/src/aggregator/get.ts index 7c483cbdf..11b8cc3a0 100644 --- a/discojs/src/aggregator/get.ts +++ b/discojs/src/aggregator/get.ts @@ -32,61 +32,47 @@ export function getAggregator( options: AggregatorOptions = {}, ): aggregator.Aggregator { const scheme = options.scheme ?? task.trainingInformation.scheme + + // If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100% + + // If scheme == 'federated' then we only expect the server's contribution at each round + // so we set the aggregation threshold to 1 contribution + // If scheme == 'local' then we only expect our own contribution + + const networkOptions: Required = { + scheme, + roundCutOff: 0, + threshold: 1, + thresholdType: scheme === "decentralized" ? "relative" : "absolute", + ...options, // user overrides defaults + }; switch (task.trainingInformation.aggregationStrategy) { case 'byzantine': { - const { - clippingRadius = 1.0, - maxIterations = 1, - beta = 0.9, + const {clippingRadius = 1.0, maxIterations = 1, beta = 0.9, } = task.trainingInformation; - if (scheme === "decentralized") { - options = { - roundCutOff: undefined, - threshold: 1, - thresholdType: "relative", - ...options, - }; - } else { - options = { - roundCutOff: undefined, - threshold: 1, - thresholdType: "absolute", - ...options, - }; - } - return new ByzantineRobustAggregator( - options.roundCutOff ?? 0, - options.threshold ?? 1, - options.thresholdType, + networkOptions.roundCutOff ?? 0, + networkOptions.threshold ?? 1, + networkOptions.thresholdType, clippingRadius, maxIterations, beta ); } case 'mean': - if (scheme === 'decentralized') { - // If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100% - options = { - roundCutOff: undefined, threshold: 1, thresholdType: 'relative', - ...options - } - } else { - // If scheme == 'federated' then we only expect the server's contribution at each round - // so we set the aggregation threshold to 1 contribution - // If scheme == 'local' then we only expect our own contribution - options = { - roundCutOff: undefined, threshold: 1, thresholdType: 'absolute', - ...options - } - } - return new aggregator.MeanAggregator(options.roundCutOff, options.threshold, options.thresholdType) + return new aggregator.MeanAggregator( + networkOptions.roundCutOff, + networkOptions.threshold, + networkOptions.thresholdType + ) case 'secure': if (scheme !== 'decentralized') { throw new Error('secure aggregation is currently supported for decentralized only') } - return new aggregator.SecureAggregator(task.trainingInformation.maxShareValue) + return new aggregator.SecureAggregator( + task.trainingInformation.maxShareValue + ) } } diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts index d52591de9..42f74e1c0 100644 --- a/discojs/src/task/training_information.ts +++ b/discojs/src/task/training_information.ts @@ -23,6 +23,14 @@ const nonLocalNetworkSchema = z.object({ minNbOfParticipants: z.number().positive().int(), }); +const byzantineSchema = z.object({ + aggregationStrategy: z.literal("byzantine"), + clippingRadius: z.number().positive().optional().default(1.0), + maxIterations: z.number().int().positive().optional().default(1), + beta: z.number().min(0).max(1).optional().default(0.9), + }); + + export namespace TrainingInformation { export const baseSchema = z.object({ // number of epochs to run training for @@ -38,6 +46,8 @@ export namespace TrainingInformation { tensorBackend: z.enum(["gpt", "tfjs"]), }); + + export const dataTypeToSchema = { image: z.object({ // classes, e.g. if two class of images, one with dogs and one with cats, then we would @@ -74,12 +84,7 @@ export namespace TrainingInformation { z.object({ aggregationStrategy: z.literal("mean"), }), - z.object({ - aggregationStrategy: z.literal("byzantine"), - clippingRadius: z.number().positive().optional().default(1.0), - maxIterations: z.number().int().positive().optional().default(1), - beta: z.number().min(0).max(1).optional().default(0.9), - }), + byzantineSchema, z.object({ aggregationStrategy: z.literal("secure"), // Secure Aggregation: maximum absolute value of a number in a randomly generated share @@ -98,12 +103,7 @@ export namespace TrainingInformation { z.object({ aggregationStrategy: z.literal("mean"), }), - z.object({ - aggregationStrategy: z.literal("byzantine"), - clippingRadius: z.number().positive().optional().default(1.0), - maxIterations: z.number().int().positive().optional().default(1), - beta: z.number().min(0).max(1).optional().default(0.9), - }), + byzantineSchema, ]), ), local: z.object({ diff --git a/webapp/src/components/task_creation_form/TaskCreationForm.vue b/webapp/src/components/task_creation_form/TaskCreationForm.vue index 378dec658..600fbeee3 100644 --- a/webapp/src/components/task_creation_form/TaskCreationForm.vue +++ b/webapp/src/components/task_creation_form/TaskCreationForm.vue @@ -314,11 +314,17 @@ v-model="aggregationStrategy" name="trainingInformation.aggregationStrategy" as="select" - :disabled="scheme !== 'decentralized'" + :disabled="scheme == 'local'" > - - - + + + + + + + + + { event.preventDefault(); }; +const byzantineParams = z.object({ + clippingRadius: z + .number() + .positive("Clipping radius must be positive"), + maxIterations: z + .number() + .int("Max iterations must be an integer") + .positive("Max iterations must be > 0"), + beta: z + .number() + .min(0, "Momentum β must be ≥ 0") + .max(1, "Momentum β must be ≤ 1"), +}); + const nonLocalNetwork = { privacy: z .object({ @@ -667,6 +687,7 @@ const trainingInformationNetworks = z.union([ }), z.object({ aggregationStrategy: z.literal("byzantine"), + ...byzantineParams.shape, }), z.object({ aggregationStrategy: z.literal("secure"), @@ -686,6 +707,7 @@ const trainingInformationNetworks = z.union([ }), z.object({ aggregationStrategy: z.literal("byzantine"), + ...byzantineParams.shape, }), ]), ), From e86a9f127035d6d3955caeb1bc69c1d3d54442ca Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Tue, 16 Dec 2025 14:16:46 +0100 Subject: [PATCH 17/18] Update discojs/src/aggregator/get.ts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rousset <5735566+tharvik@users.noreply.github.com> --- discojs/src/aggregator/get.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/discojs/src/aggregator/get.ts b/discojs/src/aggregator/get.ts index 11b8cc3a0..f95d64518 100644 --- a/discojs/src/aggregator/get.ts +++ b/discojs/src/aggregator/get.ts @@ -53,8 +53,8 @@ export function getAggregator( } = task.trainingInformation; return new ByzantineRobustAggregator( - networkOptions.roundCutOff ?? 0, - networkOptions.threshold ?? 1, + networkOptions.roundCutOff, + networkOptions.threshold, networkOptions.thresholdType, clippingRadius, maxIterations, From ea67e5a1ea2bdfdfe4c8d9a539697f17a02f7acc Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Thu, 25 Dec 2025 13:26:05 +0100 Subject: [PATCH 18/18] Fix clipping radius repeating param --- discojs/src/aggregator/get.ts | 4 +-- discojs/src/default_tasks/cifar10.ts | 2 +- discojs/src/task/training_information.ts | 2 +- .../task_creation_form/TaskCreationForm.vue | 36 ++++++++++++++++--- 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/discojs/src/aggregator/get.ts b/discojs/src/aggregator/get.ts index f95d64518..ba3bd6fcd 100644 --- a/discojs/src/aggregator/get.ts +++ b/discojs/src/aggregator/get.ts @@ -49,14 +49,14 @@ export function getAggregator( switch (task.trainingInformation.aggregationStrategy) { case 'byzantine': { - const {clippingRadius = 1.0, maxIterations = 1, beta = 0.9, + const {byzantineClippingRadius = 1.0, maxIterations = 1, beta = 0.9, } = task.trainingInformation; return new ByzantineRobustAggregator( networkOptions.roundCutOff, networkOptions.threshold, networkOptions.thresholdType, - clippingRadius, + byzantineClippingRadius, maxIterations, beta ); diff --git a/discojs/src/default_tasks/cifar10.ts b/discojs/src/default_tasks/cifar10.ts index 9a3ac0e98..ccc2dda38 100644 --- a/discojs/src/default_tasks/cifar10.ts +++ b/discojs/src/default_tasks/cifar10.ts @@ -36,7 +36,7 @@ export const cifar10: TaskProvider<"image", "decentralized"> = { LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], scheme: 'decentralized', aggregationStrategy: 'byzantine', - clippingRadius: 1.0, + byzantineClippingRadius: 10.0, maxIterations: 1, beta: 0.9, privacy: { clippingRadius: 20, noiseScale: 1 }, diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts index 42f74e1c0..75c4ba865 100644 --- a/discojs/src/task/training_information.ts +++ b/discojs/src/task/training_information.ts @@ -25,7 +25,7 @@ const nonLocalNetworkSchema = z.object({ const byzantineSchema = z.object({ aggregationStrategy: z.literal("byzantine"), - clippingRadius: z.number().positive().optional().default(1.0), + byzantineClippingRadius: z.number().positive().optional().default(1.0), maxIterations: z.number().int().positive().optional().default(1), beta: z.number().min(0).max(1).optional().default(0.9), }); diff --git a/webapp/src/components/task_creation_form/TaskCreationForm.vue b/webapp/src/components/task_creation_form/TaskCreationForm.vue index 600fbeee3..a37b2b060 100644 --- a/webapp/src/components/task_creation_form/TaskCreationForm.vue +++ b/webapp/src/components/task_creation_form/TaskCreationForm.vue @@ -346,7 +346,7 @@ label="Clipping radius (λ) for Centered Clipping" > + +
+ + + +
+
@@ -599,6 +618,7 @@ const dataType = ref("image"); const scheme = ref("federated"); const aggregationStrategy = ref("mean"); const differentialPrivacy = ref(false); +const weightClipping = ref(false); const form = useTemplateRef("form"); @@ -629,17 +649,23 @@ window.onbeforeunload = (event) => { }; const byzantineParams = z.object({ - clippingRadius: z + byzantineClippingRadius: z .number() - .positive("Clipping radius must be positive"), + .positive("Clipping radius must be positive") + .optional() + .default(1.0), maxIterations: z .number() .int("Max iterations must be an integer") - .positive("Max iterations must be > 0"), + .positive("Max iterations must be > 0") + .optional() + .default(1), beta: z .number() .min(0, "Momentum β must be ≥ 0") - .max(1, "Momentum β must be ≤ 1"), + .max(1, "Momentum β must be ≤ 1") + .optional() + .default(0.9), }); const nonLocalNetwork = {