Skip to content

Commit 7726dbc

Browse files
committed
Fix linter errors on Secure-history and Byzantine aggregator
1 parent 89fa8ed commit 7726dbc

File tree

6 files changed

+11
-29
lines changed

6 files changed

+11
-29
lines changed

discojs/src/aggregator/byzantine.spec.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import { expect } from "chai";
21
import { Set } from "immutable";
3-
import * as tf from "@tensorflow/tfjs";
2+
import { describe, expect, it } from "vitest";
43

54
import { WeightsContainer } from "../index.js";
65
import { ByzantineRobustAggregator } from "./byzantine.js";
@@ -88,7 +87,6 @@ describe("ByzantineRobustAggregator", () => {
8887
const id = "c1";
8988
agg.setNodes(Set.of(id));
9089

91-
// Round 0
9290
const p0 = agg.getPromiseForAggregation();
9391
agg.add(id, WeightsContainer.of([10]), 0);
9492
const out0 = await p0;

discojs/src/aggregator/byzantine.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import { AggregationStep } from "./aggregator.js";
44
import { MultiRoundAggregator, ThresholdType } from "./multiround.js";
55
import { WeightsContainer, client } from "../index.js";
66
import { aggregation } from "../index.js";
7-
import { Repeat } from "immutable";
87

98
/**
109
* Byzantine-robust aggregator using Centered Clipping (CC), based on the
@@ -84,8 +83,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator {
8483
v = this.prevAggregate;
8584
} else {
8685
// Use shape of the first contribution to create zero vector
87-
const sample = currentContributions.values().next().value;
88-
v = sample.map((t: any) => tf.zerosLike(t));
86+
const sample = currentContributions.values().next().value as WeightsContainer;
87+
v = sample.map((t: tf.Tensor) => tf.zerosLike(t));
8988
}
9089
// Step 2: Iterative Centered ClippingF
9190
for (let l = 0; l < this.maxIterations; l++) {
@@ -119,8 +118,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator {
119118
function euclideanNorm(w: WeightsContainer): tf.Scalar {
120119
// Computes the Euclidean (L2) norm of all tensors in a WeightsContainer by summing the squares of their elements and taking the square root.
121120
return tf.tidy(() => {
122-
const norms = w.weights.map(t => tf.sum(tf.square(t)) as tf.Scalar);
123-
const total = norms.reduce((a, b) => tf.add(a, b)) as tf.Scalar;
121+
const norms: tf.Scalar[] = w.weights.map(t => tf.sum(tf.square(t)));
122+
const total = norms.reduce((a, b) => tf.add(a, b));
124123
return tf.sqrt(total);
125124
});
126125
}

discojs/src/aggregator/mean.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@ import { AggregationStep } from "./aggregator.js";
33
import { MultiRoundAggregator, ThresholdType } from "./multiround.js";
44
import type { WeightsContainer, client } from "../index.js";
55
import { aggregation } from "../index.js";
6-
import createDebug from "debug"
7-
8-
const debug = createDebug("discojs:aggregator:mean");
9-
106

117
/**
128
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.

discojs/src/aggregator/multiround.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Aggregator, AggregationStep } from "./aggregator.js";
1+
import { Aggregator } from "./aggregator.js";
22
import createDebug from "debug";
33

44
export type ThresholdType = 'relative' | 'absolute';

discojs/src/aggregator/secure_history.spec.ts

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import { List, Set, Range, Map } from "immutable";
2-
import { assert, expect } from "chai";
2+
import { describe, expect, it, assert } from "vitest";
3+
34
import * as tf from "@tensorflow/tfjs";
45

56
import {
6-
aggregator as aggregators,
77
aggregation,
88
WeightsContainer,
99
} from "../index.js";
@@ -16,7 +16,6 @@ import { wsIntoArrays, communicate, setupNetwork } from "../aggregator.spec.js";
1616
describe("Secure history aggregator", function () {
1717
const epsilon = 1e-4;
1818

19-
const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]);
2019
const secrets = List.of(
2120
WeightsContainer.of([1, 2, 3, -1], [-5, 6]),
2221
WeightsContainer.of([2, 3, 7, 1], [-10, 5]),
@@ -32,15 +31,6 @@ describe("Secure history aggregator", function () {
3231
});
3332
}
3433

35-
function buildPartialSums(
36-
allShares: List<List<WeightsContainer>>,
37-
): List<WeightsContainer> {
38-
return Range(0, secrets.size)
39-
.map((idx) => allShares.map((shares) => shares.get(idx)))
40-
.map((shares) => aggregation.sum(shares as List<WeightsContainer>))
41-
.toList();
42-
}
43-
4434
it("recovers secrets from shares", () => {
4535
const recovered = buildShares().map((shares) => aggregation.sum(shares));
4636
assert.isTrue(
@@ -60,9 +50,9 @@ describe("Secure history aggregator", function () {
6050

6151
const sharesRound0 = buildShares();
6252

63-
let partialSums = Range(0, nodes.size).map((receiverIdx) => {
53+
const partialSums = Range(0, nodes.size).map((receiverIdx) => {
6454
const receivedShares = sharesRound0.map(shares => shares.get(receiverIdx)!);
65-
return aggregation.sum(receivedShares as List<WeightsContainer>);
55+
return aggregation.sum(receivedShares);
6656
}).toList();
6757

6858
// Add one total contribution per node

discojs/src/aggregator/secure_history.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import type { WeightsContainer, client } from "../index.js";
1+
import type { WeightsContainer } from "../index.js";
22
import { SecureAggregator } from "./secure.js";
3-
import * as tf from "@tensorflow/tfjs";
43
import { aggregation } from "../index.js";
54

65
/**

0 commit comments

Comments
 (0)