-
Notifications
You must be signed in to change notification settings - Fork 30
Byzantine robust aggregator - initial implementation #913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mina5rovic
wants to merge
22
commits into
develop
Choose a base branch
from
byzantine-robust-aggregator
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
a9ecfc3
Byzantine robust aggregator - initial implementation
mina5rovic f36ec7d
Apply suggestions from code review
mina5rovic cdb4e7d
Multiround aggregator added and mean refactored
mina5rovic cb95c74
Added comments and descriptions, as well as history map.
mina5rovic 33c120c
debugging tests
mina5rovic c282dc6
Byzantine correct aggreagation
mina5rovic eb79636
Secure aggr with momentums
mina5rovic 96a5e41
Secure-history aggr, test fix
mina5rovic 89fa8ed
Docstrings and comments added to secure history aggregator
mina5rovic 7726dbc
Fix linter errors on Secure-history and Byzantine aggregator
mina5rovic c2f5e66
Apply suggestions from code review
mina5rovic 34ab126
Patch GUI and core to support Byzantine aggregator
mina5rovic 3a44bb3
Merge branch 'byzantine-robust-aggregator' of github.com:epfml/disco …
mina5rovic 45410f9
Merge branch 'develop' into byzantine-robust-aggregator
mina5rovic f15443e
Correct federated mart in Task creation form
mina5rovic 495b9de
Fix linter errors
mina5rovic 8b242e2
Apply suggestions from code review
mina5rovic 15fa3af
Fix CR suggestions
mina5rovic 2de3f82
Merge branch 'develop' of github.com:epfml/disco into byzantine-robus…
mina5rovic e86a9f1
Update discojs/src/aggregator/get.ts
mina5rovic 32a3a26
Merge branch 'byzantine-robust-aggregator' of github.com:epfml/disco …
mina5rovic ea67e5a
Fix clipping radius repeating param
mina5rovic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| import { Set } from "immutable"; | ||
| import { describe, expect, it } from "vitest"; | ||
|
|
||
| import { WeightsContainer } from "../index.js"; | ||
| import { ByzantineRobustAggregator } from "./byzantine.js"; | ||
|
|
||
| // Helper to convert WeightsContainer → number[][] for easy assertions | ||
| async function WSIntoArrays(ws: WeightsContainer): Promise<number[][]> { | ||
| return Promise.all(ws.weights.map(async t => Array.from(await t.data()))); | ||
| } | ||
|
|
||
| describe("ByzantineRobustAggregator", () => { | ||
| it("throws on invalid constructor parameters", () => { | ||
| expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 0, 1, 0.5)).to.throw(); | ||
| expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 1, 0, 0.5)).to.throw(); | ||
| expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 1, 1.1, 0.5)).to.throw(); | ||
| expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 1, 1, 1.5)).to.throw(); | ||
| }); | ||
|
|
||
| it("performs basic mean when clippingRadius is large and beta = 0", async () => { | ||
| const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1e6, 1, 0); | ||
| const [id1, id2] = ["c1", "c2"]; | ||
| agg.setNodes(Set.of(id1, id2)); | ||
|
|
||
| const p = agg.getPromiseForAggregation(); | ||
| agg.add(id1, WeightsContainer.of([1], [2]), 0); | ||
| agg.add(id2, WeightsContainer.of([3], [4]), 0); | ||
|
|
||
| const out = await p; | ||
| const arr = await WSIntoArrays(out); | ||
| expect(arr).to.deep.equal([[2], [3]]); | ||
| }); | ||
|
|
||
| it("clips a single outlier with small radius", async () => { | ||
| const agg = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 1, 0); | ||
| const [c1, c2, bad] = ["c1", "c2", "bad"]; | ||
| agg.setNodes(Set.of(c1, c2, bad)); | ||
|
|
||
| const p = agg.getPromiseForAggregation(); | ||
| agg.add(c1, WeightsContainer.of([1]), 0); | ||
| agg.add(c2, WeightsContainer.of([1]), 0); | ||
| agg.add(bad, WeightsContainer.of([100]), 0); | ||
|
|
||
| const out = await p; | ||
| const arr = await WSIntoArrays(out); | ||
| expect(arr[0][0]).to.be.closeTo(1, 1e-6); | ||
| }); | ||
|
|
||
| it("applies multiple clipping iterations (maxIterations > 1)", async () => { | ||
| const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1.0, 3, 0); | ||
| const [c1, bad] = ["c1", "bad"]; | ||
| agg.setNodes(Set.of(c1, bad)); | ||
|
|
||
| const p = agg.getPromiseForAggregation(); | ||
| agg.add(c1, WeightsContainer.of([0]), 0); | ||
| agg.add(bad, WeightsContainer.of([10]), 0); | ||
|
|
||
| const out = await p; | ||
| const arr = await WSIntoArrays(out); | ||
| expect(arr[0][0]).to.be.lessThan(1); // clipped closer to 0 | ||
| }); | ||
|
|
||
| it("uses momentum when beta > 0", async () => { | ||
| const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1e6, 1, 0.5); | ||
| const [c1, c2] = ["c1", "c2"]; | ||
| agg.setNodes(Set.of(c1, c2)); | ||
|
|
||
| const p1 = agg.getPromiseForAggregation(); | ||
| agg.add(c1, WeightsContainer.of([2]), 0); | ||
| agg.add(c2, WeightsContainer.of([2]), 0); | ||
| const out1 = await p1; | ||
| const arr1 = await WSIntoArrays(out1); | ||
| expect(arr1[0][0]).to.equal(2); | ||
|
|
||
| const p2 = agg.getPromiseForAggregation(); | ||
| agg.add(c1, WeightsContainer.of([4]), 1); | ||
| agg.add(c2, WeightsContainer.of([4]), 1); | ||
| const out2 = await p2; | ||
| const arr2 = await WSIntoArrays(out2); | ||
|
|
||
| // With momentum = 0.5, result = 0.5 * prev + 0.5 * current = 3.0 | ||
| expect(arr2[0][0]).to.be.closeTo(3, 1e-6); | ||
| }); | ||
|
|
||
| it("respects roundCutoff — ignores old contributions", async () => { | ||
| const agg = new ByzantineRobustAggregator(1, 1, 'absolute', 1e6, 1, 0); | ||
| const id = "c1"; | ||
| agg.setNodes(Set.of(id)); | ||
|
|
||
| const p0 = agg.getPromiseForAggregation(); | ||
| agg.add(id, WeightsContainer.of([10]), 0); | ||
| const out0 = await p0; | ||
| const arr0 = await WSIntoArrays(out0); | ||
| expect(arr0[0][0]).to.equal(10); | ||
|
|
||
| // Round 2 with cutoff=1 → contributions from round 0 should be discarded | ||
| const p2 = agg.getPromiseForAggregation(); | ||
| agg.add(id, WeightsContainer.of([20]), 2); | ||
| const out2 = await p2; | ||
| const arr2 = await WSIntoArrays(out2); | ||
| expect(arr2[0][0]).to.equal(20); | ||
| }); | ||
| }); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| import { Map } from "immutable"; | ||
| import * as tf from '@tensorflow/tfjs'; | ||
| import { AggregationStep } from "./aggregator.js"; | ||
| import { MultiRoundAggregator, ThresholdType } from "./multiround.js"; | ||
| import { WeightsContainer, client } from "../index.js"; | ||
| import { aggregation } from "../index.js"; | ||
|
|
||
| /** | ||
| * Byzantine-robust aggregator using Centered Clipping (CC), based on the | ||
| * "Learning from History for Byzantine Robust Optimization" paper: https://arxiv.org/abs/2012.10333 | ||
| * | ||
| * This class implements a gradient aggregation rule that clips updates | ||
| * in an iterative fashion to mitigate the influence of Byzantine nodes, as well as momentum calculations. | ||
| */ | ||
| export class ByzantineRobustAggregator extends MultiRoundAggregator { | ||
| private readonly clippingRadius: number; | ||
| private readonly maxIterations: number; | ||
| private readonly beta: number; | ||
| private historyMomentums: Map<client.NodeID, WeightsContainer> = Map(); | ||
| private prevAggregate: WeightsContainer | null = null; | ||
|
|
||
| /** | ||
| @property clippingRadius The clipping threshold (λ) used to limit the influence of outlier updates. | ||
| * - Type: `number` | ||
| * - Determines the maximum norm allowed for the difference between a client update and the current estimate. | ||
| * - Used in the Centered Clipping step to compute a scaling factor for updates. | ||
| * - Smaller values clip more aggressively. | ||
| * - Default value is 1.0. | ||
| * | ||
| * @property maxIterations The number of iterations (L) to run the Centered Clipping update loop. | ||
| * - Type: `number` | ||
| * - Controls how many refinement steps are used to compute the final aggregate `v`. | ||
| * - Default value is 1. | ||
| * * @property beta The momentum coefficient used to smooth the aggregation over multiple rounds. | ||
| * - Type: `number` | ||
| * - Must be between 0 and 1. | ||
| * - Used to compute the exponential moving average of past aggregates (i.e., momentum vector). | ||
| * The update typically looks like: `v_t = beta * v_{t-1} + (1 - beta) * g_t`, where `g_t` is the current clipped average. | ||
| * - A higher beta gives more weight to past rounds (more smoothing), while a lower beta makes the aggregator more responsive to new updates. | ||
| */ | ||
|
|
||
|
|
||
| constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType, clippingRadius = 1.0, maxIterations = 1, beta = 0.9) { | ||
| super(roundCutoff, threshold, thresholdType); | ||
| if (clippingRadius <= 0) throw new Error("Clipping radius needs to be positive number > 0."); | ||
| if (maxIterations < 1) throw new Error("There must be at least one iteration for clipping."); | ||
| if (!Number.isInteger(maxIterations)) throw new Error("Number of iterations must be an integer."); | ||
| if ((beta < 0) || (beta > 1)) throw new Error("Beta must be between 0 and 1, since it is coeficient."); | ||
| this.clippingRadius = clippingRadius; | ||
| this.maxIterations = maxIterations; | ||
| this.beta = beta; | ||
| } | ||
|
|
||
| override _add(nodeId: client.NodeID, contribution: WeightsContainer): void { | ||
| this.log( | ||
| this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, | ||
| nodeId, | ||
| ); | ||
|
|
||
| const prevMomentum = this.historyMomentums.get(nodeId); | ||
| const newMomentum = prevMomentum | ||
| ? contribution.mapWith(prevMomentum, (g, m) => g.mul(1 - this.beta).add(m.mul(this.beta))) | ||
| : contribution; // no scaling on first momentum | ||
|
|
||
| this.historyMomentums = this.historyMomentums.set(nodeId, newMomentum); | ||
| this.contributions = this.contributions.setIn([0, nodeId], newMomentum); | ||
| } | ||
|
|
||
| override aggregate(): WeightsContainer { | ||
| const currentContributions = this.contributions.get(0); | ||
| if (!currentContributions) throw new Error("aggregating without any contribution"); | ||
|
|
||
| this.log(AggregationStep.AGGREGATE); | ||
|
|
||
| // If clipping radius is infinite, fall back to simple mean | ||
| if (!isFinite(this.clippingRadius)) { | ||
| return aggregation.avg(currentContributions.values()); | ||
| } | ||
|
|
||
| // Step 1: Initialize v to average of previous aggregations | ||
| let v: WeightsContainer; | ||
| if (this.prevAggregate) { | ||
| v = this.prevAggregate; | ||
| } else { | ||
| // Use shape of the first contribution to create zero vector | ||
| const first = currentContributions.values().next(); | ||
| if (first.done) throw new Error("zero sized contribution") | ||
| v = first.value.map((t: tf.Tensor) => tf.zerosLike(t)); | ||
| } | ||
| // Step 2: Iterative Centered Clipping | ||
| for (let l = 0; l < this.maxIterations; l++) { | ||
| const clippedDiffs = Array.from(currentContributions.values()).map(m => { | ||
| const diff = m.sub(v); | ||
| const norm = tf.tidy(() => euclideanNorm(diff)); | ||
| const scale = tf.tidy(() => tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm))); | ||
| const clipped = diff.mul(scale); | ||
| norm.dispose(); scale.dispose(); | ||
| return clipped; | ||
| }); | ||
|
|
||
| const avgClip = aggregation.avg(clippedDiffs); | ||
| const newV = v.add(avgClip); | ||
| clippedDiffs.forEach(d => d.dispose()); | ||
| v.dispose(); // Safe if v is no longer needed | ||
| v = newV; | ||
| } | ||
| // Step 3: Update momentum history | ||
| this.prevAggregate = v; | ||
| return v; | ||
| } | ||
|
|
||
|
|
||
| override makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer> { | ||
| // Communicate our local weights to every other node, be it a peer or a server | ||
| return this.nodes.toMap().map(() => weights); | ||
| } | ||
| } | ||
|
|
||
| function euclideanNorm(w: WeightsContainer): tf.Scalar { | ||
| // Computes the Euclidean (L2) norm of all tensors in a WeightsContainer by summing the squares of their elements and taking the square root. | ||
| return tf.tidy(() => { | ||
| const norms: tf.Scalar[] = w.weights.map(t => tf.sum(tf.square(t))); | ||
| const total = norms.reduce((a, b) => tf.add(a, b)); | ||
| return tf.sqrt(total); | ||
| }); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| export { Aggregator, AggregationStep } from './aggregator.js' | ||
| export { MeanAggregator } from './mean.js' | ||
| export { SecureAggregator } from './secure.js' | ||
| export { ByzantineRobustAggregator } from './byzantine.js' | ||
|
|
||
| export { getAggregator } from './get.js' |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.