Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
24f2cda
Added MSE metric for time series forecasting models.
Jasmine-Yuting-Zhang Mar 17, 2026
8232690
Added config file for TimesFM model with EV availability forecasting.
Jasmine-Yuting-Zhang Mar 17, 2026
a19e5ac
ruff format . .
Jasmine-Yuting-Zhang Mar 17, 2026
f8e1c3a
Added ev_charging datasource and the data normalization for time seri…
Jasmine-Yuting-Zhang Mar 17, 2026
536f846
Added ev availablitity datasource in the registry.
Jasmine-Yuting-Zhang Mar 17, 2026
907268a
Added TimesFM model and TimeSeriesUtils for time-series forecasting t…
Jasmine-Yuting-Zhang Mar 17, 2026
6fb4af7
Refactored timeseries models from HuggingFace.
Jasmine-Yuting-Zhang Mar 23, 2026
ed2e026
Updated config files for time-series models.
Jasmine-Yuting-Zhang Mar 23, 2026
2ed6e0f
Added two new config files for testing data influence, may clean up l…
Jasmine-Yuting-Zhang Mar 29, 2026
c757fd4
Updated prediction length for pretrained model.
Jasmine-Yuting-Zhang Mar 30, 2026
c2a3c13
Added TimesFM transformer model from HuggingFace.
Jasmine-Yuting-Zhang Mar 30, 2026
2c0c5ff
Merged remote-tracking branch 'origin/main' into TimeFM-time-series.
Jasmine-Yuting-Zhang Apr 9, 2026
d45f4b2
Document the DiLoCo implementation contract.
baochunli Apr 29, 2026
0378a57
Implemented DiLoCo outer aggregation.
baochunli Apr 29, 2026
6287b0f
Added exact local step limits for trainers.
baochunli Apr 29, 2026
b91d97d
Fixed local step counting for accumulation.
baochunli Apr 29, 2026
6846f50
Added DiLoCo parameter eligibility policy.
baochunli Apr 29, 2026
d122831
Handled adapter payload names in DiLoCo eligibility.
baochunli Apr 29, 2026
a9d8f3b
Avoided DiLoCo adapter alias overmatching.
baochunli Apr 29, 2026
679c9f6
Persisted in-process optimizer state for DiLoCo.
baochunli Apr 29, 2026
114f2a7
Wired DiLoCo server selection.
baochunli Apr 29, 2026
da0f341
Persisted optimizer state across train subprocesses.
baochunli Apr 29, 2026
c730732
Hardened subprocess optimizer state handoff.
baochunli Apr 29, 2026
33129d0
Added DiLoCo payload safety coverage.
baochunli Apr 29, 2026
21bf980
Added round-aware local-step sampling.
baochunli Apr 29, 2026
f6e8196
Handled non-materializable local-step samplers.
baochunli Apr 29, 2026
711fdb1
Added exact DiLoCo smoke configuration.
baochunli Apr 29, 2026
082aaf1
Added end-to-end DiLoCo validation coverage.
baochunli Apr 29, 2026
1313d30
Restored optimizer state after moving models to device.
baochunli Apr 29, 2026
f359d78
Logged DiLoCo outer optimizer application.
baochunli Apr 30, 2026
c365751
Added DiLoCo comparison configs and step-based scheduling.
baochunli Apr 30, 2026
bcc8073
Added MNIST DiLoCo comparison configs.
baochunli Apr 30, 2026
6ffb475
Aligned DiLoCo comparison budgets.
baochunli Apr 30, 2026
bd817d3
Added MSE in metrics for time series forecasting.
Jasmine-Yuting-Zhang May 10, 2026
ef6c606
Added functions for customizing the time series models from configura…
Jasmine-Yuting-Zhang May 10, 2026
8e30a87
Added functions to save the personalized models.
Jasmine-Yuting-Zhang May 10, 2026
6b791cf
Merge remote-tracking branch 'origin/diloco-faithful-implementation' …
Jasmine-Yuting-Zhang May 10, 2026
2c3f2c3
Added config file for TimesFM25 with diloco.
Jasmine-Yuting-Zhang May 10, 2026
dda867e
Updated DiLoCo steps to match the FedAvg.
Jasmine-Yuting-Zhang May 11, 2026
fb5d231
Ruff format .
Jasmine-Yuting-Zhang May 11, 2026
34cf705
Updated documents for time series models, including the models and a …
Jasmine-Yuting-Zhang May 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions configs/CIFAR10/diloco_resnet18.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
[clients]

# Type
type = "simple"

# The total number of clients
total_clients = 50

# The number of clients selected in each round
per_round = 50

# Should the clients compute test accuracy locally?
do_test = false

[server]
type = "diloco"
address = "127.0.0.1"
port = 8021

[server.diloco]
outer_optimizer = "nesterov"
outer_learning_rate = 0.7
outer_momentum = 0.9
aggregation_weighting = "uniform"
apply_outer_optimizer_to = "parameters"

[data]

# The training and testing dataset
datasource = "Torchvision"
dataset_name = "CIFAR10"
download = true

# Number of samples in each partition
partition_size = 1000

# IID or non-IID?
sampler = "iid"

[trainer]

# The type of the trainer
type = "basic"

# The maximum number of training rounds
rounds = 20

# The maximum number of clients running concurrently
max_concurrency = 7

# The target accuracy
target_accuracy = 0.9

# Number of local optimizer steps per DiLoCo synchronization.
local_steps_per_round = 500
preserve_optimizer_state = true

# DiLoCo paper inner-optimizer settings.
epochs = 5
batch_size = 10
optimizer = "AdamW"
lr_scheduler = "LambdaLR"

# The machine learning model
model_name = "resnet_18"

[algorithm]

# Weight extraction and model update path reused by DiLoCo.
type = "fedavg"

[parameters]

[parameters.optimizer]
lr = 0.0004
weight_decay = 0.1

[parameters.learning_rate]
warmup_steps = "1000it"
68 changes: 68 additions & 0 deletions configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
[clients]

# Type
type = "simple"

# The total number of clients
total_clients = 50

# The number of clients selected in each round
per_round = 50

# Should the clients compute test accuracy locally?
do_test = false

[server]
address = "127.0.0.1"
port = 8022

[data]

# The training and testing dataset
datasource = "Torchvision"
dataset_name = "CIFAR10"
download = true

# Number of samples in each partition
partition_size = 1000

# IID or non-IID?
sampler = "iid"

[trainer]

# The type of the trainer
type = "basic"

# The maximum number of training rounds
rounds = 20

# The maximum number of clients running concurrently
max_concurrency = 7

# The target accuracy
target_accuracy = 0.9

# Match the original FedAvg local training shape while keeping 500 optimizer
# steps per round, equal to DiLoCo's H.
epochs = 5
batch_size = 10
optimizer = "AdamW"
lr_scheduler = "LambdaLR"

# The machine learning model
model_name = "resnet_18"

[algorithm]

# Aggregation algorithm
type = "fedavg"

[parameters]

[parameters.optimizer]
lr = 0.0004
weight_decay = 0.1

[parameters.learning_rate]
warmup_steps = "1000it"
75 changes: 75 additions & 0 deletions configs/MNIST/diloco_lenet5.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
[clients]

# Type
type = "simple"

# The total number of clients
total_clients = 50

# The number of clients selected in each round
per_round = 50

# Should the clients compute test accuracy locally?
do_test = false

[server]
type = "diloco"
address = "127.0.0.1"
port = 8001
random_seed = 1
simulate_wall_time = true

[server.diloco]
outer_optimizer = "nesterov"
outer_learning_rate = 0.7
outer_momentum = 0.9
aggregation_weighting = "uniform"
apply_outer_optimizer_to = "parameters"

[data]
include = "mnist_iid.toml"
partition_size = 1000

[trainer]

# The type of the trainer
type = "basic"

# The maximum number of training rounds
rounds = 20

# The maximum number of clients running concurrently
max_concurrency = 7

# The target accuracy
target_accuracy = 0.99

# The machine learning model
model_name = "lenet5"

# Number of local optimizer steps per DiLoCo synchronization.
local_steps_per_round = 500
preserve_optimizer_state = true

# DiLoCo paper inner-optimizer settings.
epochs = 5
batch_size = 32
optimizer = "AdamW"
lr_scheduler = "LambdaLR"

[algorithm]

# Weight extraction and model update path reused by DiLoCo.
type = "fedavg"

[parameters]

[parameters.model]
num_classes = 10

[parameters.optimizer]
lr = 0.0004
weight_decay = 0.1

[parameters.learning_rate]
warmup_steps = "1000it"
66 changes: 66 additions & 0 deletions configs/MNIST/fedavg_lenet5_diloco_comparison.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
[clients]

# Type
type = "simple"

# The total number of clients
total_clients = 50

# The number of clients selected in each round
per_round = 50

# Should the clients compute test accuracy locally?
do_test = false

[server]
address = "127.0.0.1"
port = 8002
random_seed = 1
simulate_wall_time = true

[data]
include = "mnist_iid.toml"
partition_size = 1000

[trainer]

# The type of the trainer
type = "basic"

# The maximum number of training rounds
rounds = 63

# The maximum number of clients running concurrently
max_concurrency = 7

# The target accuracy
target_accuracy = 0.99

# The machine learning model
model_name = "lenet5"

# Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run.
# 5 epochs over 1000 samples at batch size 32 gives 160 optimizer steps per
# round. With 63 rounds, FedAvg gets 10,080 local steps, closely matching
# DiLoCo's 20 * H=500 = 10,000-step total budget.
epochs = 5
batch_size = 32
optimizer = "AdamW"
lr_scheduler = "LambdaLR"

[algorithm]

# Aggregation algorithm
type = "fedavg"

[parameters]

[parameters.model]
num_classes = 10

[parameters.optimizer]
lr = 0.0004
weight_decay = 0.1

[parameters.learning_rate]
warmup_steps = "1000it"
93 changes: 93 additions & 0 deletions configs/TimeSeries/patchtsmixer_ev_charging.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Federated Learning with PatchTSMixer for EV Charging Prediction
#
# Task: Given the past 28 days (672 h) of a user's EV charging behaviour,
# predict whether they will be charging in each of the next 168 hours.
#
# Dataset: "EV Charging Reports" – AdO1 garage, 4 users
# https://data.mendeley.com/datasets/jbks2rcwyj/1
#
# Federated setup: 4 clients, one user each. All clients participate every round.
#
# Model: PatchTSMixer (trained from scratch)
# - uses all 6 input features jointly via mix_channel mode
# - predicts only the is_charging channel
#
# Usage:
# uv run plato.py -c configs/TimeSeries/patchtsmixer_ev_charging.toml

[clients]
type = "simple"
total_clients = 4
per_round = 4
do_test = true

[server]
address = "127.0.0.1"
port = 8000
simulate_wall_time = false
checkpoint_path = "checkpoints/timeseries/patchtsmixer_ev"
model_path = "models/timeseries/patchtsmixer_ev"

[data]
datasource = "EVCharging"

datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv"

garage = "AdO1" # garage id

# Explicit user IDs to include — one client per user.
users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"]
sampler = "all_inclusive"
random_seed = 42

[trainer]
type = "HuggingFace"
rounds = 100
max_concurrency = 4
model_name = "patchtsmixer_scratch"
model_type = "patchtsmixer"
model_task = "forecasting"

context_length = 672 # 4 × 7 × 24
prediction_length = 168 # 7 × 24

# Number of input channels: is_charging, energy_scaled,
# hour_sin, hour_cos, dow_sin, dow_cos
num_input_channels = 6

# Predict and evaluate only the is_charging channel (index 0)
prediction_channel_indices = [0]

patch_length = 8
patch_stride = 8
d_model = 64
num_layers = 4
expansion_factor = 2
dropout = 0.1
head_dropout = 0.1

# Mix all channels so the model can use time features jointly.
mode = "mix_channel"
gated_attn = true
scaling = "std"

# Sliding-window stride for dataset creation
stride = 1 # advance 1 hour at a time to maximize training windows

epochs = 10
batch_size = 16
optimizer = "Adam"

train_ratio = 0.70
val_ratio = 0.15

[algorithm]
type = "fedavg"

[parameters]
[parameters.optimizer]
lr = 0.0005
weight_decay = 1e-4

[results]
types = "round, elapsed_time, mse"
Loading
Loading