diff --git a/configs/CIFAR10/diloco_resnet18.toml b/configs/CIFAR10/diloco_resnet18.toml new file mode 100644 index 000000000..ed407000c --- /dev/null +++ b/configs/CIFAR10/diloco_resnet18.toml @@ -0,0 +1,79 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +type = "diloco" +address = "127.0.0.1" +port = 8021 + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[data] + +# The training and testing dataset +datasource = "Torchvision" +dataset_name = "CIFAR10" +download = true + +# Number of samples in each partition +partition_size = 1000 + +# IID or non-IID? +sampler = "iid" + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.9 + +# Number of local optimizer steps per DiLoCo synchronization. +local_steps_per_round = 500 +preserve_optimizer_state = true + +# DiLoCo paper inner-optimizer settings. +epochs = 5 +batch_size = 10 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +# The machine learning model +model_name = "resnet_18" + +[algorithm] + +# Weight extraction and model update path reused by DiLoCo. +type = "fedavg" + +[parameters] + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml new file mode 100644 index 000000000..26f32d0ce --- /dev/null +++ b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml @@ -0,0 +1,68 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8022 + +[data] + +# The training and testing dataset +datasource = "Torchvision" +dataset_name = "CIFAR10" +download = true + +# Number of samples in each partition +partition_size = 1000 + +# IID or non-IID? +sampler = "iid" + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.9 + +# Match the original FedAvg local training shape while keeping 500 optimizer +# steps per round, equal to DiLoCo's H. +epochs = 5 +batch_size = 10 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +# The machine learning model +model_name = "resnet_18" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/MNIST/diloco_lenet5.toml b/configs/MNIST/diloco_lenet5.toml new file mode 100644 index 000000000..53eff9305 --- /dev/null +++ b/configs/MNIST/diloco_lenet5.toml @@ -0,0 +1,75 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +type = "diloco" +address = "127.0.0.1" +port = 8001 +random_seed = 1 +simulate_wall_time = true + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[data] +include = "mnist_iid.toml" +partition_size = 1000 + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.99 + +# The machine learning model +model_name = "lenet5" + +# Number of local optimizer steps per DiLoCo synchronization. +local_steps_per_round = 500 +preserve_optimizer_state = true + +# DiLoCo paper inner-optimizer settings. +epochs = 5 +batch_size = 32 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +[algorithm] + +# Weight extraction and model update path reused by DiLoCo. +type = "fedavg" + +[parameters] + +[parameters.model] +num_classes = 10 + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/MNIST/fedavg_lenet5_diloco_comparison.toml b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml new file mode 100644 index 000000000..e223915bb --- /dev/null +++ b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml @@ -0,0 +1,66 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8002 +random_seed = 1 +simulate_wall_time = true + +[data] +include = "mnist_iid.toml" +partition_size = 1000 + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 63 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.99 + +# The machine learning model +model_name = "lenet5" + +# Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run. +# 5 epochs over 1000 samples at batch size 32 gives 160 optimizer steps per +# round. With 63 rounds, FedAvg gets 10,080 local steps, closely matching +# DiLoCo's 20 * H=500 = 10,000-step total budget. +epochs = 5 +batch_size = 32 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.model] +num_classes = 10 + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/TimeSeries/patchtsmixer_ev_charging.toml b/configs/TimeSeries/patchtsmixer_ev_charging.toml new file mode 100644 index 000000000..fc3e5be7b --- /dev/null +++ b/configs/TimeSeries/patchtsmixer_ev_charging.toml @@ -0,0 +1,93 @@ +# Federated Learning with PatchTSMixer for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – AdO1 garage, 4 users +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 4 clients, one user each. All clients participate every round. +# +# Model: PatchTSMixer (trained from scratch) +# - uses all 6 input features jointly via mix_channel mode +# - predicts only the is_charging channel +# +# Usage: +# uv run plato.py -c configs/TimeSeries/patchtsmixer_ev_charging.toml + +[clients] +type = "simple" +total_clients = 4 +per_round = 4 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/patchtsmixer_ev" +model_path = "models/timeseries/patchtsmixer_ev" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +garage = "AdO1" # garage id + +# Explicit user IDs to include — one client per user. +users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 4 +model_name = "patchtsmixer_scratch" +model_type = "patchtsmixer" +model_task = "forecasting" + +context_length = 672 # 4 × 7 × 24 +prediction_length = 168 # 7 × 24 + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Predict and evaluate only the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +patch_stride = 8 +d_model = 64 +num_layers = 4 +expansion_factor = 2 +dropout = 0.1 +head_dropout = 0.1 + +# Mix all channels so the model can use time features jointly. +mode = "mix_channel" +gated_attn = true +scaling = "std" + +# Sliding-window stride for dataset creation +stride = 1 # advance 1 hour at a time to maximize training windows + +epochs = 10 +batch_size = 16 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/configs/TimeSeries/timesfm25_ev_charging.toml b/configs/TimeSeries/timesfm25_ev_charging.toml new file mode 100644 index 000000000..20ae2b57e --- /dev/null +++ b/configs/TimeSeries/timesfm25_ev_charging.toml @@ -0,0 +1,89 @@ +# Federated Learning with TimesFM2.5 for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – AdO1 garage, 4 users +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 4 clients, one user each. All clients participate every round. +# +# Model: TimesFM (custom, trained from scratch for the small data regime) +# – channel-independent: each of the 6 input features is processed +# as a separate univariate series; only is_charging is evaluated. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm_ev_charging.toml + +[clients] +type = "simple" +total_clients = 4 +per_round = 4 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm_ev" +model_path = "models/timeseries/timesfm_ev" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +garage = "AdO1" # garage id + +# Explicit user IDs to include — one client per user. +users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 4 +model_name = "google/timesfm-2.5-200m-pytorch" +model_type = "timesfm" + +context_length = 672 # 4 × 7 × 24 +prediction_length = 168 # 7 × 24 + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +num_hidden_layers = 4 +hidden_size = 256 +intermediate_size = 256 +num_attention_heads = 4 +head_dim = 64 +dropout = 0.1 + +freq = 0 + +# Sliding-window stride for dataset creation +stride = 1 # advance 1 hour at a time to maximizes training windows + +epochs = 10 +batch_size = 16 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml b/configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml new file mode 100644 index 000000000..56895618e --- /dev/null +++ b/configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml @@ -0,0 +1,89 @@ +# Federated Learning with TimesFM2.5 for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 128 hours. +# +# Dataset: "EV Charging Reports" – single-user run for Bl2-5 +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 1 client, using only user Bl2-5. +# +# Model: TimesFM (custom, trained from scratch for the small data regime) +# – channel-independent: each of the 6 input features is processed +# as a separate univariate series; only is_charging is evaluated. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml + +[clients] +type = "simple" +total_clients = 1 +per_round = 1 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm25_ev_bl2_5_only" +model_path = "models/timeseries/timesfm25_ev_bl2_5_only" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +garage = "Bl2" + +# Explicit user IDs to include — one client per user. +users = ["Bl2-5"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 1 +model_name = "google/timesfm-2.5-200m-pytorch" +model_type = "timesfm" + +context_length = 672 +prediction_length = 128 + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +num_hidden_layers = 4 +hidden_size = 256 +intermediate_size = 256 +num_attention_heads = 4 +head_dim = 64 +dropout = 0.1 + +freq = 0 + +# Sliding-window stride for dataset creation +stride = 1 + +epochs = 10 +batch_size = 16 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml b/configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml new file mode 100644 index 000000000..33b72ac62 --- /dev/null +++ b/configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml @@ -0,0 +1,90 @@ +# Federated Learning with TimesFM2.5 for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – mixed high-data users across garages +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 4 clients, one user each. All clients participate every round. +# +# Model: TimesFM (custom, trained from scratch for the small data regime) +# – channel-independent: each of the 6 input features is processed +# as a separate univariate series; only is_charging is evaluated. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml + +[clients] +type = "simple" +total_clients = 4 +per_round = 4 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm25_ev_top4_mixed" +model_path = "models/timeseries/timesfm25_ev_top4_mixed" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +# Use explicit users across the whole dataset, not just a single garage. +garage = "all" + +# Explicit user IDs to include — one client per user. +users = ["Bl2-5", "AsO2-1", "Bl2-1", "AdO1-3"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 4 +model_name = "google/timesfm-2.5-200m-pytorch" +model_type = "timesfm" + +context_length = 672 +prediction_length = 168 + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +num_hidden_layers = 4 +hidden_size = 256 +intermediate_size = 256 +num_attention_heads = 4 +head_dim = 64 +dropout = 0.1 + +freq = 0 + +# Sliding-window stride for dataset creation +stride = 1 + +epochs = 10 +batch_size = 16 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml b/configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml new file mode 100644 index 000000000..36e450d78 --- /dev/null +++ b/configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml @@ -0,0 +1,107 @@ +# Federated Learning with TimesFM2.5 + DiLoCo for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 128 hours. +# +# Dataset: "EV Charging Reports" – mixed high-data users across garages +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 4 clients, one user each. All clients participate every round. +# +# Model: TimesFM (custom, trained from scratch for the small data regime) +# – channel-independent: each of the 6 input features is processed +# as a separate univariate series; only is_charging is evaluated. +# +# DiLoCo uses server.type = "diloco" with algorithm.type = "fedavg" so the +# standard FedAvg weight extraction/update path is reused for local training. +# local_steps_per_round counts optimizer steps, not epochs; optimizer state is +# preserved locally under model_path between DiLoCo synchronizations. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml + +[clients] +type = "simple" +total_clients = 4 +per_round = 4 +do_test = true + +[server] +type = "diloco" +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm25_ev_top4_mixed_diloco" +model_path = "models/timeseries/timesfm25_ev_top4_mixed_diloco" + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +# Use explicit users across the whole dataset, not just a single garage. +garage = "all" + +# Explicit user IDs to include — one client per user. +users = ["Bl2-5", "AsO2-1", "Bl2-1", "AdO1-3"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 10 +max_concurrency = 1 +model_name = "google/timesfm-2.5-200m-transformers" +model_type = "timesfm" + +context_length = 672 +prediction_length = 128 # TimesFM2.5 transformers horizon_length is fixed at 128 steps. + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +num_hidden_layers = 4 +hidden_size = 256 +intermediate_size = 256 +num_attention_heads = 4 +head_dim = 64 +dropout = 0.1 + +freq = 0 + +# Sliding-window stride for dataset creation +stride = 1 + +# Number of local optimizer steps per DiLoCo synchronization. +local_steps_per_round = 1500 +preserve_optimizer_state = true + +epochs = 10 +batch_size = 16 +optimizer = "AdamW" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/configs/TimeSeries/timesfm_ev_charging.toml b/configs/TimeSeries/timesfm_ev_charging.toml new file mode 100644 index 000000000..7ced5a721 --- /dev/null +++ b/configs/TimeSeries/timesfm_ev_charging.toml @@ -0,0 +1,89 @@ +# Federated Learning with TimesFM for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – AdO1 garage, 4 users +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 4 clients, one user each. All clients participate every round. +# +# Model: TimesFM (custom, trained from scratch for the small data regime) +# – channel-independent: each of the 6 input features is processed +# as a separate univariate series; only is_charging is evaluated. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm_ev_charging.toml + +[clients] +type = "simple" +total_clients = 4 +per_round = 4 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm_ev" +model_path = "models/timeseries/timesfm_ev" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +garage = "AdO1" # garage id + +# Explicit user IDs to include — one client per user. +users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 4 +model_name = "google/timesfm-2.0-500m-pytorch" +model_type = "timesfm" + +context_length = 672 # 4 × 7 × 24 +prediction_length = 168 # 7 × 24 + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +num_hidden_layers = 4 +hidden_size = 256 +intermediate_size = 256 +num_attention_heads = 4 +head_dim = 64 +dropout = 0.1 + +freq = 0 + +# Sliding-window stride for dataset creation +stride = 1 # advance 1 hour at a time to maximizes training windows + +epochs = 10 +batch_size = 16 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/configs/TimeSeries/timesfm_transformers_bl1.toml b/configs/TimeSeries/timesfm_transformers_bl1.toml new file mode 100644 index 000000000..85d360dab --- /dev/null +++ b/configs/TimeSeries/timesfm_transformers_bl1.toml @@ -0,0 +1,77 @@ +# Federated Learning with TimesFM 2.5 (HuggingFace transformers) for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – AdO1 garage, 4 users +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Model: google/timesfm-2.5-200m-transformers +# Uses Timesfm2P5ModelForPrediction from the transformers library. +# Channel-independent: each of the 6 input features is processed as a +# separate univariate series; only is_charging is evaluated. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm25_transformers_ev_charging.toml + +[clients] +type = "simple" +total_clients = 1 +per_round = 1 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm25t_ev" +model_path = "models/timeseries/timesfm25t_ev" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +garage = "Bl2" + +# Explicit user IDs to include — one client per user. +users = ["Bl2-1"] +sampler = "all_inclusive" + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 1 +model_name = "google/timesfm-2.5-200m-transformers" +model_type = "timesfm" + +context_length = 672 # 4 × 7 × 24 (model supports up to 16384) +prediction_length = 128 # model horizon_length is fixed at 128 steps + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +# Sliding-window stride for dataset creation +stride = 1 # advance 1 hour at a time to maximise training windows + +epochs = 5 +batch_size = 32 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/docs/docs/configurations/data.md b/docs/docs/configurations/data.md index 2173be10e..ff2b1c481 100644 --- a/docs/docs/configurations/data.md +++ b/docs/docs/configurations/data.md @@ -5,6 +5,7 @@ - `Torchvision`: including torchvision datasets such as MNIST, FashionMNIST, EMNIST, CIFAR10, CIFAR100, CelebA, or STL10 (requires `dataset_name`) - `CINIC10` - `FEMNIST`: Federated EMNIST + - `EVCharging`: per-user EV charging time-series forecasting windows - `TinyImageNet` - `Purchase` - `Texas` @@ -57,6 +58,20 @@ !!! example "test_path" Where the test dataset is located. +!!! tip "EVCharging time-series datasource" + `EVCharging` builds per-user hourly time series from the [_Residential electric vehicle charging datasets from apartment buildings_](https://data.mendeley.com/datasets/jbks2rcwyj/1/files/2e3b8ced-9887-4a91-b721-8e510e18a127) [doi: 10.17632/jbks2rcwyj.1]. Each client receives one configured user, so use `sampler = "all_inclusive"` rather than class-label partitioning. + + ```toml + [data] + datasource = "EVCharging" + datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + garage = "AdO1" # use "all" for users across garages + users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] + sampler = "all_inclusive" + ``` + + The datasource creates `past_values` / `future_values` sliding-window samples. The input feature order is `is_charging`, `energy_scaled`, `hour_sin`, `hour_cos`, `dow_sin`, and `dow_cos`; the reference configs forecast only `is_charging`. + !!! example "sampler" How to divide the entire dataset to the clients. The following options are available: diff --git a/docs/docs/configurations/results.md b/docs/docs/configurations/results.md index 56537231c..dea8d2d7b 100644 --- a/docs/docs/configurations/results.md +++ b/docs/docs/configurations/results.md @@ -6,6 +6,8 @@ - `round` - `accuracy` - `accuracy_std` + - `mse` + - `mse_std` - `elapsed_time` - `comm_time` - `processing_time` @@ -19,6 +21,7 @@ !!! note "Note" Use commas to separate them. The default is `round, accuracy, elapsed_time`. + Time-series configs commonly use `round, elapsed_time, mse` instead. !!! note "Structured evaluators" When `[evaluation]` is configured, Plato automatically appends any new `evaluation_*` columns that appear at runtime. You do **not** need to predeclare every Lighteval task metric in `results.types`, although predeclaring the summary columns can keep the CSV order stable. diff --git a/docs/docs/configurations/server.md b/docs/docs/configurations/server.md index 2bb800237..cef578eb9 100644 --- a/docs/docs/configurations/server.md +++ b/docs/docs/configurations/server.md @@ -8,6 +8,7 @@ - `fedavg_personalized` a Federated Averaging server that supports all-purpose personalized federated learning by controlling when and which group of clients are to perform local personalization. - `fedavg_mpc_additive` a Federated Averaging server that reconstructs additive MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_additive` processor. - `fedavg_mpc_shamir` a Federated Averaging server that reconstructs Shamir MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_shamir` processor. + - `diloco` a FedAvg-compatible server that applies DiLoCo outer aggregation. Use it with `algorithm.type = "fedavg"` and configure the outer optimizer under `[server.diloco]`. - `split_learning` a Split Learning server that supports training different kinds of models in split learning framework. When this server is used, the `clients.per_round` in the configuration should be set to 1. Users should define the rules for updating models weights before cut from the clients to the server in the callback function `on_update_weights_before_cut`, depending on the specific model they use. - `fedavg_personalized` a personalized federated learning server that starts from a number of regular rounds of federated learning. In these regular rounds, only a subset of the total clients can be selected to perform the local update (the ratio of which is a configuration setting). After all regular rounds are completed, it starts a final round of personalization, where a selected subset of clients perform local training using their local dataset. - `pfedgraph` a personalized federated learning server that aggregates models using an inferred collaboration graph and sends per-client aggregated weights. @@ -124,6 +125,37 @@ Default value: `100` +!!! example "diloco" + Settings for `server.type = "diloco"`. DiLoCo reuses `algorithm.type = "fedavg"` for client weight extraction and global model loading, while the DiLoCo server turns client deltas into an outer-gradient update. + + ```toml + [server] + type = "diloco" + + [algorithm] + type = "fedavg" + + [server.diloco] + outer_optimizer = "nesterov" + outer_learning_rate = 0.7 + outer_momentum = 0.9 + aggregation_weighting = "uniform" + apply_outer_optimizer_to = "parameters" + ``` + + `aggregation_weighting = "uniform"` matches balanced IID worker smoke runs. `aggregation_weighting = "num_samples"` matches Plato's traditional sample-weighted FedAvg behavior. With outer SGD and `outer_learning_rate = 1.0`, uniform weighting is equivalent to uniform model averaging; with `num_samples`, it is equivalent to Plato-style sample-weighted FedAvg. + + `apply_outer_optimizer_to = "parameters"` applies the outer optimizer only to trainable floating parameters. Floating buffers are synchronized with the selected averaging rule but do not receive outer momentum. `apply_outer_optimizer_to = "all_floating"` is available for experiments that also apply the outer optimizer to floating buffers. + + Runnable comparison configurations are available for MNIST/LeNet and CIFAR-10/ResNet-18: + + ```bash + uv run python plato.py --config configs/MNIST/diloco_lenet5.toml + uv run python plato.py --config configs/CIFAR10/diloco_resnet18.toml + ``` + + These configurations validate DiLoCo mechanics in Plato; they are not C4/model/pretraining reproductions of the DiLoCo paper. + !!! example "edge_downlink_bandwidth" The edge server's estimated downlink capacity (an edge server to its clients) in Mbps, used for computing the transmission time (see `compute_comm_time` in the `clients` section). diff --git a/docs/docs/configurations/trainer.md b/docs/docs/configurations/trainer.md index de9eb715e..9c8852971 100644 --- a/docs/docs/configurations/trainer.md +++ b/docs/docs/configurations/trainer.md @@ -5,7 +5,7 @@ - `composable` the strategy-based trainer that exposes loss, optimiser, scheduler, data-loader, model-update, and testing strategies directly. - `timm_basic` a basic trainer with the [timm](https://timm.fast.ai/) learning rate scheduler. - `diff_privacy` a trainer that supports local differential privacy in its training loop by adding noise to the gradients during each step of training. - - `HuggingFace` a trainer for Hugging Face causal language models and tokenizers. + - `HuggingFace` a trainer for Hugging Face causal language models, tokenizers, and time-series models. - `nanochat` a trainer for Nanochat language-model workloads. - `lerobot` a trainer for LeRobot / SmolVLA workloads. - `split_learning` a trainer that supports the split learning framework. @@ -56,6 +56,20 @@ !!! example "epochs" The total number of epochs in local training in each communication round. +!!! example "local_steps_per_round" + The DiLoCo local work value `H`, counted as completed client-local optimizer steps between synchronizations. + + `H` is not an epoch count, raw dataloader batch count, or gradient-accumulation micro-batch count. When gradient accumulation is enabled, only batches that trigger `optimizer.step()` increment `H`. + + `H` may be smaller than one epoch. In that case, local training stops mid-epoch after exactly `H` optimizer steps while still running normal trainer cleanup, callback completion, state persistence, and reporting. + + Small-`H` DiLoCo runs use round-aware sampling where supported so a logical client does not replay the same first `H` batches every round. Trainers or samplers that cannot count optimizer steps or advance the local stream faithfully must fail or warn clearly instead of silently approximating DiLoCo. + +!!! example "preserve_optimizer_state" + Whether client-local optimizer and scheduler state should persist across a logical client's local train runs. + + DiLoCo should set this to `true` with a stateful inner optimizer such as `AdamW`. Optimizer and scheduler state remains client-local and is not transmitted in client-server payloads. + !!! example "batch_size" The size of the mini-batch of data in each step (iteration) of the training loop. @@ -165,6 +179,8 @@ - `cnn_encoder` (for generating various encoders by extracting from CNN models such as ResNet models) - `general_multilayer` (for generating a multi-layer perceptron using a provided configuration) - `huggingface` (for [HuggingFace](https://huggingface.co/models) causal language models) + - `timesfm` (for Hugging Face TimesFM time-series forecasting models) + - `patchtsmixer` (for Hugging Face PatchTSMixer time-series models) - `torch_hub` (for models from [PyTorch Hub](https://pytorch.org/hub/)) - `vit` (for Vision Transformer models from [HuggingFace](https://huggingface.co/models), [Tokens-to-Token ViT](https://github.com/yitu-opensource/T2T-ViT), and [Deep Vision Transformer](https://github.com/zhoudaquan/dvit_repo)) - `smolvla` (for LeRobot / SmolVLA robotics policies) @@ -190,6 +206,8 @@ - `multilayer` - `nanochat` - `smolvla` + - `timesfm` + - `patchtsmixer` !!! note "Note" If the `model_type` above specified a model repository, supply the name of the model, such as `gpt2`, `HuggingFaceTB/SmolLM2-135M`, or `smolvla`, here. @@ -200,3 +218,18 @@ An optional tokenizer identifier to use instead of `trainer.model_name`. This is mainly useful for Hugging Face language-model workloads where the tokenizer/chat template comes from a separate repository. + +!!! tip "HuggingFace time-series models" + Set `trainer.type = "HuggingFace"` with `model_type = "timesfm"` or `model_type = "patchtsmixer"` to use Plato's time-series collator and MSE testing strategy instead of the tokenizer-based language-model path. + + Common fields include: + + - `context_length`: number of historical time steps in each input window. + - `prediction_length`: number of future time steps to forecast. + - `num_input_channels`: number of features in `past_values`. + - `prediction_channel_indices`: channels to keep/evaluate from model output. + - `stride`: sliding-window stride used by the datasource. + - `train_ratio` and `val_ratio`: temporal split ratios for per-user windows. + - `freq`: TimesFM frequency token (`0` for high-frequency/hourly data). + + TimesFM models are wrapped channel-independently for multivariate tensors; PatchTSMixer configs can use model options such as `mode = "mix_channel"` to mix features jointly. See the [TimesFM case study](../examples/case-studies/6. Time-Series Forecasting with TimesFM.md) for complete configs. diff --git a/docs/docs/development/diloco.md b/docs/docs/development/diloco.md new file mode 100644 index 000000000..f4b8ce405 --- /dev/null +++ b/docs/docs/development/diloco.md @@ -0,0 +1,237 @@ +# DiLoCo Design Contract + +This note defines what Plato calls faithful DiLoCo in the current +implementation. + +Faithful DiLoCo in Plato means algorithm-faithful execution of the DiLoCo +training loop inside Plato's federated runtime. It does not mean reproducing +the paper's exact C4 dataset, model scale, tokenizer, hardware topology, +pretraining duration, or final benchmark numbers. + +## Example Configurations + +Plato includes MNIST/LeNet and CIFAR-10/ResNet-18 comparison configurations +for checking DiLoCo against matched FedAvg runs: + +```bash +uv run python plato.py --config configs/MNIST/diloco_lenet5.toml +uv run python plato.py --config configs/MNIST/fedavg_lenet5_diloco_comparison.toml +uv run python plato.py --config configs/CIFAR10/diloco_resnet18.toml +uv run python plato.py --config configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml +``` + +These examples validate Plato's DiLoCo mechanics without reproducing the C4 +dataset, tokenizer, language-model scale, hardware topology, pretraining +duration, or final benchmark numbers from the paper. + +## Algorithm Contract + +DiLoCo has two optimizer levels: + +- The client-local inner optimizer trains each selected logical client for + exactly `H` local optimizer steps between synchronizations. +- The server-side outer optimizer updates the global model from the averaged + outer gradient. + +Plato's FedAvg-style model delta is: + +```text +plato_delta = client_after - global_before +``` + +DiLoCo's outer gradient is: + +```text +outer_gradient = global_before - client_after = -plato_delta +``` + +The DiLoCo server must still return a Plato-compatible model delta because +`algorithm.update_weights()` adds the returned delta to the current global +model. For example, outer SGD with learning rate `1.0` returns the averaged +Plato delta and is equivalent to FedAvg only when the same averaging rule is +used. + +The outer optimizer runs on the server. Clients run only the inner optimizer +and send model weights or weight-equivalent updates. Client-local optimizer and +scheduler state persists per logical client and is never sent to the server. + +## Local Work `H` + +`H` means client-local optimizer steps between synchronizations. It is not: + +- epochs, +- raw dataloader batches, or +- gradient-accumulation micro-batches. + +When gradient accumulation is enabled, `H` counts completed optimizer steps. +Raw batches that do not trigger `optimizer.step()` do not increment `H`. + +`H` may be smaller than one epoch. Faithful DiLoCo must therefore stop local +training mid-epoch after exactly `H` optimizer steps. This early stop must +still run normal trainer cleanup, state persistence, callback completion, and +reporting paths. It must not perform an extra final optimizer step. + +Small-`H` training must not repeatedly replay the same first `H` batches only +because the train loader is recreated each round. The implementation must use +round-aware resampling or an equivalent persistent sampling stream so each +logical client's local data stream advances across rounds in a reproducible +way. + +## State Ownership + +Server-owned state: + +- the global model, +- outer optimizer momentum or other outer optimizer state, +- aggregation metadata needed to update the global model. + +Client-owned state: + +- inner optimizer state, such as AdamW first and second moments, +- scheduler state and global/local optimizer-step counters, +- sampler or dataloader stream position needed for small-`H` continuity. + +Client-owned optimizer and scheduler state must not appear in client-server +payloads. It must remain local to the logical client, including when training +uses subprocesses. + +## Parameter And Buffer Policy + +By default, the outer optimizer applies only to trainable floating parameters. +This matches the algorithm definition, which optimizes model parameters. + +Floating buffers, such as batch normalization running statistics, are +synchronized without outer momentum by default. They use the selected averaging +rule but do not receive server-side momentum or Nesterov treatment. + +Non-floating buffers use conservative FedAvg-style behavior, including casting +or rounding as needed to preserve the buffer's dtype-compatible semantics. + +The implementation may offer `apply_outer_optimizer_to = "all_floating"` for +experiments, but the default must remain `parameters`. + +## Configuration Contract + +The faithful initial mode uses these configuration names and defaults: + +```toml +[server] +type = "diloco" + +[algorithm] +type = "fedavg" + +[trainer] +local_steps_per_round = H +preserve_optimizer_state = true +optimizer = "AdamW" + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" # or "num_samples" +apply_outer_optimizer_to = "parameters" # or "all_floating" +``` + +`algorithm.type = "fedavg"` is intentional. Plato should reuse the existing +FedAvg weight extraction, delta computation, and global model loading path, +while `server.type = "diloco"` selects the server-side DiLoCo aggregation and +outer optimizer behavior. + +`aggregation_weighting = "uniform"` matches the balanced worker setting most +closely. `aggregation_weighting = "num_samples"` matches Plato's traditional +sample-weighted FedAvg behavior. FedAvg equivalence for outer SGD with learning +rate `1.0` is valid only when both runs use the same weighting rule. + +Unsupported modes must fail clearly. They must not silently fall back to an +approximate DiLoCo variant. Examples include trainer backends that cannot count +local optimizer steps exactly, execution paths that cannot preserve +client-local optimizer and scheduler state, samplers that cannot advance the +small-`H` local data stream across rounds, or payload paths that would send +optimizer state to the server. Experimental combinations that are allowed but +not faithful must warn clearly. + +## Implementation Sequence + +Dependency graph: + +```text +D1 +|-- D2 --> D3 +|-- D4 --> D5 +|-- D6 --> D7 +|-- D8 --> D9 +`-- D10 --> D11 + +D3, D5, D7, D9, D11 --> D12 --> D13 +``` + +Tasks: + +```yaml +- id: D1 + depends_on: [] + task: Document the exact DiLoCo contract and unsupported modes. + +- id: D2 + depends_on: [D1] + task: Add red tests for server-side outer gradient sign, weighting, and + FedAvg equivalence under matching weighting. + +- id: D3 + depends_on: [D2] + task: Implement DiLoCo server aggregation and outer optimizer state for SGD, + momentum SGD, and Nesterov. + +- id: D4 + depends_on: [D1] + task: Add red tests for exact local optimizer-step counting and `H` smaller + than one epoch. + +- id: D5 + depends_on: [D4] + task: Implement `trainer.local_steps_per_round` with mid-epoch termination + after exactly `H` optimizer steps. + +- id: D6 + depends_on: [D1] + task: Add red tests for per-client optimizer and scheduler state + persistence. + +- id: D7 + depends_on: [D6] + task: Persist client-local optimizer and scheduler state without sending it + to the server. + +- id: D8 + depends_on: [D1] + task: Add red tests for round-aware small-`H` sampling. + +- id: D9 + depends_on: [D8] + task: Implement round-aware resampling or an equivalent persistent sampling + stream for each logical client. + +- id: D10 + depends_on: [D1] + task: Add red tests for parameter and buffer eligibility. + +- id: D11 + depends_on: [D10] + task: Implement the default trainable-parameter-only outer optimizer policy + and conservative buffer synchronization. + +- id: D12 + depends_on: [D3, D5, D7, D9, D11] + task: Wire exact DiLoCo configuration, examples, and user-facing + documentation. + +- id: D13 + depends_on: [D12] + task: Add end-to-end faithful-mode validation coverage. +``` + +Every implementation task should use red/green test-driven development. Add +the failing tests that describe the contract first, then implement the smallest +runtime change that makes those tests pass. diff --git a/docs/docs/examples/Getting Started.md b/docs/docs/examples/Getting Started.md index c8e5a157a..7109746e2 100644 --- a/docs/docs/examples/Getting Started.md +++ b/docs/docs/examples/Getting Started.md @@ -54,3 +54,7 @@ Plato supports both Linux with NVIDIA GPUs and macOS with M1/M2/M4/M4 GPUs. It w - [Server-side Lighteval for SmolLM2](case-studies/4. Server-side Lighteval for SmolLM2.md) - [SmolVLA Trainer with LeRobot](case-studies/3. SmolVLA Trainer with LeRobot.md) + +- [Nanochat in Plato](case-studies/5. Nanochat in Plato.md) + +- [Time-Series Forecasting with TimesFM](case-studies/6. Time-Series Forecasting with TimesFM.md) diff --git a/docs/docs/examples/case-studies/6. Time-Series Forecasting with TimesFM.md b/docs/docs/examples/case-studies/6. Time-Series Forecasting with TimesFM.md new file mode 100644 index 000000000..05d62a0d8 --- /dev/null +++ b/docs/docs/examples/case-studies/6. Time-Series Forecasting with TimesFM.md @@ -0,0 +1,210 @@ +# Time-Series Forecasting with TimesFM + +Plato includes a reference workflow for **federated time-series forecasting** with Hugging Face time-series models. The initial case study predicts EV charging availability: each client owns one user's charging history and trains on sliding windows from that user's hourly sequence. + +Reference files: + +- `configs/TimeSeries/timesfm25_ev_charging.toml` +- `configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml` +- `configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml` +- `configs/TimeSeries/patchtsmixer_ev_charging.toml` +- `plato/datasources/ev_charging.py` + +## Dataset preparation + +The configs use the [_Residential electric vehicle charging datasets from apartment buildings_](https://data.mendeley.com/datasets/jbks2rcwyj/1/files/2e3b8ced-9887-4a91-b721-8e510e18a127) [doi: 10.17632/jbks2rcwyj.1]. + +Download `dataset1_ev_charging_reports.csv` and place it at the path used by the configs, for example: + +```text +runtime/data/ado1/dataset1_ev_charging_reports.csv +``` + +The dataset is not bundled with Plato. The datasource expects the raw semicolon-separated CSV and performs the preprocessing at runtime. + +## EVCharging datasource behavior + +Use the datasource with: + +```toml +[data] +datasource = "EVCharging" +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" +garage = "AdO1" +users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] +sampler = "all_inclusive" +``` + +The datasource: + +- filters to one garage, or uses the whole CSV when `garage = "all"`; +- preserves the configured `users` order; +- maps client IDs to users in that order (`client_id = 1` selects the first user); +- builds a continuous hourly grid for each user's active date range; +- marks `is_charging = 1` when a charging session overlaps an hour; +- accumulates `energy_kwh` over active charging hours; +- scales energy with the training-window maximum; +- adds cyclic time features for hour-of-day and day-of-week; +- splits valid sliding-window starts into train / validation / test windows. + +The model input feature order is: + +```text +is_charging, energy_scaled, hour_sin, hour_cos, dow_sin, dow_cos +``` + +The reference configs forecast only the first channel, `is_charging`. + +## Model choices + +### TimesFM + +Select TimesFM through the HuggingFace trainer: + +```toml +[trainer] +type = "HuggingFace" +model_name = "google/timesfm-2.5-200m-transformers" +model_type = "timesfm" +context_length = 672 +prediction_length = 128 +num_input_channels = 6 +prediction_channel_indices = [0] +freq = 0 +``` + +Supported reference variants include: + +- `google/timesfm-2.0-500m-pytorch` +- `google/timesfm-2.5-200m-pytorch` +- `google/timesfm-2.5-200m-transformers` + +The TimesFM reference configs use `prediction_length = 128` because the selected TimesFM checkpoints expose a fixed 128-step native horizon. + +### PatchTSMixer + +PatchTSMixer is useful as a smaller scratch baseline: + +```toml +[trainer] +type = "HuggingFace" +model_name = "patchtsmixer_scratch" +model_type = "patchtsmixer" +model_task = "forecasting" +context_length = 672 +prediction_length = 168 +num_input_channels = 6 +prediction_channel_indices = [0] +mode = "mix_channel" +``` + +Unlike the TimesFM wrapper's channel-independent path, the reference PatchTSMixer config uses `mode = "mix_channel"` so the model can use the time features jointly. + +## Run the reference configs + +Install the normal Plato environment first: + +```bash +uv sync +``` + +Then run one of the configs from the repository root. + +TimesFM 2.5 on the four AdO1 users: + +```bash +uv run plato.py --config configs/TimeSeries/timesfm25_ev_charging.toml +``` + +TimesFM 2.5 on selected high-data users across garages: + +```bash +uv run plato.py --config configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml +``` + +PatchTSMixer scratch baseline: + +```bash +uv run plato.py --config configs/TimeSeries/patchtsmixer_ev_charging.toml +``` + +Single-client TimesFM 2.5 transformers smoke run: + +```bash +uv run plato.py --config configs/TimeSeries/timesfm_transformers_bl1.toml +``` + +## DiLoCo variant + +The branch also includes a TimesFM 2.5 + DiLoCo config: + +```bash +uv run plato.py --config configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml +``` + +The config uses: + +```toml +[server] +type = "diloco" + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[trainer] +local_steps_per_round = 1500 +preserve_optimizer_state = true +``` + +`local_steps_per_round` is counted in completed optimizer steps, not epochs. See the [DiLoCo design contract](../../development/diloco.md) for the mechanics behind this server type. + +## Result logging + +The time-series configs use MSE as the scalar test metric: + +```toml +[results] +types = "round, elapsed_time, mse" +``` + +A lower MSE is better. + +## Troubleshooting + +### Dataset file not found + +Make sure `data.datasource_path` points to the downloaded `dataset1_ev_charging_reports.csv`. Relative paths are resolved from the Plato repository root when using the reference commands above. + +### User not found + +If a configured user is missing, check the `garage` setting. Users from multiple garages require: + +```toml +garage = "all" +``` + +### TimesFM class not available + +TimesFM 2.5 requires a recent `transformers` version that exposes `TimesFm2_5ModelForPrediction`. If model import fails, update the environment and verify the class can be imported before launching a long run. + +### Metric looks like accuracy in old scripts + +For time-series runs, use configs that include: + +```toml +[results] +types = "round, elapsed_time, mse" +``` + +The server and client logs label the primary metric as MSE when the active trainer testing strategy reports `metric_name = "mse"`. + +## Related documentation + +- [Data configuration](../../configurations/data.md) +- [Trainer configuration](../../configurations/trainer.md) +- [Results logging](../../configurations/results.md) +- [DiLoCo design contract](../../development/diloco.md) diff --git a/docs/docs/index.md b/docs/docs/index.md index 1a105dbcf..4c44fd94f 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -39,6 +39,7 @@ Welcome to *Plato*, a software framework to facilitate scalable, reproducible, a - **[SmolVLA Trainer with LeRobot](examples/case-studies/3. SmolVLA Trainer with LeRobot.md)** - **[Server-side Lighteval for SmolLM2](examples/case-studies/4. Server-side Lighteval for SmolLM2.md)** - **[Nanochat in Plato](examples/case-studies/5. Nanochat in Plato.md)** + - **[Time-Series Forecasting with TimesFM](examples/case-studies/6. Time-Series Forecasting with TimesFM.md)** ## Configuration Settings diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index fa5bc9908..4da2e01b2 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -72,6 +72,7 @@ nav: - SmolVLA Trainer with LeRobot: examples/case-studies/3. SmolVLA Trainer with LeRobot.md - Server-side Lighteval for SmolLM2: examples/case-studies/4. Server-side Lighteval for SmolLM2.md - Nanochat in Plato: examples/case-studies/5. Nanochat in Plato.md + - Time-Series Forecasting with TimesFM: examples/case-studies/6. Time-Series Forecasting with TimesFM.md - Configuration Settings: - Overview: configurations/overview.md - General: configurations/general.md @@ -88,7 +89,9 @@ nav: - Servers: references/servers.md - Trainers: references/trainers.md - Evaluators: references/evaluators.md - - Developer's Guide: development.md + - Developer's Guide: + - Overview: development.md + - DiLoCo Design Contract: development/diloco.md - Deployment Guide: deployment.md - Digital Research Alliance of Canada: ccdb.md - Miscellaneous Notes: misc.md diff --git a/examples/model_search/fedrlnas/Darts/architect.py b/examples/model_search/fedrlnas/Darts/architect.py index 22660d3a8..721bb5f53 100644 --- a/examples/model_search/fedrlnas/Darts/architect.py +++ b/examples/model_search/fedrlnas/Darts/architect.py @@ -140,10 +140,12 @@ def _parse(weights): weight_matrix = weights[start:end].copy() edges = sorted( range(i + 2), - key=lambda x, wm=weight_matrix: -max( - wm[x][k] - for k in range(len(wm[x])) - if k != PRIMITIVES.index("none") + key=lambda x, wm=weight_matrix: ( + -max( + wm[x][k] + for k in range(len(wm[x])) + if k != PRIMITIVES.index("none") + ) ), )[:2] for j in edges: diff --git a/examples/model_search/fedrlnas/Darts/operations.py b/examples/model_search/fedrlnas/Darts/operations.py index 48bba9a03..92f4ff99f 100644 --- a/examples/model_search/fedrlnas/Darts/operations.py +++ b/examples/model_search/fedrlnas/Darts/operations.py @@ -11,9 +11,9 @@ 3, stride=stride, padding=1, count_include_pad=False ), "max_pool_3x3": lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), - "skip_connect": lambda C, stride, affine: Identity() - if stride == 1 - else FactorizedReduce(C, C, affine=affine), + "skip_connect": lambda C, stride, affine: ( + Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine) + ), "sep_conv_3x3": lambda C, stride, affine: SepConv( C, C, 3, stride, 1, affine=affine ), diff --git a/examples/model_search/pfedrlnas/DARTS/Darts/architect.py b/examples/model_search/pfedrlnas/DARTS/Darts/architect.py index f9e5d5e6e..108cae151 100644 --- a/examples/model_search/pfedrlnas/DARTS/Darts/architect.py +++ b/examples/model_search/pfedrlnas/DARTS/Darts/architect.py @@ -193,10 +193,12 @@ def _parse(weights): weight_matrix = weights[start:end].copy() edges = sorted( range(i + 2), - key=lambda x, wm=weight_matrix: -max( - wm[x][k] - for k in range(len(wm[x])) - if k != PRIMITIVES.index("none") + key=lambda x, wm=weight_matrix: ( + -max( + wm[x][k] + for k in range(len(wm[x])) + if k != PRIMITIVES.index("none") + ) ), )[:2] for j in edges: diff --git a/examples/model_search/pfedrlnas/DARTS/Darts/operations.py b/examples/model_search/pfedrlnas/DARTS/Darts/operations.py index 48bba9a03..92f4ff99f 100644 --- a/examples/model_search/pfedrlnas/DARTS/Darts/operations.py +++ b/examples/model_search/pfedrlnas/DARTS/Darts/operations.py @@ -11,9 +11,9 @@ 3, stride=stride, padding=1, count_include_pad=False ), "max_pool_3x3": lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), - "skip_connect": lambda C, stride, affine: Identity() - if stride == 1 - else FactorizedReduce(C, C, affine=affine), + "skip_connect": lambda C, stride, affine: ( + Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine) + ), "sep_conv_3x3": lambda C, stride, affine: SepConv( C, C, 3, stride, 1, affine=affine ), diff --git a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/misc/attentive_nas_eval.py b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/misc/attentive_nas_eval.py index 00b2281e7..e8b6a44f7 100644 --- a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/misc/attentive_nas_eval.py +++ b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/misc/attentive_nas_eval.py @@ -77,8 +77,8 @@ def validate( top5_list.append(acc5) head_dim = 8 - func = ( - lambda x: x[0] ** 2 + func = lambda x: ( + x[0] ** 2 * ( x[1] ** 2 * 6 + x[1] ** 2 * 8 diff --git a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/models/attentive_nas_dynamic_model.py b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/models/attentive_nas_dynamic_model.py index b31694e0b..54db8a1e9 100644 --- a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/models/attentive_nas_dynamic_model.py +++ b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/models/attentive_nas_dynamic_model.py @@ -434,8 +434,8 @@ def sample_active_subnet_within_range(self, targeted_min_flops, targeted_max_flo return cfg def _sample_active_subnet(self, min_net=False, max_net=False): - sample_cfg = ( - lambda candidates, sample_min, sample_max: min(candidates) + sample_cfg = lambda candidates, sample_min, sample_max: ( + min(candidates) if sample_min else (max(candidates) if sample_max else random.choice(candidates)) ) @@ -461,8 +461,8 @@ def _sample_active_subnet(self, min_net=False, max_net=False): def mutate_and_reset(self, cfg, prob=0.1, keep_resolution=False): cfg = copy.deepcopy(cfg) - pick_another = ( - lambda x, candidates: x + pick_another = lambda x, candidates: ( + x if len(candidates) == 1 else random.choice([v for v in candidates if v != x]) ) diff --git a/examples/server_aggregation/feddf/feddf_algorithm.py b/examples/server_aggregation/feddf/feddf_algorithm.py index 57164529e..ecab8cfe6 100644 --- a/examples/server_aggregation/feddf/feddf_algorithm.py +++ b/examples/server_aggregation/feddf/feddf_algorithm.py @@ -37,7 +37,9 @@ def aggregate_teacher_logits( "FedDF teacher weighting must be either 'uniform' or 'samples'." ) - total_samples = sum(getattr(update.report, "num_samples", 0) for update in updates) + total_samples = sum( + getattr(update.report, "num_samples", 0) for update in updates + ) use_uniform_average = weighting_name == "uniform" or total_samples <= 0 aggregated = torch.zeros_like(first_logits, dtype=torch.float32) @@ -91,7 +93,9 @@ def distill_weights( inputs.append(extract_batch_inputs(example)) proxy_inputs = torch.stack(inputs) - distillation_dataset = TensorDataset(proxy_inputs, teacher_logits.detach().cpu()) + distillation_dataset = TensorDataset( + proxy_inputs, teacher_logits.detach().cpu() + ) dataloader = DataLoader( distillation_dataset, batch_size=distillation_batch_size, diff --git a/plato/algorithms/fedavg.py b/plato/algorithms/fedavg.py index e1e0b3d97..7328b3c8f 100644 --- a/plato/algorithms/fedavg.py +++ b/plato/algorithms/fedavg.py @@ -22,13 +22,13 @@ class Algorithm(base.Algorithm): def _as_state_mapping(weights: Any, context: str) -> Mapping[str, torch.Tensor]: """Validate and cast a state-dict-like payload.""" if not isinstance(weights, Mapping): - raise TypeError(f"{context} must be a mapping of parameter names to tensors.") + raise TypeError( + f"{context} must be a mapping of parameter names to tensors." + ) return weights @staticmethod - def _to_transport_tensor( - tensor: torch.Tensor, tensor_name: str - ) -> torch.Tensor: + def _to_transport_tensor(tensor: torch.Tensor, tensor_name: str) -> torch.Tensor: """ Convert a tensor to a wire-safe representation for payload transport. @@ -88,7 +88,9 @@ def _compute_tensor_delta( if baseline_weight.dtype == torch.bool: return current_casted.to(torch.int8) - baseline_weight.to(torch.int8) - if torch.is_floating_point(baseline_weight) or torch.is_complex(baseline_weight): + if torch.is_floating_point(baseline_weight) or torch.is_complex( + baseline_weight + ): return current_casted.to(baseline_weight.dtype) - baseline_weight return current_casted.to(torch.int64) - baseline_weight.to(torch.int64) @@ -111,7 +113,9 @@ def _apply_tensor_delta( delta_integral = delta.to(torch.int8) return (baseline_weight.to(torch.int8) + delta_integral).ne(0) - if torch.is_floating_point(baseline_weight) or torch.is_complex(baseline_weight): + if torch.is_floating_point(baseline_weight) or torch.is_complex( + baseline_weight + ): return baseline_weight + delta.to(baseline_weight.dtype) if torch.is_floating_point(delta): @@ -156,7 +160,9 @@ def _estimate_payload_size_bytes(weights: Mapping[str, torch.Tensor]) -> int: size_bytes += tensor.numel() * tensor.element_size() return size_bytes - def _assert_payload_size(self, weights: Mapping[str, torch.Tensor], source: str) -> None: + def _assert_payload_size( + self, weights: Mapping[str, torch.Tensor], source: str + ) -> None: """Enforce an optional payload-size safeguard.""" limit_mb = self._resolve_payload_limit_mb() if limit_mb is None: @@ -175,10 +181,15 @@ def _resolve_adapter_parameter_names( ) -> list[str] | None: """Resolve parameter names to exchange for adapter-only finetuning.""" finetune_mode = getattr(target_model, "plato_finetune_mode", None) - if not isinstance(finetune_mode, str) or finetune_mode.strip().lower() != "adapter": + if ( + not isinstance(finetune_mode, str) + or finetune_mode.strip().lower() != "adapter" + ): return None - trainable_names_attr = getattr(target_model, "plato_trainable_parameter_names", None) + trainable_names_attr = getattr( + target_model, "plato_trainable_parameter_names", None + ) names_from_attr = ( [ name @@ -224,7 +235,9 @@ def compute_weight_deltas( unknown_keys = set(weight_mapping).difference(baseline_mapping) if unknown_keys: unknown = ", ".join(sorted(unknown_keys)) - raise KeyError(f"Received weights include unexpected parameter(s): {unknown}.") + raise KeyError( + f"Received weights include unexpected parameter(s): {unknown}." + ) delta = OrderedDict() for name, current_weight in weight_mapping.items(): @@ -308,7 +321,9 @@ def load_weights(self, weights): unknown_keys = set(weights_mapping).difference(current_state) if unknown_keys: unknown = ", ".join(sorted(unknown_keys)) - raise KeyError(f"Inbound weights include unexpected parameter(s): {unknown}.") + raise KeyError( + f"Inbound weights include unexpected parameter(s): {unknown}." + ) merged_state = OrderedDict(current_state.items()) for name, incoming_tensor in weights_mapping.items(): diff --git a/plato/clients/strategies/defaults.py b/plato/clients/strategies/defaults.py index deebe91c0..eede64400 100644 --- a/plato/clients/strategies/defaults.py +++ b/plato/clients/strategies/defaults.py @@ -339,7 +339,14 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: if context.sio is not None: await context.sio.disconnect() - if hasattr(Config().trainer, "target_perplexity"): + metric_name = getattr( + getattr(context.trainer, "testing_strategy", None), + "metric_name", + "accuracy", + ) + if metric_name == "mse": + LOGGER.info("[%s] Test MSE: %.6f", context, accuracy) + elif hasattr(Config().trainer, "target_perplexity"): LOGGER.info("[%s] Test perplexity: %.2f", context, accuracy) else: LOGGER.info("[%s] Test accuracy: %.2f%%", context, 100 * accuracy) diff --git a/plato/datasources/ev_charging.py b/plato/datasources/ev_charging.py new file mode 100644 index 000000000..d992814b4 --- /dev/null +++ b/plato/datasources/ev_charging.py @@ -0,0 +1,444 @@ +""" +EV Charging Datasource for Federated Time-Series Forecasting. + +Dataset: "EV Charging Reports" – Mendeley Data (dataset1_ev_charging_reports.csv) + https://data.mendeley.com/datasets/jbks2rcwyj/1 + +Raw CSV format: + session_ID ; Garage_ID ; User_ID ; User_type ; Shared_ID ; + Start_plugin ; Start_plugin_hour ; End_plugout ; End_plugout_hour ; + El_kWh ; Duration_hours ; month_plugin ; weekdays_plugin ; + Plugin_category ; Duration_category + - Datetimes use DD.MM.YYYY HH:MM format. + - El_kWh uses a comma as the decimal separator. + +Preprocessing pipeline +----------------------- +1. Filter to the requested garage (default "AdO1", which has 4 private users), + or use all garages when ``garage = "all"``. +2. For each user, build a continuous hourly grid from the first to the last + session hour in the dataset. +3. For every hour, mark is_charging = 1 if the user had an active session, + else 0; accumulate energy_kwh proportionally over session hours. +4. Scale energy_kwh ∈ [0, 1] using the training-split maximum. +5. Add cyclic time encodings: + hour_sin = sin(2π · hour / 24) hour_cos = cos(2π · hour / 24) + dow_sin = sin(2π · dow / 7) dow_cos = cos(2π · dow / 7) +6. Split temporally: 70 % train, 15 % val, 15 % test. +7. Build sliding-window samples: + past_values : (context_length, 6) : all features + future_values : (prediction_length, 1) : is_charging only + +Federated split +--------------- +Each client sees only its own user's data. + +TOML configuration +------------------ +[data] +datasource = "EVCharging" +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" +garage = "AdO1" # optional; use "all" for cross-garage user lists +num_users = 4 # optional + +[trainer] +context_length = 168 # 7 * 24 h +prediction_length = 168 # 7 * 24 h +train_ratio = 0.70 +val_ratio = 0.15 +stride = 1 # slide 1 hour at a time +""" + +from __future__ import annotations + +import logging +import os + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from plato.config import Config + +# Exact column names from the Mendeley CSV +_CSV_SEP = ";" +_GARAGE_COL = "Garage_ID" +_USER_COL = "User_ID" +_START_COL = "Start_plugin" +_END_COL = "End_plugout" +_ENERGY_COL = "El_kWh" +_DT_FORMAT = "%d.%m.%Y %H:%M" + + +# Preprocessing helpers +def _parse_european_float(series: pd.Series) -> pd.Series: + """Replace comma decimal separator and coerce to float.""" + return ( + series.astype(str) + .str.replace(",", ".", regex=False) + .str.strip() + .pipe(pd.to_numeric, errors="coerce") + .fillna(0.0) + ) + + +def _build_hourly_series( + df: pd.DataFrame, + garage: str | None, + num_users: int, + user_ids: list[str] | None = None, +) -> dict[str, pd.DataFrame]: + """ + Build per-user hourly DataFrames from raw session records. + + Parameters: + user_ids : explicit list of User_ID strings to include. + num_users : max number of users to take alphabetically when ``user_ids`` + is not given. + + Returns: + dict mapping user_id (str) -> pd.DataFrame with hourly index and columns: + is_charging (0/1 float), energy_kwh (float >= 0) + """ + garage_name = None if garage is None else str(garage).strip() + use_all_garages = not garage_name or garage_name.lower() in {"all", "*", "any"} + + if use_all_garages: + df = df.copy() + else: + available_garages = sorted( + df[_GARAGE_COL].astype(str).str.strip().dropna().unique() + ) + mask = df[_GARAGE_COL].astype(str).str.strip() == garage_name + df = df[mask].copy() + if df.empty: + raise ValueError( + f"No records found for garage '{garage_name}'. " + f"Available: {available_garages}" + ) + + # Parse datetimes + df[_START_COL] = pd.to_datetime(df[_START_COL].str.strip(), format=_DT_FORMAT) + df[_END_COL] = pd.to_datetime(df[_END_COL].str.strip(), format=_DT_FORMAT) + df[_ENERGY_COL] = _parse_european_float(df[_ENERGY_COL]) + + # Drop invalid rows + df = df.dropna(subset=[_START_COL, _END_COL]) + df = df[df[_END_COL] > df[_START_COL]] + + # Resolve user list + available = sorted(df[_USER_COL].dropna().unique()) + if user_ids is not None: + # Explicit list from config — validate each entry + missing = [u for u in user_ids if u not in available] + if missing: + scope = "all garages" if use_all_garages else f"garage '{garage_name}'" + raise ValueError( + f"Users not found in {scope}: {missing}. Available: {available}" + ) + users = list(user_ids) # preserve config order + else: + users = available[:num_users] + scope = "all garages" if use_all_garages else f"garage '{garage_name}'" + logging.info("EVCharging: %s -> users %s", scope, users) + + result: dict[str, pd.DataFrame] = {} + for user in users: + udf = df[df[_USER_COL] == user] + + # Per-user hourly index: only spans that user's own activity window. + # Using a global index would pad every user with the same number of + # zero-charging hours, giving all clients identical dataset sizes. + user_start = udf[_START_COL].min().floor("h") + user_end = udf[_END_COL].max().ceil("h") + hourly_index = pd.date_range(user_start, user_end, freq="h") + + is_charging = pd.Series(0.0, index=hourly_index) + energy_kwh = pd.Series(0.0, index=hourly_index) + + for _, row in udf.iterrows(): + # All hours touched by this session + session_hours = pd.date_range( + row[_START_COL].floor("h"), + row[_END_COL].floor("h"), + freq="h", + ) + valid_hours = session_hours[session_hours.isin(hourly_index)] + if valid_hours.empty: + continue + is_charging[valid_hours] = 1.0 + energy_per_hour = float(row[_ENERGY_COL]) / max(len(valid_hours), 1) + energy_kwh[valid_hours] += energy_per_hour + + user_df = pd.DataFrame( + {"is_charging": is_charging, "energy_kwh": energy_kwh}, + index=hourly_index, + ) + user_df.index.name = "timestamp" + result[user] = user_df + + return result + + +def _add_time_features(df: pd.DataFrame) -> pd.DataFrame: + """Append cyclic hour-of-day and day-of-week columns.""" + hour = df.index.hour.astype(float) + dow = df.index.dayofweek.astype(float) + df = df.copy() + df["hour_sin"] = np.sin(2 * np.pi * hour / 24) + df["hour_cos"] = np.cos(2 * np.pi * hour / 24) + df["dow_sin"] = np.sin(2 * np.pi * dow / 7) + df["dow_cos"] = np.cos(2 * np.pi * dow / 7) + return df + + +# Ordered feature columns fed into the model +_FEATURE_COLS = [ + "is_charging", + "energy_scaled", + "hour_sin", + "hour_cos", + "dow_sin", + "dow_cos", +] + + +# Torch Dataset +class _EVChargingDataset(Dataset): + """Sliding-window samples for one user / one split. + + Each sample: + past_values : FloatTensor (context_length, 6) + future_values : FloatTensor (prediction_length, 1) ← is_charging only + """ + + def __init__( + self, + data: np.ndarray, # shape (T, 6), already normalized + context_length: int, + prediction_length: int, + stride: int = 1, + starts: list[int] | None = None, # explicit window start indices + ): + super().__init__() + self.data = torch.FloatTensor(data) + self.context_length = context_length + self.prediction_length = prediction_length + if starts is not None: + # Caller already computed and partitioned the valid starts. + self.indices = starts + else: + total = context_length + prediction_length + max_start = len(data) - total + if max_start < 0: + logging.warning( + "EVCharging: data has only %d steps but needs %d " + "(context=%d + prediction=%d) — dataset will be empty.", + len(data), + total, + context_length, + prediction_length, + ) + self.indices = [] + else: + self.indices = list(range(0, max_start + 1, stride)) + + def __len__(self) -> int: + return len(self.indices) + + def __getitem__(self, idx: int) -> dict: + s = self.indices[idx] + e_ctx = s + self.context_length + e_pred = e_ctx + self.prediction_length + return { + "past_values": self.data[s:e_ctx], # (ctx, 6) + "future_values": self.data[e_ctx:e_pred, :1], # (pred, 1) + } + + +# Plato DataSource +class DataSource: + """EV Charging DataSource for Plato federated learning. + + Each instance represents ONE user (client_id selects the user, 0-indexed + over the alphabetically sorted user list for the requested garage). + + Typical config (timesfm_ev_charging.toml): + + [data] + datasource = "EVCharging" + datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + garage = "AdO1" + num_users = 4 + + [trainer] + context_length = 168 + prediction_length = 168 + train_ratio = 0.70 + val_ratio = 0.15 + stride = 24 + """ + + def __init__(self, client_id: int = 0, **kwargs): + cfg = Config() + data_cfg = cfg.data + trainer_cfg = cfg.trainer + + # Locate CSV + csv_path = kwargs.get( + "datasource_path", + getattr(data_cfg, "datasource_path", None), + ) + if csv_path is None: + raise ValueError( + "EVCharging requires 'datasource_path' in [data] config, " + 'e.g. datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv"' + ) + if not os.path.isabs(csv_path): + csv_path = os.path.join(os.getcwd(), csv_path) + if not os.path.exists(csv_path): + raise FileNotFoundError( + f"EV charging CSV not found: {csv_path}\n" + "Download from https://data.mendeley.com/datasets/jbks2rcwyj/1" + ) + + garage_cfg = kwargs.get("garage", getattr(data_cfg, "garage", "AdO1")) + garage = None if garage_cfg is None else str(garage_cfg) + garage_name = None if garage is None else garage.strip() + + # Config: users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] + user_ids_cfg = kwargs.get("users", getattr(data_cfg, "users", None)) + if user_ids_cfg is not None: + user_ids: list[str] | None = [str(u) for u in user_ids_cfg] + num_users = len(user_ids) + else: + user_ids = None + num_users = int(kwargs.get("num_users", getattr(data_cfg, "num_users", 4))) + + # Window / split settings + self.context_length = int(getattr(trainer_cfg, "context_length", 168)) + self.prediction_length = int(getattr(trainer_cfg, "prediction_length", 168)) + train_ratio = float(getattr(trainer_cfg, "train_ratio", 0.70)) + val_ratio = float(getattr(trainer_cfg, "val_ratio", 0.15)) + stride = int(getattr(trainer_cfg, "stride", 1)) + + # Load and preprocess + logging.info("EVCharging: loading %s", csv_path) + raw_df = pd.read_csv(csv_path, sep=_CSV_SEP, low_memory=False) + + user_series = _build_hourly_series( + raw_df, garage=garage, num_users=num_users, user_ids=user_ids + ) + + # Preserve config-specified order when user_ids is given + if user_ids is not None: + users = [u for u in user_ids if u in user_series] + else: + users = sorted(user_series.keys()) + + user_index = max(0, client_id - 1) + + if user_index >= len(users): + scope = ( + "all garages" + if not garage_name or garage_name.lower() in {"all", "*", "any"} + else f"garage '{garage_name}'" + ) + raise ValueError( + f"client_id={client_id} out of range; " + f"found {len(users)} users in {scope}: {users}" + ) + + user_key = users[user_index] + logging.info("EVCharging: client_id=%d -> user '%s'", client_id, user_key) + + user_df = _add_time_features(user_series[user_key]) + raw_array = user_df[ + [ + "is_charging", + "energy_kwh", + "hour_sin", + "hour_cos", + "dow_sin", + "dow_cos", + ] + ].values.astype(np.float32) + + # Split window indices, not raw hours + window = self.context_length + self.prediction_length + all_starts = list(range(0, len(raw_array) - window + 1, stride)) + n_windows = len(all_starts) + + n_train_w = max(1, int(n_windows * train_ratio)) + n_val_w = max(0, int(n_windows * val_ratio)) + train_starts = all_starts[:n_train_w] + val_starts = all_starts[n_train_w : n_train_w + n_val_w] + test_starts = all_starts[n_train_w + n_val_w :] + + # Energy scaling + if train_starts: + train_end = min(len(user_df), train_starts[-1] + window) + else: + train_end = max(1, int(len(user_df) * train_ratio)) + energy_max = float(user_df["energy_kwh"].iloc[:train_end].max()) or 1.0 + + full_array = raw_array.copy() + full_array[:, 1] = full_array[:, 1] / energy_max # -> energy_scaled in [0, 1] + + # Keep the full normalized array for inference scripts + self.normalized_data = full_array + self.timestamps = user_df.index + self.user_id = user_key + self.feature_columns = list(_FEATURE_COLS) + self.split_window_starts = { + "train": list(train_starts), + "val": list(val_starts), + "test": list(test_starts), + } + + self._train_set = _EVChargingDataset( + full_array, + self.context_length, + self.prediction_length, + stride=stride, + starts=train_starts, + ) + self._val_set = _EVChargingDataset( + full_array, + self.context_length, + self.prediction_length, + stride=stride, + starts=val_starts, + ) + self._test_set = _EVChargingDataset( + full_array, + self.context_length, + self.prediction_length, + stride=stride, + starts=test_starts, + ) + + logging.info( + "EVCharging user '%s': %d train / %d val / %d test windows", + user_key, + len(self._train_set), + len(self._val_set), + len(self._test_set), + ) + + # Plato DataSource interface + def get_train_set(self) -> _EVChargingDataset: + return self._train_set + + def get_val_set(self) -> _EVChargingDataset: + return self._val_set + + def get_test_set(self) -> _EVChargingDataset: + return self._test_set + + def num_train_examples(self) -> int: + return len(self._train_set) + + def num_test_examples(self) -> int: + return len(self._test_set) diff --git a/plato/datasources/huggingface.py b/plato/datasources/huggingface.py index 49d9582c3..7240561b8 100644 --- a/plato/datasources/huggingface.py +++ b/plato/datasources/huggingface.py @@ -207,7 +207,9 @@ def __init__(self, **kwargs): if isinstance(tokenizer_name, str) and tokenizer_name else Config().trainer.model_name ) - auth_token = getattr(getattr(Config(), "parameters", None), "huggingface_token", None) + auth_token = getattr( + getattr(Config(), "parameters", None), "huggingface_token", None + ) config_kwargs = { "cache_dir": Config().params["model_path"], "revision": "main", @@ -325,7 +327,11 @@ def preprocess_corpus_lm(self, dataset_split): ) configured_block_size = getattr(Config().data, "block_size", None) - block_size = configured_block_size if configured_block_size is not None else self.block_size + block_size = ( + configured_block_size + if configured_block_size is not None + else self.block_size + ) block_size = int(block_size) if block_size > 1024: logging.warning( @@ -364,9 +370,7 @@ def _build_chat_labels( if self.label_strategy == "full_sequence": return list(input_ids) if self.label_strategy != "assistant_only": - raise ValueError( - f"Unsupported chat label strategy: {self.label_strategy}" - ) + raise ValueError(f"Unsupported chat label strategy: {self.label_strategy}") if not hasattr(self.tokenizer, "apply_chat_template"): raise AttributeError( diff --git a/plato/datasources/lerobot.py b/plato/datasources/lerobot.py index ef2cd6654..679cb3e02 100644 --- a/plato/datasources/lerobot.py +++ b/plato/datasources/lerobot.py @@ -97,7 +97,7 @@ def _import_lerobot() -> tuple[Any, Any]: raise ImportError( "LeRobot datasource requires optional LeRobot / SmolVLA robotics dependencies. " "Install the robotics stack in the active environment before using " - '"data.datasource = \"LeRobot\"". ' + '"data.datasource = "LeRobot"". ' ) from exc return LeRobotDataset, LeRobotDatasetMetadata @@ -362,7 +362,9 @@ def _resolve_task_name(row: Mapping[str, Any], tasks_lookup: Any) -> str | None: return None -def _resolve_episode_tasks(metadata: Any, episodes: Sequence[int]) -> dict[int, str | None]: +def _resolve_episode_tasks( + metadata: Any, episodes: Sequence[int] +) -> dict[int, str | None]: episode_tasks = {episode: None for episode in episodes} episode_rows = _episode_rows(getattr(metadata, "episodes", None)) tasks_lookup = _to_plain(getattr(metadata, "tasks", None)) @@ -473,10 +475,14 @@ def _resolve_episode_split( episode_set = set(int(episode) for episode in all_episodes) if explicit_train is None and explicit_test is None: - return _split_episodes(all_episodes, episode_tasks, train_ratio, seed, task_aware) + return _split_episodes( + all_episodes, episode_tasks, train_ratio, seed, task_aware + ) train_episodes = [ - int(episode) for episode in (explicit_train or []) if int(episode) in episode_set + int(episode) + for episode in (explicit_train or []) + if int(episode) in episode_set ] test_episodes = [ int(episode) @@ -569,7 +575,9 @@ def _resolve_total_clients(config: Any) -> int: return total_clients -def _filter_constructor_kwargs(dataset_cls: Any, kwargs: Mapping[str, Any]) -> dict[str, Any]: +def _filter_constructor_kwargs( + dataset_cls: Any, kwargs: Mapping[str, Any] +) -> dict[str, Any]: try: signature = inspect.signature(dataset_cls.__init__) except (TypeError, ValueError): @@ -582,9 +590,7 @@ def _filter_constructor_kwargs(dataset_cls: Any, kwargs: Mapping[str, Any]) -> d if accepts_var_kwargs: return dict(kwargs) - valid_parameters = { - name for name in signature.parameters.keys() if name != "self" - } + valid_parameters = {name for name in signature.parameters.keys() if name != "self"} filtered = {key: value for key, value in kwargs.items() if key in valid_parameters} dropped = sorted(set(kwargs.keys()) - set(filtered.keys())) @@ -646,8 +652,7 @@ def __init__(self, client_id: int = 0, **kwargs): repo_id = str(dataset_cfg.pop("repo_id", "")).strip() if not repo_id: raise ValueError( - "LeRobot datasource requires " - '"parameters.dataset.repo_id" to be set.' + 'LeRobot datasource requires "parameters.dataset.repo_id" to be set.' ) train_split_raw = dataset_cfg.pop("train_split", _DEFAULT_TRAIN_SPLIT) diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index 0a609d5e5..990b32ed5 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -8,6 +8,7 @@ from plato.config import Config from plato.datasources import ( cinic10, + ev_charging, feature, femnist, huggingface, @@ -32,7 +33,11 @@ "Nanochat": nanochat, } -registered_partitioned_datasources = {"FEMNIST": femnist, "LeRobot": lerobot} +registered_partitioned_datasources = { + "FEMNIST": femnist, + "LeRobot": lerobot, + "EVCharging": ev_charging, # per-user split; client_id selects the user +} _datasource_aliases = { "STL10": ("Torchvision", {"dataset_name": "STL10"}), diff --git a/plato/evaluators/lighteval.py b/plato/evaluators/lighteval.py index 2c623b4ba..bd51e1986 100644 --- a/plato/evaluators/lighteval.py +++ b/plato/evaluators/lighteval.py @@ -98,9 +98,7 @@ def __exit__(self, exc_type, exc, exc_tb) -> None: def advance(self, message: str) -> None: self._current += 1 - logging.info( - "[Lighteval] %s (%d/%d).", message, self._current, self.total - ) + logging.info("[Lighteval] %s (%d/%d).", message, self._current, self.total) if self._bar is not None: self._bar.set_postfix_str(message) self._bar.update(1) diff --git a/plato/evaluators/lighteval_tasks.py b/plato/evaluators/lighteval_tasks.py index dc9731852..e4fcf33a1 100644 --- a/plato/evaluators/lighteval_tasks.py +++ b/plato/evaluators/lighteval_tasks.py @@ -15,7 +15,10 @@ def piqa_hf_prompt(line, task_name: str | None = None): query = "The following are multiple choice questions (with answers) about common sense.\n" query += f"Question: {line['goal']}\n" query += "".join( - [f"{key}. {choice}\n" for key, choice in zip(letters, [line["sol1"], line["sol2"]])] + [ + f"{key}. {choice}\n" + for key, choice in zip(letters, [line["sol1"], line["sol2"]]) + ] ) query += "Answer: " diff --git a/plato/evaluators/nanochat_core.py b/plato/evaluators/nanochat_core.py index d22292315..961a20374 100644 --- a/plato/evaluators/nanochat_core.py +++ b/plato/evaluators/nanochat_core.py @@ -319,8 +319,14 @@ def _resolve_tokenizer(model) -> Any: def _safe_evaluate_task( - model, tokenizer, data, device, task_meta, label, - evaluate_task_fn, evaluate_example_fn, + model, + tokenizer, + data, + device, + task_meta, + label, + evaluate_task_fn, + evaluate_example_fn, ): """Wrap upstream ``evaluate_task`` so that examples whose tokenized prompts exceed the model's ``max_seq_len`` are gracefully skipped @@ -446,8 +452,14 @@ def run_core_evaluation( data = data[:max_per_task] accuracy = _safe_evaluate_task( - model, eval_tokenizer, data, model_device, task_meta, label, - evaluate_task, evaluate_example, + model, + eval_tokenizer, + data, + model_device, + task_meta, + label, + evaluate_task, + evaluate_example, ) if accuracy is None: # All examples were skipped (too long for model's max_seq_len). diff --git a/plato/evaluators/runner.py b/plato/evaluators/runner.py index e9b284f7d..7f2657f93 100644 --- a/plato/evaluators/runner.py +++ b/plato/evaluators/runner.py @@ -26,7 +26,9 @@ def _configured_evaluator_type() -> str | None: evaluator_type = evaluation_cfg.get("type") else: evaluator_type = getattr(evaluation_cfg, "type", None) - return evaluator_type if isinstance(evaluator_type, str) and evaluator_type else None + return ( + evaluator_type if isinstance(evaluator_type, str) and evaluator_type else None + ) def _evaluation_fail_on_error() -> bool: diff --git a/plato/models/huggingface.py b/plato/models/huggingface.py index 4a8d3aafb..929b3ae4f 100644 --- a/plato/models/huggingface.py +++ b/plato/models/huggingface.py @@ -5,11 +5,41 @@ from __future__ import annotations import logging -from typing import Any, Dict +from typing import Any, Callable, Dict +import torch +import torch.nn as nn +import torch.nn.functional as F from transformers import AutoConfig, AutoModelForCausalLM from plato.config import Config +from plato.utils.timeseries_utils import is_timeseries_model + +try: + from transformers import ( + PatchTSMixerConfig, + PatchTSMixerForPrediction, + PatchTSMixerForPretraining, + PatchTSMixerForRegression, + PatchTSMixerForTimeSeriesClassification, + ) +except ImportError: + PatchTSMixerConfig = None + PatchTSMixerForPrediction = None + PatchTSMixerForTimeSeriesClassification = None + PatchTSMixerForRegression = None + PatchTSMixerForPretraining = None + +try: + from transformers import TimesFmConfig, TimesFmModelForPrediction +except ImportError: + TimesFmConfig = None + TimesFmModelForPrediction = None + +try: + from transformers import TimesFm2_5ModelForPrediction +except ImportError: + TimesFm2_5ModelForPrediction = None try: from peft import LoraConfig, get_peft_model @@ -35,18 +65,351 @@ def _lora_config_dict(lora_config: Any) -> dict[str, Any]: raise TypeError("Unsupported LoRA configuration format.") +class _TimesFmOutput: + """Output container compatible with the Plato time-series training/testing pipeline.""" + + def __init__(self, loss=None, prediction_outputs=None): + self.loss = loss + self.prediction_outputs = prediction_outputs + + +class TimesFmMultivariateWrapper(nn.Module): + """Wraps TimesFmModelForPrediction for batched, multivariate time series. + + TimesFM is natively univariate (each call takes a list of 1-D tensors). + This wrapper accepts a standard batched tensor of shape + ``(batch, context_length)`` or ``(batch, context_length, channels)`` + and handles the reshaping transparently so the rest of the Plato pipeline + (collators, training/testing strategies) needs no changes. + + For multivariate input, every channel is processed independently through + the same TimesFM model (channel-independent forecasting). The outputs are + recombined into ``(batch, prediction_length, channels)``. + + If ``future_values`` is provided the wrapper computes MSE loss against it + and stores it in ``.loss``. ``prediction_outputs`` always holds the mean + predictions in ``(batch, prediction_length, out_channels)`` form. + + Args: + model: An instantiated ``TimesFmModelForPrediction``. + prediction_length: Number of future steps to keep. Predictions are + truncated to this length when the model's ``horizon_length`` + differs from the configured ``prediction_length``. + default_freq: Default frequency token (0 = high/hourly, + 1 = medium/daily-weekly, 2 = low/monthly-yearly). + """ + + def __init__( + self, + model: "TimesFmModelForPrediction", + prediction_length: int | None = None, + default_freq: int = 0, + use_transformers_api: bool = False, + ): + super().__init__() + self.model = model + self.prediction_length = prediction_length + self.default_freq = default_freq + self.use_transformers_api = use_transformers_api + + def forward( + self, + past_values: torch.Tensor, + future_values: torch.Tensor | None = None, + freq: int | list | torch.Tensor | None = None, + return_dict: bool = True, # accepted for API compat, ignored internally + **kwargs, + ) -> _TimesFmOutput: + if not isinstance(past_values, torch.Tensor): + raise TypeError("past_values must be a torch.Tensor") + + if past_values.dim() == 3: + # Multivariate path + batch, ctx, channels = past_values.shape + # (batch, ctx, ch) -> (batch*ch, ctx) + pv_2d = past_values.permute(0, 2, 1).reshape(batch * channels, ctx) + past_list = [pv_2d[i] for i in range(pv_2d.size(0))] + + if self.use_transformers_api: + outputs = self.model(past_values=past_list, forecast_context_len=ctx) + else: + freq_list = self._build_freq_list(freq, batch, channels) + outputs = self.model(past_values=past_list, freq=freq_list) + + # (batch*ch, horizon) -> (batch, horizon, ch) + raw = outputs.mean_predictions + horizon = raw.shape[-1] + mean_preds = raw.reshape(batch, channels, horizon).permute(0, 2, 1) + + else: + # Univariate path + batch = past_values.size(0) + ctx = past_values.size(1) + past_list = [past_values[i] for i in range(batch)] + + if self.use_transformers_api: + outputs = self.model(past_values=past_list, forecast_context_len=ctx) + else: + freq_list = self._build_freq_list(freq, batch, channels=1) + outputs = self.model(past_values=past_list, freq=freq_list) + + mean_preds = outputs.mean_predictions.unsqueeze(-1) # (batch, horizon, 1) + + # Truncate to configured prediction_length + if self.prediction_length is not None: + mean_preds = mean_preds[:, : self.prediction_length, :] + + # Compute MSE loss when targets are provided + loss = None + if future_values is not None: + fv = future_values + if fv.dim() == 2: + fv = fv.unsqueeze(-1) # (batch, pred) -> (batch, pred, 1) + min_len = min(mean_preds.shape[1], fv.shape[1]) + loss = F.mse_loss( + mean_preds[:, :min_len, : fv.shape[-1]], + fv[:, :min_len, :], + ) + + return _TimesFmOutput(loss=loss, prediction_outputs=mean_preds) + + def _build_freq_list( + self, + freq: int | list | torch.Tensor | None, + batch: int, + channels: int, + ) -> list[int]: + n = batch * channels + if freq is None: + return [self.default_freq] * n + if isinstance(freq, int): + return [freq] * n + if isinstance(freq, torch.Tensor): + freq = freq.tolist() + # freq is list of length batch; expand for each channel + return [int(f) for f in freq for _ in range(channels)] + + +# --------------------------------------------------------------------------- +# Time-series model loaders +# +# To add a new HuggingFace time series model: +# 1. Implement a loader function with signature: +# def _load_(resolved_model_name, cache_dir, **kwargs) -> nn.Module +# 2. Register it below in _TIMESERIES_LOADERS. +# 3. Add the model type string to TIMESERIES_MODEL_TYPES in +# plato/utils/timeseries_utils.py. +# --------------------------------------------------------------------------- + + +def _create_timesfm_from_config(trainer_config, prediction_length: int) -> nn.Module: + """Instantiate a fresh TimesFmModelForPrediction from TOML trainer settings.""" + if TimesFmConfig is None or TimesFmModelForPrediction is None: + raise ImportError( + "TimesFM models are not available. " + "Ensure you have transformers>=5.0.0 installed." + ) + context_length = getattr(trainer_config, "context_length", 512) + config = TimesFmConfig( + context_length=context_length, + horizon_length=prediction_length, + patch_length=getattr(trainer_config, "patch_length", 32), + num_hidden_layers=getattr(trainer_config, "num_hidden_layers", 20), + hidden_size=getattr(trainer_config, "hidden_size", 1280), + intermediate_size=getattr(trainer_config, "intermediate_size", 1280), + num_attention_heads=getattr(trainer_config, "num_attention_heads", 16), + head_dim=getattr(trainer_config, "head_dim", 80), + attention_dropout=getattr(trainer_config, "dropout", 0.0), + ) + return TimesFmModelForPrediction(config) + + +def _load_timesfm(resolved_model_name: str, cache_dir: str, **kwargs) -> nn.Module: + """Load or create a TimesFM model wrapped for batched multivariate use. + + Model class selection is based on version, not checkpoint format: + + - TimesFM 2.5 (``2.5`` in the name): ``TimesFm2_5ModelForPrediction``. + Both ``*-pytorch`` and ``*-transformers`` checkpoints are supported by + this class. The forward API differs by suffix: + - ``*-transformers``: uses ``forecast_context_len`` + - ``*-pytorch`` (and others): uses ``freq`` + - TimesFM 1.0 / unspecified: ``TimesFmModelForPrediction`` with ``freq``. + Falls back to config-based creation if the checkpoint is not found. + """ + name_lower = resolved_model_name.lower() + is_v25 = "2.5" in name_lower + # Controls which forward kwarg to use, not which class to load. + use_transformers_api = "transformers" in name_lower + + trainer_config = Config().trainer + prediction_length = getattr(trainer_config, "prediction_length", 128) + default_freq = getattr(trainer_config, "freq", 0) + + if is_v25: + if TimesFm2_5ModelForPrediction is None: + raise ImportError( + "TimesFm2_5ModelForPrediction is not available. " + "Ensure you have a recent transformers version installed." + ) + logging.info( + "Loading pretrained TimesFM 2.5 model: %s", + resolved_model_name, + ) + inner = TimesFm2_5ModelForPrediction.from_pretrained( + resolved_model_name, cache_dir=cache_dir + ) + logging.info("Successfully loaded pretrained TimesFM 2.5 model") + else: + if TimesFmModelForPrediction is None: + raise ImportError( + "TimesFM models are not available. " + "Ensure you have transformers>=5.0.0 installed." + ) + try: + logging.info( + "Attempting to load pretrained TimesFM model: %s", + resolved_model_name, + ) + inner = TimesFmModelForPrediction.from_pretrained( + resolved_model_name, cache_dir=cache_dir + ) + logging.info("Successfully loaded pretrained TimesFM model") + except (OSError, ValueError, Exception): + logging.info( + "TimesFM model '%s' not found as pretrained, creating from config", + resolved_model_name, + ) + inner = _create_timesfm_from_config(trainer_config, prediction_length) + + return TimesFmMultivariateWrapper( + model=inner, + prediction_length=prediction_length, + default_freq=default_freq, + use_transformers_api=use_transformers_api, + ) + + +def _load_patchtsmixer(resolved_model_name: str, cache_dir: str, **kwargs) -> nn.Module: + """Load or create a PatchTSMixer model.""" + if PatchTSMixerForPrediction is None: + raise ImportError( + "PatchTSMixer models are not available. " + "Ensure you have transformers>=4.35.0 installed." + ) + + trainer_config = Config().trainer + model_task = ( + kwargs.get("model_task") + or getattr(trainer_config, "model_task", None) + or getattr(trainer_config, "task_type", "forecasting") + ) + + task_models = { + "classification": PatchTSMixerForTimeSeriesClassification, + "regression": PatchTSMixerForRegression, + "pretraining": PatchTSMixerForPretraining, + "forecasting": PatchTSMixerForPrediction, + } + model_class = task_models.get(model_task, PatchTSMixerForPrediction) + + try: + logging.info( + "Attempting to load pretrained PatchTSMixer model: %s", + resolved_model_name, + ) + model = model_class.from_pretrained(resolved_model_name, cache_dir=cache_dir) + logging.info("Successfully loaded pretrained model") + except (OSError, ValueError, Exception): + logging.info( + "Model '%s' not found as pretrained, creating from config settings", + resolved_model_name, + ) + scaling_param = getattr(trainer_config, "scaling", "std") + if isinstance(scaling_param, str) and scaling_param.lower() == "none": + scaling_param = None + + config = PatchTSMixerConfig( + context_length=getattr(trainer_config, "context_length", 512), + prediction_length=getattr(trainer_config, "prediction_length", 96), + num_input_channels=getattr(trainer_config, "num_input_channels", 7), + patch_length=getattr(trainer_config, "patch_length", 8), + patch_stride=getattr(trainer_config, "patch_stride", 8), + d_model=getattr(trainer_config, "d_model", 64), + num_layers=getattr(trainer_config, "num_layers", 8), + expansion_factor=getattr(trainer_config, "expansion_factor", 2), + dropout=getattr(trainer_config, "dropout", 0.2), + head_dropout=getattr(trainer_config, "head_dropout", 0.2), + mode=getattr(trainer_config, "mode", "common_channel"), + gated_attn=getattr(trainer_config, "gated_attn", True), + scaling=scaling_param, + prediction_channel_indices=getattr( + trainer_config, "prediction_channel_indices", None + ), + ) + + if model_task == "classification": + config.num_labels = getattr(trainer_config, "num_classes", 2) + model = PatchTSMixerForTimeSeriesClassification(config) + elif model_task == "regression": + config.num_targets = getattr(trainer_config, "num_targets", 1) + model = PatchTSMixerForRegression(config) + elif model_task == "pretraining": + model = PatchTSMixerForPretraining(config) + else: + model = PatchTSMixerForPrediction(config) + + return model + + +# Registry mapping model_type (lowercase) -> loader function. +# This is the only place that needs updating when a new HF time series model +# is added (along with TIMESERIES_MODEL_TYPES in timeseries_utils.py). +_TIMESERIES_LOADERS: Dict[str, Callable[..., nn.Module]] = { + "timesfm": _load_timesfm, + "patchtsmixer": _load_patchtsmixer, +} + + class Model: - """The CausalLM model loaded from HuggingFace.""" + """The HuggingFace model factory supporting various model types.""" + + @staticmethod + def _get_timeseries_model( + resolved_model_name: str, cache_dir: str, model_type: str = "", **kwargs + ) -> nn.Module: + """Unified entry point for all HuggingFace time series models. + + Dispatches to the appropriate loader in ``_TIMESERIES_LOADERS`` based + on ``model_type`` or a substring match in ``resolved_model_name``. + """ + model_type_lower = model_type.lower() + model_name_lower = resolved_model_name.lower() + + loader = None + for ts_type, ts_loader in _TIMESERIES_LOADERS.items(): + if model_type_lower == ts_type or ts_type in model_name_lower: + loader = ts_loader + break + + if loader is None: + raise ValueError( + f"No time series loader found for model '{resolved_model_name}' " + f"(type='{model_type}'). " + "Register a loader in _TIMESERIES_LOADERS in plato/models/huggingface.py " + "and add the type to TIMESERIES_MODEL_TYPES in plato/utils/timeseries_utils.py." + ) + + return loader(resolved_model_name, cache_dir, **kwargs) @staticmethod def get(model_name=None, **kwargs): # pylint: disable=unused-argument - """Returns a named model from HuggingFace.""" - config_kwargs = { - "cache_dir": None, - "revision": "main", - "use_auth_token": None, - } + """Returns a named model from HuggingFace. + Two paths: + - Time series models -> ``_get_timeseries_model()`` + - All other models -> ``AutoModelForCausalLM`` (with optional LoRA) + """ resolved_model_name = ( model_name if isinstance(model_name, str) and model_name @@ -55,6 +418,25 @@ def get(model_name=None, **kwargs): # pylint: disable=unused-argument if not isinstance(resolved_model_name, str) or not resolved_model_name: raise ValueError("A valid HuggingFace model name must be provided.") + cache_dir = Config().params["model_path"] + "/huggingface" + + model_type = ( + kwargs.get("model_type") + or getattr(getattr(Config(), "trainer", None), "model_type", None) + or "" + ) + + if is_timeseries_model(model_name=resolved_model_name, model_type=model_type): + return Model._get_timeseries_model( + resolved_model_name, cache_dir, model_type=model_type, **kwargs + ) + + # NLP / CausalLM path + config_kwargs = { + "cache_dir": None, + "revision": "main", + "use_auth_token": None, + } config = AutoConfig.from_pretrained(resolved_model_name, **config_kwargs) model = AutoModelForCausalLM.from_pretrained( @@ -70,7 +452,6 @@ def get(model_name=None, **kwargs): # pylint: disable=unused-argument "The 'peft' package is required for LoRA fine-tuning. " "Install it by running `uv add peft`." ) - params_dict = _lora_config_dict(lora_params) logging.info("Configuring LoRA with parameters: %s", params_dict) lora_cfg = LoraConfig(**params_dict) diff --git a/plato/models/registry.py b/plato/models/registry.py index 1dbda56cf..b66c0edbb 100644 --- a/plato/models/registry.py +++ b/plato/models/registry.py @@ -44,6 +44,8 @@ "vit": vit.Model, "nanochat": nanochat.Model, "smolvla": smolvla.Model, + "timesfm": huggingface.Model, + "patchtsmixer": huggingface.Model, } registered_mlx_models = {} diff --git a/plato/models/smolvla.py b/plato/models/smolvla.py index ec0865aae..6760c29ba 100644 --- a/plato/models/smolvla.py +++ b/plato/models/smolvla.py @@ -43,7 +43,7 @@ def _import_smolvla_policy() -> type[Any]: except ImportError as exc: # pragma: no cover - environment dependent raise ImportError( "SmolVLA requires optional LeRobot robotics dependencies. " - "Install the robotics stack in the active environment before using `model_type = \"smolvla\"`." + 'Install the robotics stack in the active environment before using `model_type = "smolvla"`.' ) from exc return SmolVLAPolicy @@ -273,7 +273,9 @@ def get(model_name: str | None = None, **kwargs: Any) -> nn.Module: setattr(policy, "plato_policy_path", policy_path) setattr(policy, "plato_finetune_mode", finetune_mode) setattr(policy, "plato_adapter_patterns", tuple(adapter_patterns)) - setattr(policy, "plato_adapter_fallback_mode", trainable_metadata["fallback_mode"]) + setattr( + policy, "plato_adapter_fallback_mode", trainable_metadata["fallback_mode"] + ) setattr(policy, "plato_trainable_parameter_count", trainable_count) setattr( policy, diff --git a/plato/servers/diloco.py b/plato/servers/diloco.py new file mode 100644 index 000000000..bedbb5f05 --- /dev/null +++ b/plato/servers/diloco.py @@ -0,0 +1,50 @@ +"""FedAvg-compatible server using DiLoCo aggregation.""" + +from plato.config import Config +from plato.servers import fedavg +from plato.servers.strategies.aggregation import DiLoCoAggregationStrategy + + +class Server(fedavg.Server): + """Federated learning server with server-side DiLoCo outer aggregation.""" + + def __init__( + self, + model=None, + datasource=None, + algorithm=None, + trainer=None, + callbacks=None, + aggregation_strategy=None, + client_selection_strategy=None, + ): + if aggregation_strategy is None: + aggregation_strategy = DiLoCoAggregationStrategy( + **self._aggregation_config() + ) + + super().__init__( + model=model, + datasource=datasource, + algorithm=algorithm, + trainer=trainer, + callbacks=callbacks, + aggregation_strategy=aggregation_strategy, + client_selection_strategy=client_selection_strategy, + ) + + @staticmethod + def _aggregation_config() -> dict: + """Read optional DiLoCo aggregation settings from [server.diloco].""" + config = getattr(Config().server, "diloco", None) + if config is None: + return {} + + keys = ( + "outer_optimizer", + "outer_learning_rate", + "outer_momentum", + "aggregation_weighting", + "apply_outer_optimizer_to", + ) + return {key: getattr(config, key) for key in keys if hasattr(config, key)} diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 5f862fc35..2e6ebc733 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -63,6 +63,56 @@ def __init__( self.clients_per_round, ) + def _primary_metric_name(self) -> str: + """Return the name of the primary testing metric.""" + trainer = getattr(self, "trainer", None) + testing_strategy = getattr(trainer, "testing_strategy", None) + metric_name = getattr(testing_strategy, "metric_name", None) + + if isinstance(metric_name, str) and metric_name: + metric_name = metric_name.lower() + if metric_name != "accuracy": + return metric_name + + if hasattr(Config().trainer, "target_perplexity"): + return "perplexity" + + return "accuracy" + + def _log_average_client_metric(self, metric_value: float) -> None: + """Log the client-aggregated testing metric with the appropriate label.""" + metric_name = self._primary_metric_name() + + if metric_name == "mse": + logging.info("[%s] Average client MSE: %.6f.", self, metric_value) + elif metric_name == "perplexity": + logging.info("[%s] Average client perplexity: %.2f.", self, metric_value) + else: + logging.info( + "[%s] Average client accuracy: %.2f%%.", self, 100 * metric_value + ) + + def _log_global_metric(self, metric_value: float) -> None: + """Log the server-tested metric with the appropriate label.""" + metric_name = self._primary_metric_name() + + if metric_name == "mse": + logging.info( + fonts.colourize(f"[{self}] Global model MSE: {metric_value:.6f}\n") + ) + elif metric_name == "perplexity": + logging.info( + fonts.colourize( + f"[{self}] Global model perplexity: {metric_value:.2f}\n" + ) + ) + else: + logging.info( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * metric_value:.2f}%\n" + ) + ) + def configure(self) -> None: """ Booting the federated learning server by setting up the data, model, and @@ -133,7 +183,8 @@ def configure(self) -> None: accuracy_csv_file = ( f"{Config().params['result_path']}/{os.getpid()}_accuracy.csv" ) - accuracy_headers = ["round", "client_id", "accuracy"] + metric_name = self._primary_metric_name() + accuracy_headers = ["round", "client_id", metric_name] csv_processor.initialize_csv( accuracy_csv_file, accuracy_headers, Config().params["result_path"] ) @@ -222,13 +273,20 @@ async def _process_reports(self): # Use delta aggregation (default path) # Computes the weight deltas by comparing the weights received with # the current global model weights - deltas_received = algorithm.compute_weight_deltas( - baseline_weights, weights_received + delta_updates, delta_weights_received = self._weight_updates_and_payloads( + self.updates, weights_received + ) + deltas_received = ( + algorithm.compute_weight_deltas( + baseline_weights, delta_weights_received + ) + if delta_weights_received + else [] ) # Runs a framework-agnostic server aggregation algorithm, such as # the federated averaging algorithm logging.info("[Server #%d] Aggregating model weight deltas.", os.getpid()) - deltas = await self.aggregate_deltas(self.updates, deltas_received) + deltas = await self.aggregate_deltas(delta_updates, deltas_received) # Updates the existing model weights from the provided deltas updated_weights = algorithm.update_weights(deltas) # Loads the new model weights @@ -243,9 +301,7 @@ async def _process_reports(self): if hasattr(Config().server, "do_test") and not Config().server.do_test: # Compute the average accuracy from client reports self.accuracy, self.accuracy_std = self.get_accuracy_mean_std(self.updates) - logging.info( - "[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy - ) + self._log_average_client_metric(self.accuracy) else: # Testing the updated model directly at the server logging.info("[%s] Started model testing.", self) @@ -270,17 +326,9 @@ async def _process_reports(self): ) ) elif hasattr(Config().trainer, "target_perplexity"): - logging.info( - fonts.colourize( - f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" - ) - ) + self._log_global_metric(self.accuracy) else: - logging.info( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n" - ) - ) + self._log_global_metric(self.accuracy) self.clients_processed() self.callback_handler.call_event("on_clients_processed", self) @@ -299,6 +347,20 @@ def _should_prefer_weight_aggregation(self) -> bool: and aggregate_deltas_impl is not FedAvgAggregationStrategy.aggregate_deltas ) + @staticmethod + def _weight_updates_and_payloads(updates, weights_received): + """Return update/payload pairs whose reports contain model weights.""" + delta_updates = [] + delta_weights_received = [] + + for update, weights in zip(updates, weights_received): + if getattr(update.report, "type", "weights") != "weights": + continue + delta_updates.append(update) + delta_weights_received.append(weights) + + return delta_updates, delta_weights_received + def clients_processed(self) -> None: """Additional work to be performed after client reports have been processed.""" @@ -345,6 +407,11 @@ def get_logged_items(self) -> dict: if hasattr(self, "_core_metric"): logged["core_metric"] = self._core_metric + metric_name = self._primary_metric_name() + if metric_name != "accuracy": + logged[metric_name] = self.accuracy + logged[f"{metric_name}_std"] = self.accuracy_std + logged.update(evaluation_logging.extract_logged_items(self.trainer)) return logged diff --git a/plato/servers/fedavg_cs.py b/plato/servers/fedavg_cs.py index eaba3caf6..fe865dd90 100644 --- a/plato/servers/fedavg_cs.py +++ b/plato/servers/fedavg_cs.py @@ -222,11 +222,7 @@ async def _process_reports(self): self.average_accuracy, self.std_accuracy, ) = self.get_accuracy_mean_std(self.updates) - logging.info( - "[%s] Average client accuracy: %.2f%%.", - self, - 100 * self.average_accuracy, - ) + self._log_average_client_metric(self.average_accuracy) elif Config().is_central_server() and Config().clients.do_test: # Compute the average accuracy from client reports total_samples = sum(update.report.num_samples for update in self.updates) @@ -238,11 +234,7 @@ async def _process_reports(self): / total_samples ) - logging.info( - "[%s] Average client accuracy: %.2f%%.", - self, - 100 * self.average_accuracy, - ) + self._log_average_client_metric(self.average_accuracy) if ( Config().is_central_server() @@ -268,18 +260,8 @@ async def _process_reports(self): f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" ) ) - elif hasattr(Config().trainer, "target_perplexity"): - logging.info( - fonts.colourize( - f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" - ) - ) else: - logging.info( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n" - ) - ) + self._log_global_metric(self.accuracy) elif ( Config().is_edge_server() and hasattr(Config().server, "edge_do_test") @@ -304,18 +286,8 @@ async def _process_reports(self): f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" ) ) - elif hasattr(Config().trainer, "target_perplexity"): - logging.info( - fonts.colourize( - f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" - ) - ) else: - logging.info( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n" - ) - ) + self._log_global_metric(self.accuracy) else: self.accuracy = self.average_accuracy self.accuracy_std = self.std_accuracy diff --git a/plato/servers/pfedgraph.py b/plato/servers/pfedgraph.py index f6d694bd4..73aa57d84 100644 --- a/plato/servers/pfedgraph.py +++ b/plato/servers/pfedgraph.py @@ -4,9 +4,12 @@ from __future__ import annotations +import logging +from pathlib import Path from typing import Any, Sequence from plato.config import Config +from plato.serialization.safetensor import serialize_tree from plato.servers import fedavg from plato.servers.strategies.aggregation.pfedgraph import ( PFedGraphAggregationStrategy, @@ -79,3 +82,40 @@ def customize_server_payload(self, payload: Any) -> Any: if client_id in self.client_models: return self.client_models[client_id] return payload + + def _client_model_path(self, client_id: int) -> Path: + """Return the output path for a saved client-specific pFedGraph model.""" + + model_name = ( + Config().trainer.model_name + if hasattr(Config().trainer, "model_name") + else "custom" + ) + return ( + Path(Config().params["model_path"]) + / f"{model_name}_client_{client_id}.safetensors" + ) + + def save_client_models(self) -> None: + """Persist the latest pFedGraph client-specific models.""" + + if not self.client_models: + return + + for client_id, client_model in sorted(self.client_models.items()): + model_path = self._client_model_path(client_id) + model_path.parent.mkdir(parents=True, exist_ok=True) + with model_path.open("wb") as model_file: + model_file.write(serialize_tree(client_model)) + logging.info( + "[%s] Saved pFedGraph client #%d model to %s.", + self, + client_id, + model_path, + ) + + def server_will_close(self) -> None: + """Save pFedGraph client-specific models before server shutdown.""" + + self.save_client_models() + super().server_will_close() diff --git a/plato/servers/registry.py b/plato/servers/registry.py index 2211b8972..b26465a8c 100644 --- a/plato/servers/registry.py +++ b/plato/servers/registry.py @@ -10,6 +10,7 @@ from plato.config import Config from plato.servers import ( + diloco, fedavg, fedavg_cs, fedavg_gan, @@ -30,6 +31,7 @@ registered_servers = { "fedavg": fedavg.Server, "fedavg_lora": fedavg.Server, + "diloco": diloco.Server, "fedavg_cross_silo": fedavg_cs.Server, "fedavg_gan": fedavg_gan.Server, "fedavg_personalized": fedavg_personalized.Server, diff --git a/plato/servers/split_learning.py b/plato/servers/split_learning.py index 49fdc2d4c..72ffe688a 100644 --- a/plato/servers/split_learning.py +++ b/plato/servers/split_learning.py @@ -91,7 +91,7 @@ async def aggregate_weights(self, updates, baseline_weights, weights_received): self.phase = "gradient" elif report.type == "weights": - logging.warning("[%s] Weights received, start testing accuracy.", self) + logging.warning("[%s] Weights received, start testing.", self) weights = update.payload # The weights after cut layer are not trained by clients @@ -112,17 +112,9 @@ async def aggregate_weights(self, updates, baseline_weights, weights_received): ) ) else: - logging.warning( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" - ) - ) + self._log_global_metric(self.test_accuracy) else: - logging.warning( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" - ) - ) + self._log_global_metric(self.test_accuracy) self.phase = "prompt" # Change client in next round self.next_client = True diff --git a/plato/servers/strategies/aggregation/__init__.py b/plato/servers/strategies/aggregation/__init__.py index f44c420a4..fe4174402 100644 --- a/plato/servers/strategies/aggregation/__init__.py +++ b/plato/servers/strategies/aggregation/__init__.py @@ -4,6 +4,7 @@ Each strategy is defined in its own module for clarity. """ +from plato.servers.strategies.aggregation.diloco import DiLoCoAggregationStrategy from plato.servers.strategies.aggregation.fedasync import FedAsyncAggregationStrategy from plato.servers.strategies.aggregation.fedavg import FedAvgAggregationStrategy from plato.servers.strategies.aggregation.fedbuff import FedBuffAggregationStrategy @@ -16,6 +17,7 @@ __all__ = [ "FedAvgAggregationStrategy", + "DiLoCoAggregationStrategy", "FedBuffAggregationStrategy", "FedNovaAggregationStrategy", "FedAsyncAggregationStrategy", diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py new file mode 100644 index 000000000..e1d35aa5f --- /dev/null +++ b/plato/servers/strategies/aggregation/diloco.py @@ -0,0 +1,496 @@ +""" +DiLoCo aggregation strategy. + +The strategy consumes Plato-style client deltas (`client_after - global_before`), +converts them to DiLoCo outer gradients, and returns Plato-compatible server +deltas for `algorithm.update_weights()` to add to the global model. +""" + +from __future__ import annotations + +import asyncio +import copy +import logging +import numbers +from collections.abc import Callable, Mapping +from types import SimpleNamespace +from typing import Any, cast + +import numpy as np + +from plato.servers.strategies.aggregation.fedavg import FedAvgAggregationStrategy +from plato.servers.strategies.base import ServerContext + +try: # pragma: no cover - optional dependency + import torch +except ImportError: # pragma: no cover + torch = cast(Any, None) + + +class DiLoCoAggregationStrategy(FedAvgAggregationStrategy): + """Aggregate client deltas with a server-side DiLoCo outer optimizer.""" + + _SUPPORTED_OPTIMIZERS = {"sgd", "sgdm", "nesterov"} + _SUPPORTED_WEIGHTING_MODES = {"uniform", "num_samples"} + _SUPPORTED_APPLY_POLICIES = {"parameters", "all_floating"} + + def __init__( + self, + outer_optimizer: str = "nesterov", + outer_learning_rate: float = 0.7, + outer_momentum: float = 0.9, + aggregation_weighting: str = "uniform", + apply_outer_optimizer_to: str = "parameters", + ): + super().__init__() + self.outer_optimizer = self._validate_outer_optimizer(outer_optimizer) + self.outer_learning_rate = self._validate_learning_rate(outer_learning_rate) + self.outer_momentum = self._validate_momentum(outer_momentum) + self.aggregation_weighting = self._validate_weighting_mode( + aggregation_weighting + ) + self.apply_outer_optimizer_to = self._validate_apply_policy( + apply_outer_optimizer_to + ) + self.momentum_state: dict[str, Any] = {} + + async def aggregate_deltas( + self, + updates: list[SimpleNamespace], + deltas_received: list[dict], + context: ServerContext, + ) -> dict: + """Aggregate deltas and apply the configured DiLoCo outer optimizer.""" + eligible = self._eligible_updates(updates, deltas_received) + if not eligible: + self._remove_stale_momentum(set()) + return self._empty_delta(context, self._first_delta(deltas_received)) + + weights = self._aggregation_weights(eligible) + if not weights: + self._remove_stale_momentum(set()) + return self._empty_delta(context, eligible[0][1]) + + avg_delta: Any = None + for (_, delta, _), weight in zip(eligible, weights): + avg_delta = self._accumulate_weighted(avg_delta, delta, weight, context) + await asyncio.sleep(0) + + if avg_delta is None: + self._remove_stale_momentum(set()) + return self._empty_delta(context, eligible[0][1]) + + avg_delta = self._match_reference_structure(avg_delta, eligible[0][1]) + optimizer_paths = self._outer_optimizer_paths(avg_delta, context) + server_delta, active_paths = self._apply_outer_optimizer( + avg_delta, optimizer_paths + ) + logging.info( + "[Server] DiLoCo outer optimizer applied: optimizer=%s " + "outer_lr=%g outer_momentum=%g weighting=%s apply_to=%s " + "eligible_updates=%d optimized_tensors=%d.", + self.outer_optimizer, + self.outer_learning_rate, + self.outer_momentum, + self.aggregation_weighting, + self.apply_outer_optimizer_to, + len(eligible), + len(optimizer_paths), + ) + self._remove_stale_momentum(active_paths) + + return self._match_reference_structure(server_delta, eligible[0][1]) + + @classmethod + def _validate_outer_optimizer(cls, value: str) -> str: + optimizer = str(value).lower() + if optimizer not in cls._SUPPORTED_OPTIMIZERS: + supported = ", ".join(sorted(cls._SUPPORTED_OPTIMIZERS)) + raise ValueError( + f"Invalid outer_optimizer '{value}'. Supported values: {supported}." + ) + return optimizer + + @staticmethod + def _validate_learning_rate(value: float) -> float: + learning_rate = float(value) + if learning_rate < 0: + raise ValueError("outer_learning_rate must be nonnegative.") + return learning_rate + + @staticmethod + def _validate_momentum(value: float) -> float: + momentum = float(value) + if not 0 <= momentum < 1: + raise ValueError("outer_momentum must be in the range [0, 1).") + return momentum + + @classmethod + def _validate_weighting_mode(cls, value: str) -> str: + weighting = str(value).lower() + if weighting not in cls._SUPPORTED_WEIGHTING_MODES: + supported = ", ".join(sorted(cls._SUPPORTED_WEIGHTING_MODES)) + raise ValueError( + "Invalid aggregation_weighting " + f"'{value}'. Supported values: {supported}." + ) + return weighting + + @classmethod + def _validate_apply_policy(cls, value: str) -> str: + policy = str(value).lower() + if policy not in cls._SUPPORTED_APPLY_POLICIES: + supported = ", ".join(sorted(cls._SUPPORTED_APPLY_POLICIES)) + raise ValueError( + "Invalid apply_outer_optimizer_to " + f"'{value}'. Supported values: {supported}." + ) + return policy + + def _eligible_updates( + self, + updates: list[SimpleNamespace], + deltas_received: list[dict], + ) -> list[tuple[SimpleNamespace, dict, float]]: + eligible: list[tuple[SimpleNamespace, dict, float]] = [] + for update, delta in zip(updates, deltas_received): + if getattr(update.report, "type", "weights") == "features": + continue + + num_samples = self._num_samples(update) + if num_samples <= 0: + continue + + eligible.append((update, delta, num_samples)) + + return eligible + + @staticmethod + def _num_samples(update: SimpleNamespace) -> float: + try: + return float(update.report.num_samples) + except (AttributeError, TypeError, ValueError): + return 0.0 + + def _aggregation_weights( + self, eligible: list[tuple[SimpleNamespace, dict, float]] + ) -> list[float]: + if not eligible: + return [] + + if self.aggregation_weighting == "uniform": + return [1.0 / len(eligible)] * len(eligible) + + total_samples = sum(num_samples for _, _, num_samples in eligible) + if total_samples <= 0: + return [] + + return [num_samples / total_samples for _, _, num_samples in eligible] + + def _outer_optimizer_paths( + self, avg_delta: Any, context: ServerContext + ) -> set[str]: + if self.apply_outer_optimizer_to == "all_floating": + return self._floating_leaf_paths(avg_delta) + + floating_paths = self._floating_leaf_paths(avg_delta) + trainable_parameter_names = self._trainable_parameter_names( + context, floating_paths + ) + return floating_paths.intersection(trainable_parameter_names) + + def _apply_outer_optimizer( + self, avg_delta: Any, optimizer_paths: set[str] + ) -> tuple[Any, set[str]]: + active_paths: set[str] = set() + + server_delta = self._map_tree( + avg_delta, + lambda value, path: self._apply_outer_optimizer_leaf( + value, path, optimizer_paths, active_paths + ), + ) + return server_delta, active_paths + + def _apply_outer_optimizer_leaf( + self, + avg_delta: Any, + path: str, + optimizer_paths: set[str], + active_paths: set[str], + ) -> Any: + if path not in optimizer_paths: + return avg_delta + + outer_gradient = self._scale_tree(avg_delta, -1.0) + if self.outer_optimizer == "sgd": + return self._scale_tree(outer_gradient, -self.outer_learning_rate) + + return self._apply_momentum_leaf(outer_gradient, path, active_paths) + + def _apply_momentum_leaf( + self, outer_gradient: Any, path: str, active_paths: set[str] + ) -> Any: + active_paths.add(path) + previous = self.momentum_state.get(path) + if previous is not None and not self._is_compatible(previous, outer_gradient): + previous = None + + if previous is None: + momentum = self._clone_tree(outer_gradient) + else: + momentum = self._add_values( + self._scale_tree(previous, self.outer_momentum), + outer_gradient, + ) + + self.momentum_state[path] = self._clone_tree(momentum) + + if self.outer_optimizer == "nesterov": + direction = self._add_values( + outer_gradient, + self._scale_tree(momentum, self.outer_momentum), + ) + else: + direction = momentum + + return self._scale_tree(direction, -self.outer_learning_rate) + + def _remove_stale_momentum(self, active_paths: set[str]) -> None: + if self.outer_optimizer == "sgd": + self.momentum_state.clear() + return + + for path in list(self.momentum_state): + if path not in active_paths: + del self.momentum_state[path] + + def _trainable_parameter_names( + self, context: ServerContext, payload_paths: set[str] | None = None + ) -> set[str]: + model = self._model_from_context(context) + adapter_names = self._adapter_names(model) + trainable_names: set[str] = set() + + for name, parameter in model.named_parameters(): + if getattr(parameter, "requires_grad", False) and self._is_floating_value( + parameter + ): + trainable_names.update( + self._payload_name_candidates(name, adapter_names, payload_paths) + ) + + return trainable_names + + @staticmethod + def _adapter_names(model: Any) -> set[str]: + adapter_names = {"default"} + + peft_config = getattr(model, "peft_config", None) + if isinstance(peft_config, Mapping): + adapter_names.update(str(name) for name in peft_config) + + active_adapter = getattr(model, "active_adapter", None) + if isinstance(active_adapter, str): + adapter_names.add(active_adapter) + + active_adapters = getattr(model, "active_adapters", None) + if callable(active_adapters): + try: + adapter_names.update(str(name) for name in active_adapters()) + except TypeError: + pass + elif isinstance(active_adapters, (list, tuple, set)): + adapter_names.update(str(name) for name in active_adapters) + + return adapter_names + + @classmethod + def _payload_name_candidates( + cls, + parameter_name: str, + adapter_names: set[str], + payload_paths: set[str] | None, + ) -> set[str]: + candidates = {parameter_name} + if payload_paths is not None and parameter_name in payload_paths: + return candidates + + parts = parameter_name.split(".") + for index, part in enumerate(parts): + if part not in adapter_names: + continue + + candidate = ".".join(parts[:index] + parts[index + 1 :]) + if payload_paths is None or candidate in payload_paths: + candidates.add(candidate) + + return candidates + + @staticmethod + def _model_from_context(context: ServerContext) -> Any: + trainer = getattr(context, "trainer", None) + model = getattr(trainer, "model", None) if trainer is not None else None + if model is None or not hasattr(model, "named_parameters"): + raise AttributeError( + "DiLoCo apply_outer_optimizer_to='parameters' requires " + "context.trainer.model with named_parameters()." + ) + return model + + def _floating_leaf_paths(self, value: Any) -> set[str]: + return self._collect_leaf_paths( + value, lambda leaf, _: self._is_floating_value(leaf) + ) + + def _collect_leaf_paths( + self, + value: Any, + predicate: Callable[[Any, str], bool], + path: str = "", + ) -> set[str]: + if isinstance(value, Mapping): + paths: set[str] = set() + for key, item in value.items(): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, key) + ) + ) + return paths + + if isinstance(value, list): + paths = set() + for index, item in enumerate(value): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, index) + ) + ) + return paths + + if isinstance(value, tuple): + paths = set() + for index, item in enumerate(value): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, index) + ) + ) + return paths + + return {path} if predicate(value, path) else set() + + @staticmethod + def _is_floating_value(value: Any) -> bool: + if torch is not None and isinstance(value, torch.Tensor): + return torch.is_floating_point(value) + + if isinstance(value, np.ndarray): + return np.issubdtype(value.dtype, np.floating) + + return isinstance(value, numbers.Real) and not isinstance( + value, (numbers.Integral, bool) + ) + + def _empty_delta(self, context: ServerContext, reference_delta: Any | None) -> dict: + zero_delta = self._zero_delta(context, reference_delta) + if zero_delta is not None: + return zero_delta + + if reference_delta is None: + return {} + + return self._scale_tree(reference_delta, 0.0) + + @staticmethod + def _first_delta(deltas_received: list[dict]) -> dict | None: + return deltas_received[0] if deltas_received else None + + def _map_tree(self, value: Any, leaf_fn: Callable[[Any, str], Any], path="") -> Any: + if isinstance(value, Mapping): + return { + key: self._map_tree(item, leaf_fn, self._join_path(path, key)) + for key, item in value.items() + } + + if isinstance(value, list): + return [ + self._map_tree(item, leaf_fn, self._join_path(path, index)) + for index, item in enumerate(value) + ] + + if isinstance(value, tuple): + return tuple( + self._map_tree(item, leaf_fn, self._join_path(path, index)) + for index, item in enumerate(value) + ) + + return leaf_fn(value, path) + + def _scale_tree(self, value: Any, scalar: float) -> Any: + if isinstance(value, Mapping): + return {key: self._scale_tree(item, scalar) for key, item in value.items()} + + if isinstance(value, list): + return [self._scale_tree(item, scalar) for item in value] + + if isinstance(value, tuple): + return tuple(self._scale_tree(item, scalar) for item in value) + + return value * scalar + + @staticmethod + def _add_values(left: Any, right: Any) -> Any: + return left + right + + def _clone_tree(self, value: Any) -> Any: + if isinstance(value, Mapping): + return {key: self._clone_tree(item) for key, item in value.items()} + + if isinstance(value, list): + return [self._clone_tree(item) for item in value] + + if isinstance(value, tuple): + return tuple(self._clone_tree(item) for item in value) + + if torch is not None and isinstance(value, torch.Tensor): + return value.detach().clone() + + if isinstance(value, np.ndarray): + return value.copy() + + try: + return copy.deepcopy(value) + except TypeError: + return value + + @staticmethod + def _is_compatible(left: Any, right: Any) -> bool: + if torch is not None and isinstance(left, torch.Tensor): + return ( + isinstance(right, torch.Tensor) + and left.shape == right.shape + and left.dtype == right.dtype + ) + + if isinstance(left, np.ndarray): + return ( + isinstance(right, np.ndarray) + and left.shape == right.shape + and left.dtype == right.dtype + ) + + left_shape = getattr(left, "shape", None) + right_shape = getattr(right, "shape", None) + if left_shape is not None or right_shape is not None: + return left_shape == right_shape and getattr( + left, "dtype", None + ) == getattr(right, "dtype", None) + + return isinstance(left, numbers.Number) and isinstance(right, numbers.Number) + + @staticmethod + def _join_path(prefix: str, key: Any) -> str: + key_text = str(key) + return key_text if not prefix else f"{prefix}.{key_text}" diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index 1e98d1128..2645b7904 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -168,6 +168,7 @@ def __init__( self.current_epoch = 0 self.training_start_time = time.time() self.model_state_dict = None + self._preserved_optimizer_states: dict[int, dict[str, Any]] = {} def _require_model(self) -> nn.Module: """Return the underlying model, ensuring it is available.""" @@ -177,6 +178,300 @@ def _require_model(self) -> nn.Module: ) return cast(nn.Module, self.model) + @staticmethod + def _local_steps_per_round(config: dict[str, Any]) -> int | None: + """Return the optional local optimizer-step limit for one train run.""" + value = config.get("local_steps_per_round") + if value is None: + return None + + if isinstance(value, bool) or not isinstance(value, int) or value <= 0: + raise ValueError( + "trainer.local_steps_per_round must be a positive integer." + ) + + return value + + def _record_local_optimizer_step(self, local_steps_per_round: int | None) -> bool: + """Record one completed optimizer step and report whether H was reached.""" + if local_steps_per_round is None: + return False + + completed_steps = int(self.context.state.get("local_optimizer_steps", 0)) + 1 + self.context.state["local_optimizer_steps"] = completed_steps + return completed_steps >= local_steps_per_round + + @staticmethod + def _preserve_optimizer_state(config: dict[str, Any]) -> bool: + """Return whether optimizer state should survive local train runs.""" + return bool(config.get("preserve_optimizer_state", False)) + + @staticmethod + def _step_lr_scheduler_per_optimizer_step(config: dict[str, Any]) -> bool: + """Return whether LR scheduling should follow optimizer steps.""" + if config.get("local_steps_per_round") is None: + return False + + return getattr(Config().server, "type", None) == "diloco" + + def _step_lr_scheduler_after_optimizer_step( + self, step_lr_per_optimizer_step: bool + ) -> None: + """Advance step-based LR schedules after one completed optimizer step.""" + if step_lr_per_optimizer_step: + self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) + + @staticmethod + def _parameter_signature(name: str | None, parameter: torch.Tensor): + """Build a compatibility signature for one model parameter.""" + return (name, tuple(parameter.shape), str(parameter.dtype)) + + @classmethod + def _model_parameter_signature(cls, model: nn.Module): + """Return parameter names, shapes, dtypes, and order for a model.""" + return tuple( + cls._parameter_signature(name, parameter) + for name, parameter in model.named_parameters() + ) + + @classmethod + def _optimizer_parameter_signature( + cls, model: nn.Module, optimizer: torch.optim.Optimizer + ): + """Return optimizer parameter group ordering with model metadata.""" + named_parameters = { + id(parameter): cls._parameter_signature(name, parameter) + for name, parameter in model.named_parameters() + } + + group_signatures = [] + for group in optimizer.param_groups: + group_signatures.append( + tuple( + named_parameters.get( + id(parameter), + cls._parameter_signature(None, parameter), + ) + for parameter in group.get("params", []) + ) + ) + + return tuple(group_signatures) + + @staticmethod + def _scheduler_type(scheduler: Any | None) -> type | None: + """Return the scheduler type used for compatibility checks.""" + if scheduler is None: + return None + return type(scheduler) + + def _preserved_state_is_compatible( + self, + payload: dict[str, Any], + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Any | None, + ) -> bool: + """Return whether a cached optimizer bundle matches this train run.""" + if payload.get("optimizer_type") is not type(optimizer): + return False + + if payload.get("scheduler_type") is not self._scheduler_type(scheduler): + return False + + if payload.get("model_parameters") != self._model_parameter_signature(model): + return False + + if payload.get("optimizer_parameters") != self._optimizer_parameter_signature( + model, optimizer + ): + return False + + if not callable(getattr(optimizer, "load_state_dict", None)): + return False + + if payload.get("scheduler_state") is not None and not callable( + getattr(scheduler, "load_state_dict", None) + ): + return False + + return True + + def _restore_preserved_optimizer_state(self) -> None: + """Restore compatible optimizer and scheduler state for this client.""" + payload = self._preserved_optimizer_states.get(self.client_id) + if payload is None or self.optimizer is None: + return + + model = self._require_model() + if not self._preserved_state_is_compatible( + payload, model, self.optimizer, self.lr_scheduler + ): + logging.info( + "[Client #%d] Discarding incompatible optimizer state; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + ) + self._preserved_optimizer_states.pop(self.client_id, None) + return + + try: + scheduler_state = payload.get("scheduler_state") + if scheduler_state is not None: + self.lr_scheduler.load_state_dict(copy.deepcopy(scheduler_state)) + + self.optimizer.load_state_dict(copy.deepcopy(payload["optimizer_state"])) + except Exception as error: + logging.warning( + "[Client #%d] Discarding incompatible optimizer state; " + "starting with fresh optimizer and scheduler state: %s", + self.client_id, + error, + ) + self._preserved_optimizer_states.pop(self.client_id, None) + self.optimizer = self.optimizer_strategy.create_optimizer( + model, self.context + ) + self.lr_scheduler = self.lr_scheduler_strategy.create_scheduler( + self.optimizer, self.context + ) + + def _save_preserved_optimizer_state(self) -> None: + """Save optimizer and scheduler state locally for this logical client.""" + if self.optimizer is None: + return + + model = self._require_model() + scheduler_state = None + if self.lr_scheduler is not None: + state_dict_fn = getattr(self.lr_scheduler, "state_dict", None) + if callable(state_dict_fn): + scheduler_state = copy.deepcopy(state_dict_fn()) + + self._preserved_optimizer_states[self.client_id] = { + "optimizer_type": type(self.optimizer), + "optimizer_state": copy.deepcopy(self.optimizer.state_dict()), + "scheduler_type": self._scheduler_type(self.lr_scheduler), + "scheduler_state": scheduler_state, + "model_parameters": self._model_parameter_signature(model), + "optimizer_parameters": self._optimizer_parameter_signature( + model, self.optimizer + ), + } + + def _optimizer_state_filename(self, run_id: str) -> str: + """Return the local optimizer-state handoff filename.""" + model_name = Config().trainer.model_name + return f"{model_name}_{self.client_id}_{run_id}.optim.pkl" + + def _optimizer_state_output_filename(self, run_id: str) -> str: + """Return a unique subprocess optimizer-state output filename.""" + model_name = Config().trainer.model_name + token = time.time_ns() + return f"{model_name}_{self.client_id}_{run_id}_{os.getpid()}_{token}.optim.pkl" + + def _optimizer_state_path(self, filename: str) -> str: + """Return the local optimizer-state handoff path.""" + return os.path.join(Config().params["model_path"], filename) + + def _save_preserved_optimizer_state_file(self, filename: str) -> bool: + """Persist preserved optimizer state for subprocess handoff.""" + payload = self._preserved_optimizer_states.get(self.client_id) + if payload is None: + return False + + model_path = Config().params["model_path"] + os.makedirs(model_path, exist_ok=True) + state_path = self._optimizer_state_path(filename) + tmp_path = f"{state_path}.{os.getpid()}.tmp" + + try: + with open(tmp_path, "wb") as state_file: + pickle.dump(copy.deepcopy(payload), state_file) + os.replace(tmp_path, state_path) + return True + except Exception as error: + if os.path.exists(tmp_path): + os.remove(tmp_path) + logging.warning( + "[Client #%d] Failed to persist optimizer state to %s: %s", + self.client_id, + state_path, + error, + ) + return False + + def _load_preserved_optimizer_state_file( + self, filename: str, *, clear_on_missing: bool = False + ) -> bool: + """Load preserved optimizer state from a subprocess handoff file.""" + state_path = self._optimizer_state_path(filename) + if not os.path.exists(state_path): + if clear_on_missing: + self._preserved_optimizer_states.pop(self.client_id, None) + logging.info( + "[Client #%d] No persisted optimizer state found at %s; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + state_path, + ) + return False + + try: + with open(state_path, "rb") as state_file: + payload = pickle.load(state_file) + except Exception as error: + self._preserved_optimizer_states.pop(self.client_id, None) + logging.warning( + "[Client #%d] Discarding unreadable optimizer state at %s; " + "starting with fresh optimizer and scheduler state: %s", + self.client_id, + state_path, + error, + ) + return False + + if not isinstance(payload, dict): + self._preserved_optimizer_states.pop(self.client_id, None) + logging.warning( + "[Client #%d] Discarding invalid optimizer state at %s; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + state_path, + ) + return False + + self._preserved_optimizer_states[self.client_id] = payload + return True + + def _remove_preserved_optimizer_state_file(self, filename: str) -> None: + """Remove a local optimizer-state sidecar if it exists.""" + state_path = self._optimizer_state_path(filename) + try: + os.remove(state_path) + except FileNotFoundError: + return + except OSError as error: + logging.warning( + "[Client #%d] Failed to remove optimizer state at %s: %s", + self.client_id, + state_path, + error, + ) + + def _finish_subprocess_optimizer_state( + self, input_filename: str, output_filename: str + ) -> None: + """Load the child output sidecar and promote it for the next round.""" + loaded = self._load_preserved_optimizer_state_file( + output_filename, clear_on_missing=True + ) + if loaded: + self._save_preserved_optimizer_state_file(input_filename) + self._remove_preserved_optimizer_state_file(output_filename) + else: + self._remove_preserved_optimizer_state_file(input_filename) + @staticmethod def _persisted_test_state_keys() -> tuple[str, ...]: """State keys that must survive spawned test subprocesses.""" @@ -384,8 +679,25 @@ def simulate_sleep_time(self): def train_process(self, config, trainset, sampler, **kwargs): """The training process in a federated learning workload.""" + preserve_optimizer_state = self._preserve_optimizer_state(config) + if preserve_optimizer_state: + optimizer_state_filename = config.get( + "_optimizer_state_input_filename", + self._optimizer_state_filename(config["run_id"]), + ) + optimizer_state_output_filename = config.get( + "_optimizer_state_output_filename", + optimizer_state_filename, + ) + self._load_preserved_optimizer_state_file( + optimizer_state_filename, clear_on_missing=True + ) + self.train_model(config, trainset, sampler, **kwargs) + if preserve_optimizer_state: + self._save_preserved_optimizer_state_file(optimizer_state_output_filename) + model_name = Config().trainer.model_name filename = f"{model_name}_{self.client_id}_{config['run_id']}.safetensors" self.save_model(filename) @@ -397,6 +709,16 @@ def train_model(self, config, trainset, sampler, **kwargs): self.sampler = sampler self.context.config = config self.context.current_round = self.current_round + preserve_optimizer_state = self._preserve_optimizer_state(config) + if not preserve_optimizer_state: + self._preserved_optimizer_states.pop(self.client_id, None) + + local_steps_per_round = self._local_steps_per_round(config) + self.context.state["local_optimizer_steps"] = 0 + if local_steps_per_round is None: + self.context.state.pop("local_steps_per_round", None) + else: + self.context.state["local_steps_per_round"] = local_steps_per_round # Ensure training step strategy respects higher-order gradient settings if self.training_step_strategy is not None: @@ -476,24 +798,28 @@ def train_model(self, config, trainset, sampler, **kwargs): self.context.state["grad_accum_loss_total"] = 0.0 self.context.state["grad_accum_loss_count"] = 0 - # Create optimizer using strategy + # Move the model before optimizer state restore so PyTorch maps restored + # state tensors onto the same device as the optimizer parameters. model = self._require_model() + model.to(self.device) + model.train() + + # Create optimizer using strategy self.optimizer = self.optimizer_strategy.create_optimizer(model, self.context) # Create LR scheduler using strategy self.lr_scheduler = self.lr_scheduler_strategy.create_scheduler( self.optimizer, self.context ) - - # Move model to device - model = self._require_model() - model.to(self.device) - model.train() + if preserve_optimizer_state: + self._restore_preserved_optimizer_state() # Training epochs total_epochs = config["epochs"] + step_lr_per_optimizer_step = self._step_lr_scheduler_per_optimizer_step(config) tic = time.perf_counter() training_stop_requested = False + local_step_limit_reached = False try: total_batches = len(self.train_loader) except (TypeError, AttributeError): @@ -551,7 +877,16 @@ def compute_loss(outputs, labels_inner): ) # Track loss - self._loss_tracker.update(loss, labels.size(0)) + if labels is not None: + batch_size = labels.size(0) + else: + first_val = ( + next(iter(examples.values())) + if hasattr(examples, "values") + else examples + ) + batch_size = first_val.size(0) if hasattr(first_val, "size") else 1 + self._loss_tracker.update(loss, batch_size) # Store last loss in context self.context.state["last_loss"] = loss.item() @@ -564,6 +899,12 @@ def compute_loss(outputs, labels_inner): self.optimizer_strategy.on_optimizer_step( self.optimizer, self.context ) + self._step_lr_scheduler_after_optimizer_step( + step_lr_per_optimizer_step + ) + local_step_limit_reached = self._record_local_optimizer_step( + local_steps_per_round + ) # Strategy hook: after_step self.model_update_strategy.after_step(self.context) @@ -591,7 +932,7 @@ def compute_loss(outputs, labels_inner): ): self._handle_control_log() - if control_actions.get("stop_training"): + if control_actions.get("stop_training") or local_step_limit_reached: training_stop_requested = True break @@ -601,7 +942,11 @@ def compute_loss(outputs, labels_inner): finalize_loss = None finalize_step_done = False finalize_callable = getattr(self.training_step_strategy, "finalize", None) - if batches_seen and callable(finalize_callable): + if ( + batches_seen + and callable(finalize_callable) + and not local_step_limit_reached + ): finalize_loss = finalize_callable( model=model, optimizer=self.optimizer, @@ -613,6 +958,10 @@ def compute_loss(outputs, labels_inner): ) if finalize_step_done: self.optimizer_strategy.on_optimizer_step(self.optimizer, self.context) + self._step_lr_scheduler_after_optimizer_step(step_lr_per_optimizer_step) + local_step_limit_reached = self._record_local_optimizer_step( + local_steps_per_round + ) self.model_update_strategy.after_step(self.context) self.callback_handler.call_event( "on_train_step_end", @@ -652,11 +1001,15 @@ def compute_loss(outputs, labels_inner): # No batches remain, but respect control flag. pass + if local_step_limit_reached: + training_stop_requested = True + self.context.state.pop("is_last_batch", None) self.context.state.pop("hf_optimizer_step_index", None) # LR scheduler step - self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) + if not step_lr_per_optimizer_step: + self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) # Handle optimizer params state update if needed if hasattr(self.optimizer, "params_state_update"): @@ -701,6 +1054,9 @@ def compute_loss(outputs, labels_inner): # Callbacks: train run end self.callback_handler.call_event("on_train_run_end", self, config) + if preserve_optimizer_state: + self._save_preserved_optimizer_state() + def train(self, trainset, sampler, **kwargs) -> float: """ The main training loop in a federated learning workload. @@ -721,6 +1077,21 @@ def train(self, trainset, sampler, **kwargs) -> float: if "max_concurrency" in config: tic = time.perf_counter() + preserve_optimizer_state = self._preserve_optimizer_state(config) + optimizer_state_filename = None + optimizer_state_output_filename = None + if preserve_optimizer_state: + optimizer_state_filename = self._optimizer_state_filename( + config["run_id"] + ) + optimizer_state_output_filename = self._optimizer_state_output_filename( + config["run_id"] + ) + config = { + **config, + "_optimizer_state_input_filename": optimizer_state_filename, + "_optimizer_state_output_filename": optimizer_state_output_filename, + } if mp.get_start_method(allow_none=True) != "spawn": mp.set_start_method("spawn", force=True) @@ -773,6 +1144,14 @@ def train(self, trainset, sampler, **kwargs) -> float: f"Training on client {self.client_id} failed." ) from error + if ( + optimizer_state_filename is not None + and optimizer_state_output_filename is not None + ): + self._finish_subprocess_optimizer_state( + optimizer_state_filename, optimizer_state_output_filename + ) + toc = time.perf_counter() self.pause_training() else: diff --git a/plato/trainers/huggingface.py b/plato/trainers/huggingface.py index 734b857b2..8d4c7e592 100644 --- a/plato/trainers/huggingface.py +++ b/plato/trainers/huggingface.py @@ -38,6 +38,7 @@ TrainingContext, TrainingStepStrategy, ) +from plato.utils.timeseries_utils import is_timeseries_model class HuggingFaceBatch(dict): @@ -124,7 +125,10 @@ def __call__( if not example_list: raise ValueError("HuggingFace collator received an empty batch.") - feature_rows = [{k: v for k, v in example.items() if k != "labels"} for example in example_list] + feature_rows = [ + {k: v for k, v in example.items() if k != "labels"} + for example in example_list + ] padding_side = getattr(self.tokenizer, "padding_side", "right") batch = self.tokenizer.pad( @@ -161,6 +165,82 @@ def __call__( return batch, labels +class TimeSeriesCollateWrapper: + """Collate function for time-series datasets that return tensor dicts. + + Stacks per-sample dicts (e.g. ``{"past_values": ..., "future_values": ...}``) + into a batched ``HuggingFaceBatch``. Labels are always ``None`` because the + model computes its own loss from ``future_values``. + """ + + def __call__(self, examples: Iterable[dict]) -> tuple[HuggingFaceBatch, None]: + example_list = list(examples) + if not example_list: + raise ValueError("TimeSeriesCollateWrapper received an empty batch.") + + keys = example_list[0].keys() + batch = HuggingFaceBatch( + { + key: torch.stack([torch.as_tensor(ex[key]) for ex in example_list]) + for key in keys + } + ) + return batch, None + + +class TimeSeriesTestingStrategy(TestingStrategy): + """Evaluates time-series models and reports mean MSE loss.""" + + metric_name = "mse" + + def __init__(self, collate_fn: TimeSeriesCollateWrapper): + self.collate_fn = collate_fn + + def test_model(self, model, config, testset, sampler, context: TrainingContext): + batch_size = config.get("batch_size", 1) + + if sampler is not None: + if isinstance(sampler, torch.utils.data.Sampler): + sampler_obj = sampler + elif isinstance(sampler, (list, range)): + sampler_obj = torch.utils.data.SubsetRandomSampler(sampler) + elif hasattr(sampler, "get"): + sampler_obj = sampler.get() + else: + sampler_obj = sampler + else: + sampler_obj = None + + data_loader = torch.utils.data.DataLoader( + testset, + batch_size=batch_size, + shuffle=False, + sampler=sampler_obj, + collate_fn=self.collate_fn, + ) + + model.to(context.device) + model.eval() + + total_loss = 0.0 + num_batches = 0 + + with torch.no_grad(): + for batch_inputs, _ in data_loader: + batch_inputs = batch_inputs.to(context.device) + batch_inputs.setdefault("return_dict", True) + outputs = model(**batch_inputs) + loss = _resolve_hf_loss(outputs, labels=None) + total_loss += loss.item() + num_batches += 1 + + model.train() + + if num_batches == 0: + return float("inf") + return total_loss / num_batches + + def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): """ Resolve a loss tensor from HuggingFace model outputs. @@ -520,45 +600,7 @@ def __init__(self, model=None, callbacks=None): self.training_args = cast(TrainingArguments, training_args) model_name = Config().trainer.model_name - tokenizer_name = getattr(Config().trainer, "tokenizer_name", model_name) - if not isinstance(tokenizer_name, str) or not tokenizer_name: - tokenizer_name = model_name - - config_kwargs = { - "cache_dir": None, - "revision": "main", - "use_auth_token": None, - } - self.config = AutoConfig.from_pretrained(model_name, **config_kwargs) - - cache_dir = Config().params["data_path"] - use_fast_tokenizer = True - revision = "main" - auth_token = getattr( - getattr(Config(), "parameters", None), "huggingface_token", None - ) - - tokenizer_loader: Any = ( - LlamaTokenizer if "llama" in tokenizer_name else AutoTokenizer - ) - tokenizer_kwargs: dict[str, Any] = { - "config": self.config, - "cache_dir": cache_dir, - "use_fast": use_fast_tokenizer, - "revision": revision, - } - if isinstance(auth_token, str) and auth_token: - tokenizer_kwargs["use_auth_token"] = auth_token - self.tokenizer: Any = tokenizer_loader.from_pretrained( - tokenizer_name, - **tokenizer_kwargs, - ) - - tokenizer = cast(Any, self.tokenizer) - if getattr(tokenizer, "pad_token_id", None) is None: - eos_token = getattr(tokenizer, "eos_token", None) - if eos_token is not None: - tokenizer.pad_token = eos_token + model_type = getattr(Config().trainer, "model_type", "") grad_accum_steps = getattr(Config().trainer, "gradient_accumulation_steps", 1) try: @@ -566,7 +608,59 @@ def __init__(self, model=None, callbacks=None): except (TypeError, ValueError): grad_accum_steps = 1 self._gradient_accumulation_steps = max(grad_accum_steps, 1) - self._collate_wrapper = HuggingFaceCollateWrapper(tokenizer) + + if is_timeseries_model(model_name=model_name, model_type=model_type): + # Time-series models have no tokenizer. Use a simple tensor-stacking + # collator and return raw MSE from the testing strategy. + self.tokenizer = None + self.config = None + ts_collate = TimeSeriesCollateWrapper() + self._collate_wrapper = ts_collate + testing_strategy: TestingStrategy = TimeSeriesTestingStrategy(ts_collate) + else: + tokenizer_name = getattr(Config().trainer, "tokenizer_name", model_name) + if not isinstance(tokenizer_name, str) or not tokenizer_name: + tokenizer_name = model_name + + config_kwargs = { + "cache_dir": None, + "revision": "main", + "use_auth_token": None, + } + self.config = AutoConfig.from_pretrained(model_name, **config_kwargs) + + cache_dir = Config().params["data_path"] + use_fast_tokenizer = True + revision = "main" + auth_token = getattr( + getattr(Config(), "parameters", None), "huggingface_token", None + ) + + tokenizer_loader: Any = ( + LlamaTokenizer if "llama" in tokenizer_name else AutoTokenizer + ) + tokenizer_kwargs: dict[str, Any] = { + "config": self.config, + "cache_dir": cache_dir, + "use_fast": use_fast_tokenizer, + "revision": revision, + } + if isinstance(auth_token, str) and auth_token: + tokenizer_kwargs["use_auth_token"] = auth_token + self.tokenizer: Any = tokenizer_loader.from_pretrained( + tokenizer_name, + **tokenizer_kwargs, + ) + + tokenizer = cast(Any, self.tokenizer) + if getattr(tokenizer, "pad_token_id", None) is None: + eos_token = getattr(tokenizer, "eos_token", None) + if eos_token is not None: + tokenizer.pad_token = eos_token + + self._collate_wrapper = HuggingFaceCollateWrapper(tokenizer) + testing_strategy = HuggingFaceTestingStrategy(self._collate_wrapper) + self.training_args.gradient_accumulation_steps = ( self._gradient_accumulation_steps ) @@ -593,7 +687,7 @@ def __init__(self, model=None, callbacks=None): num_workers=0, pin_memory=True, ), - testing_strategy=HuggingFaceTestingStrategy(self._collate_wrapper), + testing_strategy=testing_strategy, ) if hf_callbacks: @@ -603,23 +697,27 @@ def __init__(self, model=None, callbacks=None): if hasattr(model_instance, "loss_type"): setattr(model_instance, "loss_type", "ForCausalLM") - tokenizer_vocab_size = None - if hasattr(self.tokenizer, "__len__"): - try: - tokenizer_vocab_size = len(self.tokenizer) - except TypeError: - tokenizer_vocab_size = None - embedding_getter = getattr(model_instance, "get_input_embeddings", None) - embedding_resizer = getattr(model_instance, "resize_token_embeddings", None) - if ( - tokenizer_vocab_size is not None - and callable(embedding_getter) - and callable(embedding_resizer) - ): - embeddings = embedding_getter() - embedding_size = getattr(embeddings, "num_embeddings", None) - if embedding_size is not None and embedding_size != tokenizer_vocab_size: - embedding_resizer(tokenizer_vocab_size) + if self.tokenizer is not None: + tokenizer_vocab_size = None + if hasattr(self.tokenizer, "__len__"): + try: + tokenizer_vocab_size = len(self.tokenizer) + except TypeError: + tokenizer_vocab_size = None + embedding_getter = getattr(model_instance, "get_input_embeddings", None) + embedding_resizer = getattr(model_instance, "resize_token_embeddings", None) + if ( + tokenizer_vocab_size is not None + and callable(embedding_getter) + and callable(embedding_resizer) + ): + embeddings = embedding_getter() + embedding_size = getattr(embeddings, "num_embeddings", None) + if ( + embedding_size is not None + and embedding_size != tokenizer_vocab_size + ): + embedding_resizer(tokenizer_vocab_size) if self.training_args.gradient_checkpointing: model_config = getattr(model_instance, "config", None) diff --git a/plato/trainers/lerobot.py b/plato/trainers/lerobot.py index 051733e47..6c607d3b8 100644 --- a/plato/trainers/lerobot.py +++ b/plato/trainers/lerobot.py @@ -344,9 +344,7 @@ def _resolve_runtime_device(device_value: Any, fallback_device: Any) -> torch.de try: gpu_index = int(normalized.split(":", 1)[1]) except (IndexError, ValueError) as exc: - raise ValueError( - f"Invalid CUDA device value: '{device_value}'." - ) from exc + raise ValueError(f"Invalid CUDA device value: '{device_value}'.") from exc if gpu_index < 0 or gpu_index >= torch.cuda.device_count(): raise RuntimeError( f"`parameters.policy.device` requested CUDA device {gpu_index}, " @@ -466,9 +464,7 @@ def training_step( ) if not torch.is_tensor(loss): - raise TypeError( - "LeRobot policy forward did not return a tensor loss." - ) + raise TypeError("LeRobot policy forward did not return a tensor loss.") loss.backward() optimizer.step() diff --git a/plato/trainers/lr_schedulers.py b/plato/trainers/lr_schedulers.py index edf87bd9b..fe2b77fef 100644 --- a/plato/trainers/lr_schedulers.py +++ b/plato/trainers/lr_schedulers.py @@ -104,8 +104,9 @@ def get(optimizer: optim.Optimizer, iterations_per_epoch: int, **kwargs: str | d for x in lr_params["milestone_steps"].split(",") ] lambdas.append( - lambda it, milestones=milestones: lr_params["gamma"] - ** bisect.bisect(milestones, it) + lambda it, milestones=milestones: ( + lr_params["gamma"] ** bisect.bisect(milestones, it) + ) ) # Add a linear learning rate warmup if specified diff --git a/plato/trainers/strategies/data_loader.py b/plato/trainers/strategies/data_loader.py index 91e5a0482..c934e48b4 100644 --- a/plato/trainers/strategies/data_loader.py +++ b/plato/trainers/strategies/data_loader.py @@ -14,12 +14,26 @@ import torch import torch.utils.data +from plato.config import Config from plato.trainers.strategies.base import DataLoaderStrategy, TrainingContext CollateFn = Callable[[list[Any]], Any] AdjustFn = Callable[[TrainingContext], int] +class _FixedOrderSampler(torch.utils.data.Sampler): + """Sampler that yields precomputed dataset indices in order.""" + + def __init__(self, indices: list[int]): + self._indices = indices + + def __iter__(self): + return iter(self._indices) + + def __len__(self): + return len(self._indices) + + def _context_uses_cuda(context: TrainingContext) -> bool: """Return True if the training context targets a CUDA device.""" device = getattr(context, "device", None) @@ -40,6 +54,77 @@ def _resolve_pin_memory(setting: bool | None, context: TrainingContext) -> bool: return _context_uses_cuda(context) +def _local_step_stream_start( + context: TrainingContext, samples_per_round: int, stream_length: int +) -> int: + """Return the deterministic stream offset for this local-step round.""" + current_round = int(getattr(context, "current_round", 0) or 0) + if current_round > 0: + return ((current_round - 1) * samples_per_round) % stream_length + + offset = int(context.state.get("_local_step_sampler_stream_offset", 0)) + context.state["_local_step_sampler_stream_offset"] = offset + samples_per_round + return offset % stream_length + + +def _enforce_diloco_full_participation_for_local_steps() -> None: + """Require DiLoCo workers to train once per outer synchronization.""" + server_type = getattr(Config().server, "type", None) + if server_type != "diloco": + return + + total_clients = int(Config().clients.total_clients) + clients_per_round = int(Config().clients.per_round) + if clients_per_round == total_clients: + return + + raise ValueError( + "DiLoCo local-step data loading requires clients.per_round to equal " + "clients.total_clients so every worker advances its local data stream " + "once per outer round." + ) + + +def _apply_local_step_sampling_stream( + sampler_obj, batch_size: int, context: TrainingContext +): + """Advance deterministic samplers across short local-step rounds.""" + local_steps_per_round = context.state.get("local_steps_per_round") + if local_steps_per_round is None: + return sampler_obj + + _enforce_diloco_full_participation_for_local_steps() + + if sampler_obj is None: + return sampler_obj + + samples_per_round = int(local_steps_per_round) * int(batch_size) + if samples_per_round <= 0: + return sampler_obj + + try: + indices = list(iter(sampler_obj)) + except (TypeError, NotImplementedError): + logging.warning( + "Sampler %s cannot be materialized for round-aware local-step " + "sampling; using it unchanged. Consecutive short local rounds may " + "replay the same sampler prefix.", + type(sampler_obj), + ) + return sampler_obj + + if len(indices) == 0: + return sampler_obj + + start = _local_step_stream_start(context, samples_per_round, len(indices)) + if start == 0: + ordered_indices = indices + else: + ordered_indices = indices[start:] + indices[:start] + + return _FixedOrderSampler(ordered_indices) + + class DefaultDataLoaderStrategy(DataLoaderStrategy): """ Default data loader strategy. @@ -100,6 +185,10 @@ def create_train_loader( sampler_obj = None shuffle = self.shuffle + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + if sampler is None and not shuffle: logging.warning( "Data loader strategy received no sampler; falling back to SequentialSampler." @@ -174,6 +263,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, @@ -239,6 +332,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, @@ -320,6 +417,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, actual_batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=actual_batch_size, @@ -383,6 +484,10 @@ def create_train_loader( sampler_obj = None shuffle = True + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, diff --git a/plato/trainers/strategies/training_step.py b/plato/trainers/strategies/training_step.py index b4aba6d9d..5afa9702c 100644 --- a/plato/trainers/strategies/training_step.py +++ b/plato/trainers/strategies/training_step.py @@ -128,6 +128,9 @@ def training_step( if self.current_step % self.accumulation_steps == 0: optimizer.step() optimizer.zero_grad() + context.state["optimizer_step_completed"] = True + else: + context.state["optimizer_step_completed"] = False # Return unscaled loss for logging return loss diff --git a/plato/utils/timeseries_utils.py b/plato/utils/timeseries_utils.py new file mode 100644 index 000000000..0593695ff --- /dev/null +++ b/plato/utils/timeseries_utils.py @@ -0,0 +1,48 @@ +""" +Utility functions for time series model detection and handling. +""" + +from typing import Optional + +# Single source of truth: all known HuggingFace time series model types. +# When adding a new time series model, register it here AND add a loader to +# plato/models/huggingface.py (_TIMESERIES_LOADERS). +TIMESERIES_MODEL_TYPES: frozenset[str] = frozenset({"timesfm", "patchtsmixer"}) + + +def is_timeseries_model( + model_name: Optional[str] = None, + model_type: Optional[str] = None, + dataset_type: Optional[str] = None, +) -> bool: + """ + Check if a model/dataset is for time series. + + Args: + model_name: Name of the model + model_type: Type of model from config + dataset_type: Type of dataset from config + + Returns: + True if this is a time series model, False otherwise + """ + model_name_lower = (model_name or "").lower() + model_type_lower = (model_type or "").lower() + + # Check explicit model type + if model_type_lower in TIMESERIES_MODEL_TYPES: + return True + + # Check if any known time series type appears in the model name + if any(ts_type in model_name_lower for ts_type in TIMESERIES_MODEL_TYPES): + return True + + # Generic "timeseries" keyword in name + if "timeseries" in model_name_lower: + return True + + # Check dataset type + if dataset_type and dataset_type.lower() == "timeseries": + return True + + return False diff --git a/plato/utils/tree.py b/plato/utils/tree.py index ed7c2b647..0f1a13c54 100644 --- a/plato/utils/tree.py +++ b/plato/utils/tree.py @@ -66,7 +66,10 @@ def _ensure_numpy(value: Any) -> np.ndarray: if callable(cpu_fn): tensor = cpu_fn() torch_bfloat16 = getattr(torch, "bfloat16", None) if torch is not None else None - if torch_bfloat16 is not None and getattr(tensor, "dtype", None) == torch_bfloat16: + if ( + torch_bfloat16 is not None + and getattr(tensor, "dtype", None) == torch_bfloat16 + ): tensor = tensor.to(torch.float32) numpy_fn = getattr(tensor, "numpy", None) if callable(numpy_fn): diff --git a/tests/algorithms/test_fedavg_algorithm.py b/tests/algorithms/test_fedavg_algorithm.py index 1ccdd8399..69657036e 100644 --- a/tests/algorithms/test_fedavg_algorithm.py +++ b/tests/algorithms/test_fedavg_algorithm.py @@ -39,9 +39,7 @@ class BFloat16ToyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.weight = torch.nn.Parameter( - torch.ones((2, 2), dtype=torch.bfloat16) - ) + self.weight = torch.nn.Parameter(torch.ones((2, 2), dtype=torch.bfloat16)) def _algorithm_for(model: torch.nn.Module) -> FedAvgAlgorithm: @@ -114,9 +112,7 @@ def test_extract_weights_casts_bfloat16_payloads_for_transport(): payload = algorithm.extract_weights() assert payload["weight"].dtype == torch.float32 - inbound = OrderedDict( - {"weight": torch.full((2, 2), 3.5, dtype=torch.float32)} - ) + inbound = OrderedDict({"weight": torch.full((2, 2), 3.5, dtype=torch.float32)}) algorithm.load_weights(inbound) state = model.state_dict() diff --git a/tests/clients/test_feddf_strategy.py b/tests/clients/test_feddf_strategy.py index 6676eda37..518735d6e 100644 --- a/tests/clients/test_feddf_strategy.py +++ b/tests/clients/test_feddf_strategy.py @@ -15,9 +15,7 @@ from tests.test_utils.fakes import FakeModel _TESTS_ROOT = Path(__file__).resolve().parent -_FEDDF_DIR = ( - _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" -) +_FEDDF_DIR = _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" if str(_FEDDF_DIR) not in sys.path: sys.path.insert(0, str(_FEDDF_DIR)) @@ -44,7 +42,9 @@ def test_feddf_training_strategy_returns_teacher_logits(temp_config): context = SimpleNamespace( client_id=1, current_round=1, - algorithm=SimpleNamespace(load_weights=lambda weights: loaded_weights.append(weights)), + algorithm=SimpleNamespace( + load_weights=lambda weights: loaded_weights.append(weights) + ), trainer=SimpleNamespace(model=FakeModel(), device="cpu"), state={}, ) @@ -58,14 +58,17 @@ def test_feddf_training_strategy_returns_teacher_logits(temp_config): mock_report = SimpleNamespace(num_samples=8) async_mock = AsyncMock(return_value=(mock_report, {"weights": torch.ones(1)})) - with patch.object( - feddf_client.DefaultTrainingStrategy, - "train", - new=async_mock, - ) as mock_train, patch.object( - feddf_client.time, - "perf_counter", - side_effect=[10.0, 10.25], + with ( + patch.object( + feddf_client.DefaultTrainingStrategy, + "train", + new=async_mock, + ) as mock_train, + patch.object( + feddf_client.time, + "perf_counter", + side_effect=[10.0, 10.25], + ), ): report, payload = asyncio.run(strategy.train(context)) diff --git a/tests/clients/test_simple_client.py b/tests/clients/test_simple_client.py index 43ffbb907..db3d7c838 100644 --- a/tests/clients/test_simple_client.py +++ b/tests/clients/test_simple_client.py @@ -1,7 +1,10 @@ """End-to-end smoke tests for the strategy-based client runtime.""" import asyncio +import pickle +import sys from dataclasses import dataclass +from pathlib import Path import torch from torch.utils.data import Dataset @@ -10,6 +13,20 @@ from plato.clients import simple from plato.config import Config from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies import AdamWOptimizerStrategy, StepLRSchedulerStrategy +from tests.test_utils.fakes import NoOpCommunicationStrategy + +LOCAL_STATE_PAYLOAD_KEYS = { + "optimizer_state", + "scheduler_state", + "trainer_state", + "local_metadata", + "metadata", + "global_step", + "local_optimizer_steps", + "_optimizer_state_input_filename", + "_optimizer_state_output_filename", +} class ToyDataset(Dataset): @@ -48,37 +65,154 @@ def get_test_set(self): return self._test -def _build_client(): +def _build_client(trainer=ComposableTrainer): """Instantiate a client wired with custom model, datasource, and trainer.""" return simple.Client( model=torch.nn.Linear(4, 2), datasource=ToyDatasource, - trainer=ComposableTrainer, + trainer=trainer, algorithm=lambda trainer: fedavg.Algorithm(trainer), ) -def test_simple_client_trains_with_default_strategies(temp_config): - """A simple client should complete one training round using the strategy stack.""" - Config().trainer = Config().trainer._replace(epochs=1, batch_size=2) +def _build_stateful_trainer(model=None, callbacks=None): + """Build a trainer whose local optimizer and scheduler state is non-empty.""" + return ComposableTrainer( + model=model, + callbacks=callbacks, + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) - client = _build_client() - # Assign identifiers expected by the client runtime. +def _configure_one_round_client(client): + """Prepare a client for a deterministic single training round.""" client.client_id = 1 client._context.client_id = 1 client.current_round = 1 client._context.current_round = 1 - # Prepare data and runtime components. client._load_data() client.configure() client._allocate_data() + +def _disable_payload_processors(client): + """Keep the test focused on decoded client-server model payload contents.""" + client.inbound_processor = None + client.outbound_processor = None + client._context.inbound_processor = None + client._context.outbound_processor = None + + +def _assert_model_weight_payload(payload, model): + """Assert that an outbound payload contains exactly model state tensors.""" + model_state = model.state_dict() + + assert isinstance(payload, dict) + assert set(payload) == set(model_state) + assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(payload) + assert all(torch.is_tensor(value) for value in payload.values()) + + for name, expected in model_state.items(): + assert torch.equal(payload[name], expected) + + +def _assert_preserved_state_is_local(trainer, client_id): + """Assert optimizer and scheduler persistence exists only in trainer state.""" + state = trainer._preserved_optimizer_states[client_id] + + assert state["optimizer_state"]["state"] + assert state["scheduler_state"] is not None + assert state["scheduler_state"]["last_epoch"] >= 1 + assert state["scheduler_state"]["_step_count"] >= 2 + + +def test_simple_client_trains_with_default_strategies(temp_config): + """A simple client should complete one training round using the strategy stack.""" + Config().trainer = Config().trainer._replace(epochs=1, batch_size=2) + + client = _build_client() + + _configure_one_round_client(client) + report, payload = asyncio.run(client._train()) assert report.client_id == 1 # With partition_size=4 each client receives four samples. assert report.num_samples == 4 - assert isinstance(payload, dict) - assert all(isinstance(value, torch.Tensor) for value in payload.values()) + _assert_model_weight_payload(payload, client.trainer.model) + + +def test_simple_client_payload_excludes_local_state_when_persistence_enabled( + temp_config, +): + """FedAvg/DiLoCo client payloads stay model-only with local persistence.""" + Config.params["run_id"] = "client-payload-in-process" + Config().trainer = Config().trainer._replace( + epochs=1, + batch_size=2, + preserve_optimizer_state=True, + ) + client = _build_client(trainer=_build_stateful_trainer) + client._configure_composable( + lifecycle_strategy=client.lifecycle_strategy, + payload_strategy=client.payload_strategy, + training_strategy=client.training_strategy, + reporting_strategy=client.reporting_strategy, + communication_strategy=NoOpCommunicationStrategy(), + ) + _configure_one_round_client(client) + _disable_payload_processors(client) + + server_payload = client.algorithm.extract_weights() + asyncio.run(client._handle_payload(server_payload)) + + sent_payload = client._context.state["sent_payloads"][-1] + _assert_preserved_state_is_local(client.trainer, client.client_id) + _assert_model_weight_payload(sent_payload, client.trainer.model) + + +def test_simple_client_subprocess_payload_excludes_local_state_sidecar( + temp_config, monkeypatch, tmp_path +): + """Subprocess persistence uses a sidecar without changing server payloads.""" + model_path = Path(tmp_path) / "models" / "pretrained" + checkpoint_path = Path(tmp_path) / "checkpoints" + model_path.mkdir(parents=True, exist_ok=True) + checkpoint_path.mkdir(parents=True, exist_ok=True) + Config.params["model_path"] = str(model_path) + Config.params["checkpoint_path"] = str(checkpoint_path) + Config.params["run_id"] = "client-payload-subprocess" + monkeypatch.setattr(sys, "argv", [sys.argv[0], "-b", str(tmp_path)]) + Config().trainer = Config().trainer._replace( + epochs=1, + batch_size=2, + max_concurrency=1, + preserve_optimizer_state=True, + ) + client = _build_client(trainer=_build_stateful_trainer) + client._configure_composable( + lifecycle_strategy=client.lifecycle_strategy, + payload_strategy=client.payload_strategy, + training_strategy=client.training_strategy, + reporting_strategy=client.reporting_strategy, + communication_strategy=NoOpCommunicationStrategy(), + ) + _configure_one_round_client(client) + _disable_payload_processors(client) + + server_payload = client.algorithm.extract_weights() + asyncio.run(client._handle_payload(server_payload)) + + sent_payload = client._context.state["sent_payloads"][-1] + state_path = Path( + Config.params["model_path"] + ) / client.trainer._optimizer_state_filename(Config.params["run_id"]) + with state_path.open("rb") as state_file: + sidecar_state = pickle.load(state_file) + + _assert_preserved_state_is_local(client.trainer, client.client_id) + assert sidecar_state["optimizer_state"]["state"] + assert sidecar_state["scheduler_state"] is not None + _assert_model_weight_payload(sent_payload, client.trainer.model) diff --git a/tests/datasources/test_huggingface_datasource.py b/tests/datasources/test_huggingface_datasource.py index 1937c34b6..532ee126e 100644 --- a/tests/datasources/test_huggingface_datasource.py +++ b/tests/datasources/test_huggingface_datasource.py @@ -30,8 +30,7 @@ def apply_chat_template( ): if not tokenize: return "".join( - f"<{message['role']}>{message['content']}|" - for message in messages + f"<{message['role']}>{message['content']}|" for message in messages ) tokens = [] @@ -125,8 +124,12 @@ def test_huggingface_datasource_keeps_validation_split_for_corpus_mode( } ) - monkeypatch.setattr(huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset) - monkeypatch.setattr(huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset) + monkeypatch.setattr( + huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset + ) + monkeypatch.setattr( + huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset + ) monkeypatch.setattr(huggingface_datasource.os.path, "exists", lambda *args: False) monkeypatch.setattr( huggingface_datasource.AutoConfig, @@ -172,8 +175,12 @@ def test_huggingface_datasource_falls_back_to_test_split(temp_config, monkeypatc } ) - monkeypatch.setattr(huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset) - monkeypatch.setattr(huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset) + monkeypatch.setattr( + huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset + ) + monkeypatch.setattr( + huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset + ) monkeypatch.setattr(huggingface_datasource.os.path, "exists", lambda *args: False) monkeypatch.setattr( huggingface_datasource.AutoConfig, @@ -218,9 +225,7 @@ def test_huggingface_datasource_loads_legacy_cache_path_when_present( } ) - legacy_path = ( - f"{Config().params['data_path']}/{cfg.data.dataset_name}_{cfg.data.dataset_config}" - ) + legacy_path = f"{Config().params['data_path']}/{cfg.data.dataset_name}_{cfg.data.dataset_config}" loaded_paths: list[str] = [] monkeypatch.setattr( @@ -260,7 +265,6 @@ class LargeContextDummyTokenizer(DummyTokenizer): model_max_length = 4096 - def test_huggingface_corpus_mode_keeps_legacy_default_block_size( temp_config, monkeypatch ): @@ -283,8 +287,12 @@ def test_huggingface_corpus_mode_keeps_legacy_default_block_size( } ) - monkeypatch.setattr(huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset) - monkeypatch.setattr(huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset) + monkeypatch.setattr( + huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset + ) + monkeypatch.setattr( + huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset + ) monkeypatch.setattr(huggingface_datasource.os.path, "exists", lambda *args: False) monkeypatch.setattr( huggingface_datasource.AutoConfig, diff --git a/tests/evaluators/test_lighteval.py b/tests/evaluators/test_lighteval.py index 3045d711f..702600520 100644 --- a/tests/evaluators/test_lighteval.py +++ b/tests/evaluators/test_lighteval.py @@ -42,7 +42,13 @@ def test_lighteval_fast_preset_contains_expected_tasks(temp_config): preset = _resolve_preset("smollm_round_fast") - assert preset["tasks"] == ["ifeval", "hellaswag", "arc_easy", "arc_challenge", "piqa"] + assert preset["tasks"] == [ + "ifeval", + "hellaswag", + "arc_easy", + "arc_challenge", + "piqa", + ] assert preset["primary_metric"] == "ifeval_avg" @@ -72,7 +78,9 @@ class FakeParallelismManager(Enum): ACCELERATE = auto() class FakePipelineParameters: - def __init__(self, launcher_type, custom_tasks_directory=None, max_samples=None): + def __init__( + self, launcher_type, custom_tasks_directory=None, max_samples=None + ): calls["launcher_type"] = launcher_type calls["custom_tasks_directory"] = custom_tasks_directory calls["max_samples"] = max_samples @@ -228,7 +236,9 @@ class FakeParallelismManager(Enum): ACCELERATE = auto() class FakePipelineParameters: - def __init__(self, launcher_type, custom_tasks_directory=None, max_samples=None): + def __init__( + self, launcher_type, custom_tasks_directory=None, max_samples=None + ): del launcher_type, custom_tasks_directory, max_samples class FakeEvaluationTracker: @@ -335,7 +345,9 @@ class FakeParallelismManager(Enum): ACCELERATE = auto() class FakePipelineParameters: - def __init__(self, launcher_type, custom_tasks_directory=None, max_samples=None): + def __init__( + self, launcher_type, custom_tasks_directory=None, max_samples=None + ): del launcher_type, custom_tasks_directory, max_samples class FakeEvaluationTracker: @@ -425,7 +437,9 @@ class FakeParallelismManager(Enum): ACCELERATE = auto() class FakePipelineParameters: - def __init__(self, launcher_type, custom_tasks_directory=None, max_samples=None): + def __init__( + self, launcher_type, custom_tasks_directory=None, max_samples=None + ): del launcher_type, custom_tasks_directory captured["max_samples"] = max_samples @@ -605,9 +619,7 @@ def _mock_pipeline(**kwargs): result = LightevalEvaluator( {"type": "lighteval", "preset": "smollm_round_fast"} - ).evaluate( - EvaluationInput(model=SaveableArtifact(), tokenizer=SaveableArtifact()) - ) + ).evaluate(EvaluationInput(model=SaveableArtifact(), tokenizer=SaveableArtifact())) assert result.metrics["ifeval_avg"] == pytest.approx(0.40) assert captured["model_name"] diff --git a/tests/evaluators/test_registry_runner.py b/tests/evaluators/test_registry_runner.py index 193c0ee6d..b4be4e9f9 100644 --- a/tests/evaluators/test_registry_runner.py +++ b/tests/evaluators/test_registry_runner.py @@ -109,7 +109,9 @@ def test_composable_trainer_runs_registered_evaluator_and_stores_results(temp_co testing_strategy=ConstantTestingStrategy(0.5), ) - accuracy = trainer.test_model(config={"batch_size": 1}, testset=[], sampler=None) + accuracy = trainer.test_model( + config={"batch_size": 1}, testset=[], sampler=None + ) assert accuracy == 0.5 assert trainer.accuracy == 0.5 @@ -257,7 +259,9 @@ def test_composable_trainer_tolerates_evaluator_runtime_failure_by_default( testing_strategy=ConstantTestingStrategy(0.5), ) - accuracy = trainer.test_model(config={"batch_size": 1}, testset=[], sampler=None) + accuracy = trainer.test_model( + config={"batch_size": 1}, testset=[], sampler=None + ) assert accuracy == 0.5 assert trainer.accuracy == 0.5 @@ -306,7 +310,9 @@ def test_composable_trainer_restores_grad_mode_after_evaluator_side_effect( ) assert torch.is_grad_enabled() is True - accuracy = trainer.test_model(config={"batch_size": 1}, testset=[], sampler=None) + accuracy = trainer.test_model( + config={"batch_size": 1}, testset=[], sampler=None + ) assert accuracy == 0.5 assert torch.is_grad_enabled() is True diff --git a/tests/integration/test_huggingface_smollm_smoke.py b/tests/integration/test_huggingface_smollm_smoke.py index 4ac9320c3..b1de91f89 100644 --- a/tests/integration/test_huggingface_smollm_smoke.py +++ b/tests/integration/test_huggingface_smollm_smoke.py @@ -93,7 +93,9 @@ def gradient_checkpointing_enable(self): def test_smollm_smoltalk_config_smoke(monkeypatch, tmp_path): """Smoke test the SmolLM2 + smol-smoltalk config with mocked HF/Lighteval hooks.""" repo_root = Path(__file__).resolve().parents[2] - config_path = repo_root / "configs/HuggingFace/fedavg_smol_smoltalk_smollm2_135m.toml" + config_path = ( + repo_root / "configs/HuggingFace/fedavg_smol_smoltalk_smollm2_135m.toml" + ) assert config_path.exists() dataset = DatasetDict( diff --git a/tests/integration/test_smoke_configs.py b/tests/integration/test_smoke_configs.py index 6dbc1fa08..ea99bb401 100644 --- a/tests/integration/test_smoke_configs.py +++ b/tests/integration/test_smoke_configs.py @@ -24,13 +24,14 @@ class MNISTSmokeDatasource: """Datasource returning image-shaped tensors for LeNet smoke tests.""" def __init__(self, train_size: int = 4, test_size: int = 2): + generator = torch.Generator().manual_seed(13) self._train = TensorDataset( - torch.randn(train_size, 1, 28, 28), - torch.randint(0, 10, (train_size,)), + torch.randn(train_size, 1, 28, 28, generator=generator), + torch.randint(0, 10, (train_size,), generator=generator), ) self._test = TensorDataset( - torch.randn(test_size, 1, 28, 28), - torch.randint(0, 10, (test_size,)), + torch.randn(test_size, 1, 28, 28, generator=generator), + torch.randint(0, 10, (test_size,), generator=generator), ) def num_train_examples(self): diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 3cb4a907b..4ff610756 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -107,6 +107,36 @@ def configure_environment(config_dict: dict): Config._instance = None +@contextlib.contextmanager +def configure_environment_from_path(config_path: Path): + """ + Context manager that initialises Config singleton from an existing config. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + Config._instance = None # reset singleton + Config.params = {} + + previous_env = os.environ.get("config_file") + previous_argv = sys.argv[:] + os.environ["config_file"] = str(config_path) + sys.argv = [ + previous_argv[0] if previous_argv else "pytest", + "--base", + tmp_dir, + ] + + try: + config = Config() + yield config + finally: + if previous_env is None: + os.environ.pop("config_file", None) + else: + os.environ["config_file"] = previous_env + sys.argv = previous_argv + Config._instance = None + + def async_run(coro): """Utility to execute the coroutine using asyncio.run (Python 3.7+).""" return asyncio.run(coro) diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py new file mode 100644 index 000000000..35b685b63 --- /dev/null +++ b/tests/servers/test_diloco_strategy.py @@ -0,0 +1,820 @@ +"""Tests for DiLoCo server-side outer aggregation.""" + +import asyncio +import logging +from types import SimpleNamespace + +import pytest +import torch + +from plato.servers.strategies.aggregation import DiLoCoAggregationStrategy +from plato.servers.strategies.base import ServerContext + + +class DummyAlgorithm: + """Minimal algorithm stub for zero-delta construction.""" + + def __init__(self, baseline): + self.baseline = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in baseline.items() + } + + def extract_weights(self): + return { + name: value.clone() if hasattr(value, "clone") else value + for name, value in self.baseline.items() + } + + def compute_weight_deltas(self, baseline_weights, weights_list): + return [ + { + name: weights[name] - baseline_weights[name] + for name in baseline_weights.keys() + } + for weights in weights_list + ] + + +class ServerAlgorithm(DummyAlgorithm): + """Algorithm stub for exercising FedAvg-compatible server dispatch.""" + + def __init__(self, baseline): + self.current = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in baseline.items() + } + self.delta_payloads = None + + def extract_weights(self): + return { + name: value.clone() if hasattr(value, "clone") else value + for name, value in self.current.items() + } + + def compute_weight_deltas(self, baseline_weights, weights_list): + self.delta_payloads = weights_list + return super().compute_weight_deltas(baseline_weights, weights_list) + + def update_weights(self, deltas): + self.current = { + name: self.current[name] + deltas[name] for name in self.current + } + return self.extract_weights() + + def load_weights(self, weights): + self.current = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in weights.items() + } + + +class RecordingDiLoCoStrategy(DiLoCoAggregationStrategy): + """DiLoCo strategy recording server dispatch calls.""" + + def __init__(self): + super().__init__( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + self.delta_calls = 0 + self.last_updates = None + self.last_deltas = None + + async def aggregate_deltas(self, updates, deltas_received, context): + self.delta_calls += 1 + self.last_updates = updates + self.last_deltas = deltas_received + return await super().aggregate_deltas(updates, deltas_received, context) + + +class MixedStateModel(torch.nn.Module): + """Model exposing trainable, frozen, floating-buffer, and integer state.""" + + def __init__(self): + super().__init__() + self.trainable = torch.nn.Parameter(torch.tensor([1.0])) + self.frozen = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=False) + self.register_buffer("floating_buffer", torch.tensor([1.0])) + self.register_buffer("integer_buffer", torch.tensor([1], dtype=torch.int64)) + self.register_buffer("bool_buffer", torch.tensor([True], dtype=torch.bool)) + + +class PeftLikeAdapterModel(torch.nn.Module): + """Model whose adapter payload keys omit PEFT's default adapter segment.""" + + def __init__(self): + super().__init__() + self.peft_config = {"default": object()} + self.base_model = torch.nn.Module() + self.base_model.model = torch.nn.Module() + self.base_model.model.linear = torch.nn.Module() + self.base_model.model.linear.lora_A = torch.nn.ModuleDict( + {"default": torch.nn.Linear(1, 1, bias=False)} + ) + + +class AdapterAliasCollisionModel(torch.nn.Module): + """Model with a trainable parameter and separate payload key collision.""" + + def __init__(self): + super().__init__() + self.peft_config = {"default": object()} + self.foo = torch.nn.ModuleDict({"default": torch.nn.Linear(1, 1, bias=False)}) + + +def _context(baseline=None, model=None): + context = ServerContext() + if baseline is not None: + context.algorithm = DummyAlgorithm(baseline) + if model is not None: + context.trainer = SimpleNamespace(model=model) + return context + + +def _update(num_samples, report_type="weights"): + return SimpleNamespace( + report=SimpleNamespace(num_samples=num_samples, type=report_type) + ) + + +def _server_update(payload, num_samples=1, report_type="weights"): + update = _update(num_samples, report_type) + update.client_id = len(str(payload)) + update.report.accuracy = 0.5 + update.report.processing_time = 0.1 + update.report.comm_time = 0.1 + update.report.training_time = 0.1 + update.payload = payload + return update + + +def _aggregate(strategy, updates, deltas, baseline=None, model=None): + return asyncio.run( + strategy.aggregate_deltas(updates, deltas, _context(baseline, model)) + ) + + +def test_diloco_server_type_uses_fedavg_algorithm_and_strategy(temp_config): + """server.type=diloco should select a FedAvg-compatible DiLoCo server.""" + from plato.algorithms import registry as algorithms_registry + from plato.config import Config + from plato.servers import diloco as diloco_server + from plato.servers import fedavg + from plato.servers import registry as servers_registry + + Config().server.type = "diloco" + Config().algorithm.type = "fedavg" + Config().server.diloco = SimpleNamespace( + outer_optimizer="sgd", + outer_learning_rate=0.25, + outer_momentum=0.1, + aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", + ) + + server = servers_registry.get() + + assert isinstance(server, diloco_server.Server) + assert isinstance(server, fedavg.Server) + assert isinstance(server.aggregation_strategy, DiLoCoAggregationStrategy) + assert server.aggregation_strategy.outer_optimizer == "sgd" + assert server.aggregation_strategy.outer_learning_rate == 0.25 + assert server.aggregation_strategy.outer_momentum == 0.1 + assert server.aggregation_strategy.aggregation_weighting == "num_samples" + assert server.aggregation_strategy.apply_outer_optimizer_to == "all_floating" + assert Config().algorithm.type == "fedavg" + assert "diloco" not in algorithms_registry.registered_algorithms + + +def test_diloco_server_process_reports_uses_delta_aggregation(temp_config): + """DiLoCo server processing should reach the delta aggregation path.""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + server.updates = [ + _server_update({"w": torch.tensor([2.0])}), + _server_update({"w": torch.tensor([4.0])}), + ] + + asyncio.run(server._process_reports()) + + assert strategy.delta_calls == 1 + assert strategy.last_updates == server.updates + assert len(strategy.last_deltas) == 2 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([3.0])) + + +def test_diloco_server_does_not_use_inherited_weight_aggregation(temp_config): + """DiLoCo must not bypass delta aggregation via inherited FedAvg weights.""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + + async def fail_if_called(*_args, **_kwargs): + raise AssertionError("Inherited aggregate_weights() must not be called.") + + strategy.aggregate_weights = fail_if_called + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + server.updates = [_server_update({"w": torch.tensor([2.0])})] + + asyncio.run(server._process_reports()) + + assert strategy.delta_calls == 1 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([2.0])) + + +def test_diloco_server_filters_non_weight_reports_before_delta_computation( + temp_config, +): + """Non-weight payloads should not reach compute_weight_deltas().""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + weight_payload = {"w": torch.tensor([2.0])} + server.updates = [ + _server_update("feature payload", report_type="features"), + _server_update({"metrics": 1.0}, report_type="metrics"), + _server_update(weight_payload), + ] + + asyncio.run(server._process_reports()) + + assert server.algorithm.delta_payloads == [weight_payload] + assert strategy.last_updates == [server.updates[2]] + assert len(strategy.last_deltas) == 1 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([2.0])) + + +def test_sgd_lr_one_uniform_matches_uniform_model_averaging(temp_config): + """Outer SGD with lr=1 should match uniform averaging under uniform mode.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(1), _update(99)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([15.0])) + + +def test_sgd_lr_one_num_samples_matches_weighted_fedavg(temp_config): + """Outer SGD with lr=1 should match sample-weighted FedAvg.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(1), _update(3)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(server_delta["w"], torch.tensor([6.5])) + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([16.5])) + + +def test_sgd_lr_half_moves_halfway_to_averaged_model(temp_config): + """A lower outer SGD lr should partially move toward the averaged model.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(5), _update(5)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(server_delta["w"], torch.tensor([2.5])) + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([12.5])) + + +def test_sgd_uses_diloco_outer_gradient_sign(temp_config): + """The strategy should negate Plato deltas before applying outer SGD.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=0.25, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([1.0])) + + +def test_outer_optimizer_application_is_logged(temp_config, caplog): + """A DiLoCo aggregation should report the server-side outer optimizer.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="nesterov", + outer_learning_rate=0.7, + outer_momentum=0.9, + aggregation_weighting="uniform", + apply_outer_optimizer_to="parameters", + ) + model = torch.nn.Linear(1, 1, bias=False) + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + with caplog.at_level(logging.INFO): + _aggregate( + strategy, + [_update(1), _update(1)], + [ + {"weight": torch.tensor([[2.0]])}, + {"weight": torch.tensor([[4.0]])}, + ], + baseline, + model, + ) + + message = caplog.text + assert "DiLoCo outer optimizer applied" in message + assert "optimizer=nesterov" in message + assert "outer_lr=0.7" in message + assert "outer_momentum=0.9" in message + assert "weighting=uniform" in message + assert "apply_to=parameters" in message + assert "eligible_updates=2" in message + assert "optimized_tensors=1" in message + + +def test_uniform_weighting_ignores_positive_sample_count_magnitude(temp_config): + """Uniform mode should weight eligible clients equally.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + server_delta = _aggregate( + strategy, + [_update(1), _update(1000)], + [{"w": torch.tensor([0.0])}, {"w": torch.tensor([10.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([5.0])) + + +def test_nonpositive_sample_reports_are_ineligible(temp_config): + """Reports with zero or negative sample counts should not affect averages.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", + ) + + server_delta = _aggregate( + strategy, + [_update(0), _update(-5), _update(10)], + [ + {"w": torch.tensor([100.0])}, + {"w": torch.tensor([100.0])}, + {"w": torch.tensor([4.0])}, + ], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0])) + + +def test_empty_eligible_updates_return_zero_delta(temp_config): + """An empty eligible set should produce a zero delta matching the baseline.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + baseline = {"w": torch.tensor([3.0, 4.0])} + server_delta = _aggregate( + strategy, + [_update(0), _update(5, report_type="features")], + [{"w": torch.tensor([10.0, 10.0])}, {"w": torch.tensor([10.0, 10.0])}], + baseline, + ) + + assert torch.allclose(server_delta["w"], torch.zeros_like(baseline["w"])) + + +def test_empty_eligible_updates_remove_stale_momentum(temp_config): + """A round with no eligible keys should clear stale momentum buffers.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + server_delta = _aggregate( + strategy, + [_update(0)], + [{"w": torch.tensor([10.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([0.0])) + assert strategy.momentum_state == {} + + +def test_sgdm_persists_momentum_across_rounds(temp_config): + """Momentum SGD should reuse server-side outer momentum across rounds.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + first_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + second_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(first_delta["w"], torch.tensor([2.0])) + assert torch.allclose(second_delta["w"], torch.tensor([5.0])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-5.0])) + + +def test_nesterov_uses_pytorch_style_two_round_recurrence(temp_config): + """Nesterov should use g + beta * m after updating the momentum buffer.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="nesterov", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + first_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + second_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(first_delta["w"], torch.tensor([3.0])) + assert torch.allclose(second_delta["w"], torch.tensor([6.5])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-5.0])) + + +def test_momentum_state_resets_on_shape_mismatch_and_removes_stale_keys( + temp_config, +): + """Momentum state should reset incompatible keys and prune missing keys.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0]), "b": torch.tensor([1.0])}], + {"w": torch.tensor([0.0]), "b": torch.tensor([0.0])}, + ) + + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0, 6.0])}], + {"w": torch.tensor([0.0, 0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0, 6.0])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-4.0, -6.0])) + assert "b" not in strategy.momentum_state + + +def test_momentum_state_resets_on_dtype_mismatch(temp_config): + """Momentum state should reset when the tensor dtype changes.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0], dtype=torch.float32)}], + {"w": torch.tensor([0.0], dtype=torch.float32)}, + ) + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0], dtype=torch.float64)}], + {"w": torch.tensor([0.0], dtype=torch.float64)}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0], dtype=torch.float64)) + assert strategy.momentum_state["w"].dtype == torch.float64 + + +def test_parameters_policy_optimizes_only_trainable_floating_parameters( + temp_config, +): + """Default policy should leave frozen parameters and buffers on FedAvg deltas.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = MixedStateModel() + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + first_delta = _aggregate( + strategy, + [_update(1), _update(1)], + [ + { + "trainable": torch.tensor([2.0]), + "frozen": torch.tensor([2.0]), + "floating_buffer": torch.tensor([2.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + }, + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([2], dtype=torch.int64), + "bool_buffer": torch.tensor([True]), + }, + ], + baseline, + model, + ) + + assert torch.allclose(first_delta["trainable"], torch.tensor([2.0])) + assert torch.allclose(first_delta["frozen"], torch.tensor([4.0])) + assert torch.allclose(first_delta["floating_buffer"], torch.tensor([4.0])) + assert torch.equal(first_delta["integer_buffer"], torch.tensor([2])) + assert torch.equal(first_delta["bool_buffer"], torch.tensor([True])) + assert set(strategy.momentum_state) == {"trainable"} + assert torch.allclose(strategy.momentum_state["trainable"], torch.tensor([-4.0])) + + second_delta = _aggregate( + strategy, + [_update(1)], + [ + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + } + ], + baseline, + model, + ) + + assert torch.allclose(second_delta["trainable"], torch.tensor([4.0])) + assert torch.allclose(second_delta["frozen"], torch.tensor([6.0])) + assert torch.allclose(second_delta["floating_buffer"], torch.tensor([6.0])) + assert torch.equal(second_delta["integer_buffer"], torch.tensor([1])) + assert torch.equal(second_delta["bool_buffer"], torch.tensor([False])) + assert set(strategy.momentum_state) == {"trainable"} + assert torch.allclose(strategy.momentum_state["trainable"], torch.tensor([-8.0])) + + +def test_all_floating_policy_optimizes_every_floating_state_tensor(temp_config): + """All-floating mode should not require model context for eligibility.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + model = MixedStateModel() + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + server_delta = _aggregate( + strategy, + [_update(1), _update(1)], + [ + { + "trainable": torch.tensor([2.0]), + "frozen": torch.tensor([2.0]), + "floating_buffer": torch.tensor([2.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + }, + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([2], dtype=torch.int64), + "bool_buffer": torch.tensor([True]), + }, + ], + baseline, + ) + + assert torch.allclose(server_delta["trainable"], torch.tensor([2.0])) + assert torch.allclose(server_delta["frozen"], torch.tensor([2.0])) + assert torch.allclose(server_delta["floating_buffer"], torch.tensor([2.0])) + assert torch.equal(server_delta["integer_buffer"], torch.tensor([2])) + assert torch.equal(server_delta["bool_buffer"], torch.tensor([True])) + assert set(strategy.momentum_state) == { + "trainable", + "frozen", + "floating_buffer", + } + + second_delta = _aggregate( + strategy, + [_update(1)], + [ + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + } + ], + baseline, + ) + + assert torch.allclose(second_delta["trainable"], torch.tensor([4.0])) + assert torch.allclose(second_delta["frozen"], torch.tensor([4.0])) + assert torch.allclose(second_delta["floating_buffer"], torch.tensor([4.0])) + assert torch.equal(second_delta["integer_buffer"], torch.tensor([1])) + assert torch.equal(second_delta["bool_buffer"], torch.tensor([False])) + assert set(strategy.momentum_state) == { + "trainable", + "frozen", + "floating_buffer", + } + + +def test_parameters_policy_requires_trainer_model_context(temp_config): + """Default parameter eligibility should fail clearly without a model.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + ) + + with pytest.raises(AttributeError, match="context.trainer.model"): + _aggregate( + strategy, + [_update(1)], + [{"trainable": torch.tensor([2.0])}], + {"trainable": torch.tensor([0.0])}, + ) + + +def test_parameters_policy_maps_peft_adapter_payload_names(temp_config): + """PEFT payloads can omit adapter-name segments from trainable param names.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = PeftLikeAdapterModel() + payload_name = "base_model.model.linear.lora_A.weight" + baseline = {payload_name: torch.zeros((1, 1))} + + server_delta = _aggregate( + strategy, + [_update(1)], + [{payload_name: torch.full((1, 1), 4.0)}], + baseline, + model, + ) + + assert torch.allclose(server_delta[payload_name], torch.full((1, 1), 2.0)) + assert set(strategy.momentum_state) == {payload_name} + assert torch.allclose( + strategy.momentum_state[payload_name], torch.full((1, 1), -4.0) + ) + + +def test_parameters_policy_does_not_overmatch_adapter_alias_collisions(temp_config): + """Alias support should not optimize unrelated colliding payload names.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = AdapterAliasCollisionModel() + trainable_name = "foo.default.weight" + colliding_name = "foo.weight" + baseline = { + trainable_name: torch.zeros((1, 1)), + colliding_name: torch.zeros((1, 1)), + } + + server_delta = _aggregate( + strategy, + [_update(1)], + [ + { + trainable_name: torch.full((1, 1), 4.0), + colliding_name: torch.full((1, 1), 4.0), + } + ], + baseline, + model, + ) + + assert torch.allclose(server_delta[trainable_name], torch.full((1, 1), 2.0)) + assert torch.allclose(server_delta[colliding_name], torch.full((1, 1), 4.0)) + assert set(strategy.momentum_state) == {trainable_name} + + +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"outer_optimizer": "adam"}, "outer_optimizer"), + ({"aggregation_weighting": "weighted"}, "aggregation_weighting"), + ({"apply_outer_optimizer_to": "buffers"}, "apply_outer_optimizer_to"), + ({"outer_learning_rate": -0.1}, "outer_learning_rate"), + ({"outer_momentum": -0.1}, "outer_momentum"), + ({"outer_momentum": 1.0}, "outer_momentum"), + ], +) +def test_invalid_config_values_fail_clearly(temp_config, kwargs, match): + """Invalid DiLoCo aggregation configuration should raise clear errors.""" + with pytest.raises(ValueError, match=match): + DiLoCoAggregationStrategy(**kwargs) diff --git a/tests/servers/test_fedavg_strategy.py b/tests/servers/test_fedavg_strategy.py index 2f788a9f1..41f0c89a5 100644 --- a/tests/servers/test_fedavg_strategy.py +++ b/tests/servers/test_fedavg_strategy.py @@ -242,9 +242,7 @@ def test_fedavg_server_prefers_custom_delta_strategy_over_inherited_weights( assert torch.allclose(server.algorithm.current["bias"], torch.ones(1)) -def test_fedavg_server_logged_items_flatten_evaluator_metrics( - temp_config, tmp_path -): +def test_fedavg_server_logged_items_flatten_evaluator_metrics(temp_config, tmp_path): """FedAvg should keep accuracy while surfacing evaluator summary metrics.""" from plato.config import Config from plato.servers import fedavg @@ -309,9 +307,7 @@ def test_fedavg_server_logged_items_include_detailed_lighteval_metrics( assert logged_items["evaluation_arc_challenge_acc_stderr"] == 0.0701 -def test_fedavg_server_does_not_persist_evaluator_jsonl_sidecar( - temp_config, tmp_path -): +def test_fedavg_server_does_not_persist_evaluator_jsonl_sidecar(temp_config, tmp_path): """FedAvg should rely on CSV logging instead of a JSONL sidecar.""" from plato.config import Config from plato.servers import fedavg diff --git a/tests/servers/test_feddf_server_strategy.py b/tests/servers/test_feddf_server_strategy.py index edf382b1d..39b584372 100644 --- a/tests/servers/test_feddf_server_strategy.py +++ b/tests/servers/test_feddf_server_strategy.py @@ -16,9 +16,7 @@ from plato.config import Config _TESTS_ROOT = Path(__file__).resolve().parent -_FEDDF_DIR = ( - _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" -) +_FEDDF_DIR = _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" if str(_FEDDF_DIR) not in sys.path: sys.path.insert(0, str(_FEDDF_DIR)) @@ -71,7 +69,9 @@ def __init__( self._unlabeled = TensorDataset(proxy_inputs, torch.zeros(len(proxy_inputs))) self._test = TensorDataset( test_inputs if test_inputs is not None else proxy_inputs, - torch.zeros(len(test_inputs) if test_inputs is not None else len(proxy_inputs)), + torch.zeros( + len(test_inputs) if test_inputs is not None else len(proxy_inputs) + ), ) def get_unlabeled_set(self): diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 41da1e1d1..19e43cee4 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -212,8 +212,6 @@ def test_config_loads_evaluation_section(tmp_path: Path, monkeypatch): Config._cli_overrides = {} - - def test_config_loads_smolvla_lerobot_parameter_contract(tmp_path: Path, monkeypatch): """SmolVLA/LeRobot config keys should round-trip through Config().""" config_base = tmp_path / "runtime" diff --git a/tests/trainers/strategies/test_loss_criterion.py b/tests/trainers/strategies/test_loss_criterion.py index f414cb031..7c798d813 100644 --- a/tests/trainers/strategies/test_loss_criterion.py +++ b/tests/trainers/strategies/test_loss_criterion.py @@ -44,7 +44,6 @@ def import_without_lightly(name, package=None): assert isinstance(criterion, nn.CrossEntropyLoss) - def test_loss_registry_ssl_loss_requires_optional_lightly(temp_config, monkeypatch): from plato.trainers import loss_criterion as loss_criterion_registry diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py new file mode 100644 index 000000000..2ddb53f86 --- /dev/null +++ b/tests/trainers/test_composable_optimizer_state.py @@ -0,0 +1,624 @@ +"""Tests for in-process optimizer state preservation in ComposableTrainer.""" + +import copy +import os +import pickle +import sys +from collections import OrderedDict +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from torch.utils.data import TensorDataset + +from plato.config import Config +from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies import ( + AdamWOptimizerStrategy, + CrossEntropyLossStrategy, + DefaultTrainingStepStrategy, + SGDOptimizerStrategy, + StepLRSchedulerStrategy, +) +from plato.trainers.strategies.base import OptimizerStrategy, TrainingContext + +LOCAL_STATE_PAYLOAD_KEYS = { + "optimizer_state", + "scheduler_state", + "trainer_state", + "local_metadata", + "metadata", + "global_step", + "local_optimizer_steps", + "_optimizer_state_input_filename", + "_optimizer_state_output_filename", +} + + +@pytest.fixture +def tiny_dataset(): + features = torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + [-1.0, 0.5], + ], + dtype=torch.float32, + ) + labels = torch.tensor([0, 1, 0, 1], dtype=torch.long) + return TensorDataset(features, labels) + + +@pytest.fixture +def one_step_config(): + return { + "batch_size": 4, + "epochs": 1, + "lr": 0.01, + "run_id": "optimizer-state-test", + } + + +class CapturingTrainingStep(DefaultTrainingStepStrategy): + """Record optimizer state before each local optimizer step.""" + + def __init__(self): + super().__init__() + self.pre_step_states = [] + self.pre_step_lrs = [] + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + optimizer_state = optimizer.state_dict() + self.pre_step_states.append(copy.deepcopy(optimizer_state["state"])) + self.pre_step_lrs.append( + [group["lr"] for group in optimizer_state["param_groups"]] + ) + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + +def _linear_model(): + return nn.Sequential(OrderedDict([("linear", nn.Linear(2, 2))])) + + +class DeviceTrackingModel(nn.Module): + """Model that records whether it has been moved to a trainer device.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + self.moved_to_trainer_device = False + + def forward(self, features): + return self.linear(features) + + def to(self, *args, **kwargs): + self.moved_to_trainer_device = True + return super().to(*args, **kwargs) + + +class RestoreOrderOptimizer(torch.optim.SGD): + """Optimizer that records whether state restore happens after model.to().""" + + def __init__(self, params, model: DeviceTrackingModel): + self.model = model + self.loaded_after_model_to = None + super().__init__(params, lr=0.01, momentum=0.9) + + def load_state_dict(self, state_dict): + self.loaded_after_model_to = self.model.moved_to_trainer_device + if not self.loaded_after_model_to: + raise AssertionError("optimizer state restored before model.to()") + return super().load_state_dict(state_dict) + + +class RestoreOrderOptimizerStrategy(OptimizerStrategy): + """Create restore-order-aware optimizers for regression tests.""" + + def __init__(self): + self.optimizers = [] + + def create_optimizer( + self, model: DeviceTrackingModel, context: TrainingContext + ) -> torch.optim.Optimizer: + optimizer = RestoreOrderOptimizer(model.parameters(), model) + self.optimizers.append(optimizer) + return optimizer + + +def _two_layer_model(first_name="first", second_name="second"): + return nn.Sequential( + OrderedDict( + [ + (first_name, nn.Linear(2, 2, bias=False)), + (second_name, nn.Linear(2, 2, bias=False)), + ] + ) + ) + + +def _first_param_state(optimizer_state): + return next(iter(optimizer_state.values())) + + +def _state_step(param_state): + step = param_state["step"] + if isinstance(step, torch.Tensor): + return int(step.item()) + return int(step) + + +def _configure_subprocess_training( + monkeypatch, + tmp_path, + *, + preserve_optimizer_state, +): + """Configure parent and spawned child processes to share local artifacts.""" + model_path = Path(tmp_path) / "models" / "pretrained" + model_path.mkdir(parents=True, exist_ok=True) + Config.params["model_path"] = str(model_path) + Config.params["checkpoint_path"] = str(Path(tmp_path) / "checkpoints") + Config.params["run_id"] = "subprocess-optimizer-state" + os.makedirs(Config.params["checkpoint_path"], exist_ok=True) + monkeypatch.setattr(sys, "argv", [sys.argv[0], "-b", str(tmp_path)]) + Config().trainer = Config().trainer._replace( + max_concurrency=1, + preserve_optimizer_state=preserve_optimizer_state, + batch_size=4, + epochs=1, + ) + + +def _cached_optimizer_step(trainer): + payload = trainer._preserved_optimizer_states[trainer.client_id] + return _state_step(_first_param_state(payload["optimizer_state"]["state"])) + + +def _cached_scheduler_last_epoch(trainer): + payload = trainer._preserved_optimizer_states[trainer.client_id] + return payload["scheduler_state"]["last_epoch"] + + +def _assert_model_update_contains_only_model_weights(update, model): + model_state = model.state_dict() + + assert set(update) == set(model_state) + assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(update) + assert all(torch.is_tensor(value) for value in update.values()) + + +def test_adamw_moment_buffers_persist_between_rounds_for_same_client( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + trainer.set_client_id(7) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + round1_state = copy.deepcopy(trainer.optimizer.state_dict()["state"]) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[0] == {} + restored_state = _first_param_state(step_strategy.pre_step_states[1]) + saved_state = _first_param_state(round1_state) + assert torch.allclose(restored_state["exp_avg"], saved_state["exp_avg"]) + assert torch.allclose(restored_state["exp_avg_sq"], saved_state["exp_avg_sq"]) + final_param_state = _first_param_state(trainer.optimizer.state_dict()["state"]) + assert _state_step(final_param_state) == 2 + + +def test_preserved_optimizer_state_restores_after_model_moves_to_device( + temp_config, tiny_dataset, one_step_config +): + config = {**one_step_config, "preserve_optimizer_state": True} + source_trainer = ComposableTrainer( + model=DeviceTrackingModel, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=RestoreOrderOptimizerStrategy(), + ) + source_trainer.set_client_id(11) + source_trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + restore_strategy = RestoreOrderOptimizerStrategy() + trainer = ComposableTrainer( + model=DeviceTrackingModel, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=restore_strategy, + ) + trainer.set_client_id(11) + trainer._preserved_optimizer_states[11] = copy.deepcopy( + source_trainer._preserved_optimizer_states[11] + ) + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert restore_strategy.optimizers[0].loaded_after_model_to is True + restored_state = _first_param_state( + trainer._preserved_optimizer_states[11]["optimizer_state"]["state"] + ) + assert "momentum_buffer" in restored_state + + +def test_scheduler_state_and_lr_progress_persist_between_rounds( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=SGDOptimizerStrategy(lr=0.2), + training_step_strategy=step_strategy, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + trainer.set_client_id(3) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_lrs == [[0.2], [0.1]] + assert trainer.lr_scheduler.last_epoch == 2 + assert trainer.optimizer.param_groups[0]["lr"] == pytest.approx(0.05) + + +def test_subprocess_optimizer_state_parent_reloads_after_child( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert trainer.client_id in trainer._preserved_optimizer_states + assert _cached_optimizer_step(trainer) == 1 + state_path = Path(Config.params["model_path"]) / trainer._optimizer_state_filename( + Config.params["run_id"] + ) + assert state_path.exists() + assert "optimizer_state" not in trainer.obtain_model_update( + { + "batch_size": 4, + "epochs": 1, + "lr": 0.01, + "run_id": "payload-check", + "preserve_optimizer_state": True, + }, + tiny_dataset, + list(range(len(tiny_dataset))), + ) + + +def test_subprocess_optimizer_state_persists_across_rounds_for_same_client( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert _cached_optimizer_step(trainer) == 2 + + +def test_subprocess_scheduler_state_persists_across_rounds( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=SGDOptimizerStrategy(lr=0.2), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + trainer.set_client_id(3) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + payload = trainer._preserved_optimizer_states[trainer.client_id] + assert _cached_scheduler_last_epoch(trainer) == 2 + assert payload["optimizer_state"]["param_groups"][0]["lr"] == pytest.approx(0.05) + + +def test_subprocess_missing_sidecar_clears_inherited_parent_cache( + temp_config, monkeypatch, tmp_path, tiny_dataset, one_step_config +): + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) + source_trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + source_trainer.set_client_id(7) + config = { + **one_step_config, + "run_id": Config.params["run_id"], + "preserve_optimizer_state": True, + } + source_trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + assert _cached_optimizer_step(source_trainer) == 1 + + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + trainer._preserved_optimizer_states[7] = copy.deepcopy( + source_trainer._preserved_optimizer_states[7] + ) + + state_path = Path(Config.params["model_path"]) / trainer._optimizer_state_filename( + Config.params["run_id"] + ) + state_path.unlink(missing_ok=True) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert _cached_optimizer_step(trainer) == 1 + + +def test_missing_subprocess_output_removes_stale_input_sidecar( + temp_config, monkeypatch, tmp_path, tiny_dataset, one_step_config +): + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + config = { + **one_step_config, + "run_id": Config.params["run_id"], + "preserve_optimizer_state": True, + } + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + input_filename = trainer._optimizer_state_filename(Config.params["run_id"]) + missing_output_filename = trainer._optimizer_state_output_filename( + Config.params["run_id"] + ) + assert trainer._save_preserved_optimizer_state_file(input_filename) + input_path = Path(Config.params["model_path"]) / input_filename + assert input_path.exists() + + trainer._finish_subprocess_optimizer_state(input_filename, missing_output_filename) + + assert trainer.client_id not in trainer._preserved_optimizer_states + assert not input_path.exists() + + +def test_subprocess_invalid_optimizer_state_resets_safely( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + state_path = Path(Config.params["model_path"]) / trainer._optimizer_state_filename( + Config.params["run_id"] + ) + with open(state_path, "wb") as state_file: + pickle.dump({"optimizer_type": torch.optim.SGD}, state_file) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + payload = trainer._preserved_optimizer_states[trainer.client_id] + assert payload["optimizer_type"] is torch.optim.AdamW + assert _cached_optimizer_step(trainer) == 1 + + +def test_subprocess_optimizer_state_is_not_persisted_when_disabled( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=False + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert trainer._preserved_optimizer_states == {} + state_path = Path(Config.params["model_path"]) / trainer._optimizer_state_filename( + Config.params["run_id"] + ) + assert not state_path.exists() + + +def test_preserved_optimizer_state_is_local_to_logical_client( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.set_client_id(1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + client1_state = copy.deepcopy(trainer.optimizer.state_dict()["state"]) + + trainer.set_client_id(2) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + trainer.set_client_id(1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[0] == {} + assert step_strategy.pre_step_states[1] == {} + restored_state = _first_param_state(step_strategy.pre_step_states[2]) + saved_state = _first_param_state(client1_state) + assert torch.allclose(restored_state["exp_avg"], saved_state["exp_avg"]) + + +def test_preserved_state_stays_out_of_model_update_payload( + temp_config, tiny_dataset, one_step_config +): + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + trainer.set_client_id(5) + config = {**one_step_config, "preserve_optimizer_state": True} + + update = trainer.obtain_model_update( + config, tiny_dataset, list(range(len(tiny_dataset))) + ) + preserved_state = trainer._preserved_optimizer_states[trainer.client_id] + + assert preserved_state["optimizer_state"]["state"] + assert preserved_state["scheduler_state"]["last_epoch"] >= 1 + _assert_model_update_contains_only_model_weights(update, trainer.model) + + +def test_preserved_state_invalidates_when_parameter_order_changes( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=lambda: _two_layer_model("first", "second"), + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.model = _two_layer_model("second", "first") + trainer.context.model = trainer.model + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[1] == {} + + +def test_preserved_state_invalidates_when_optimizer_type_changes( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.optimizer_strategy = SGDOptimizerStrategy(lr=0.1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[1] == {} + assert isinstance(trainer.optimizer, torch.optim.SGD) + + +def test_preserved_state_compatibility_rejects_shape_dtype_and_scheduler_changes( + temp_config, tiny_dataset, one_step_config +): + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(4) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + payload = copy.deepcopy(trainer._preserved_optimizer_states[4]) + + current_model = trainer.model + current_optimizer = trainer.optimizer_strategy.create_optimizer( + current_model, trainer.context + ) + changed_scheduler = StepLRSchedulerStrategy( + step_size=1, gamma=0.5 + ).create_scheduler(current_optimizer, trainer.context) + assert not trainer._preserved_state_is_compatible( + payload, current_model, current_optimizer, changed_scheduler + ) + + changed_shape_model = nn.Sequential(OrderedDict([("linear", nn.Linear(2, 3))])) + changed_shape_optimizer = trainer.optimizer_strategy.create_optimizer( + changed_shape_model, trainer.context + ) + assert not trainer._preserved_state_is_compatible( + payload, changed_shape_model, changed_shape_optimizer, None + ) + + changed_dtype_model = _linear_model().to(torch.float64) + changed_dtype_optimizer = trainer.optimizer_strategy.create_optimizer( + changed_dtype_model, trainer.context + ) + assert not trainer._preserved_state_is_compatible( + payload, changed_dtype_model, changed_dtype_optimizer, None + ) + + +@pytest.mark.parametrize("preserve_value", [None, False]) +def test_optimizer_state_is_not_restored_when_disabled_or_unset( + temp_config, tiny_dataset, one_step_config, preserve_value +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = dict(one_step_config) + if preserve_value is not None: + config["preserve_optimizer_state"] = preserve_value + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states == [{}, {}] + final_param_state = _first_param_state(trainer.optimizer.state_dict()["state"]) + assert _state_step(final_param_state) == 1 diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index ff35f76ec..e9ca21090 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -5,11 +5,14 @@ it works correctly in end-to-end training scenarios. """ +import logging + import pytest import torch import torch.nn as nn -from torch.utils.data import TensorDataset +from torch.utils.data import SubsetRandomSampler, TensorDataset +from plato.callbacks.trainer import TrainerCallback from plato.config import Config from plato.evaluators.runner import EVALUATION_PRIMARY_KEY, EVALUATION_RESULTS_KEY from plato.trainers.composable import ComposableTrainer @@ -18,13 +21,16 @@ CrossEntropyLossStrategy, DefaultDataLoaderStrategy, DefaultTrainingStepStrategy, + GradientAccumulationStepStrategy, NoOpUpdateStrategy, NoSchedulerStrategy, + StepLRSchedulerStrategy, ) from plato.trainers.strategies.base import ( LossCriterionStrategy, ModelUpdateStrategy, TrainingContext, + TrainingStepStrategy, ) @@ -178,6 +184,453 @@ def test_multiple_epochs(self, simple_model, simple_dataset): assert len(trainer.run_history.get_metric_values("train_loss")) == 5 +class TestComposableTrainerLocalSteps: + """Test local optimizer-step limits for DiLoCo-style local work.""" + + class DeterministicPlatoSampler: + def __init__(self, indices, seed=47): + self.indices = list(indices) + self.seed = seed + + def get(self): + generator = torch.Generator() + generator.manual_seed(self.seed) + return SubsetRandomSampler(self.indices, generator=generator) + + def num_samples(self): + return len(self.indices) + + class NonMaterializableSampler(torch.utils.data.Sampler): + def __iter__(self): + raise NotImplementedError("This sampler cannot be materialized.") + + def __len__(self): + return 10 + + class CountingCallback(TrainerCallback): + def __init__(self): + self.train_run_end_called = False + self.train_step_end_count = 0 + + def on_train_step_end(self, trainer, config, batch, loss, **kwargs): + self.train_step_end_count += 1 + + def on_train_run_end(self, trainer, config, **kwargs): + self.train_run_end_called = True + + class CountingUpdateStrategy(ModelUpdateStrategy): + def __init__(self): + self.after_step_count = 0 + self.on_train_end_called = False + + def after_step(self, context): + self.after_step_count += 1 + + def on_train_end(self, context): + self.on_train_end_called = True + + class CountingStepStrategy(DefaultTrainingStepStrategy): + def __init__(self): + super().__init__() + self.batch_count = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + self.batch_count += 1 + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + class RecordingStepStrategy(DefaultTrainingStepStrategy): + def __init__(self): + super().__init__() + self.samples_by_round = {} + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + sample_ids = examples[:, 0].detach().cpu().int().tolist() + self.samples_by_round.setdefault(context.current_round, []).extend( + sample_ids + ) + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + class DelayedOptimizerStepStrategy(TrainingStepStrategy): + def __init__(self, accumulation_steps=2, finalize_steps=True): + self.accumulation_steps = accumulation_steps + self.finalize_steps = finalize_steps + self.raw_batch_count = 0 + self.optimizer_step_count = 0 + self.finalize_calls = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + outputs = model(examples) + loss = loss_criterion(outputs, labels) + (loss / self.accumulation_steps).backward() + + self.raw_batch_count += 1 + if self.raw_batch_count % self.accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + self.optimizer_step_count += 1 + context.state["optimizer_step_completed"] = True + else: + context.state["optimizer_step_completed"] = False + + return loss + + def finalize(self, model, optimizer, context): + self.finalize_calls += 1 + if not self.finalize_steps: + return None + + optimizer.step() + optimizer.zero_grad() + self.optimizer_step_count += 1 + context.state["optimizer_step_completed"] = True + return torch.tensor(0.0) + + class CountingGradientAccumulationStepStrategy(GradientAccumulationStepStrategy): + def __init__(self, accumulation_steps): + super().__init__(accumulation_steps=accumulation_steps) + self.raw_batch_count = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + self.raw_batch_count += 1 + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + def test_local_steps_stop_mid_epoch_and_run_cleanup( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + callback = self.CountingCallback() + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.CountingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + callbacks=[callback], + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.batch_count == 3 + assert update_strategy.after_step_count == 3 + assert callback.train_step_end_count == 3 + assert trainer.current_epoch == 1 + assert trainer.context.state["local_optimizer_steps"] == 3 + assert update_strategy.on_train_end_called + assert callback.train_run_end_called + + def test_local_steps_count_optimizer_steps_not_raw_batches( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 2, + } + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.DelayedOptimizerStepStrategy(accumulation_steps=3) + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 6 + assert step_strategy.optimizer_step_count == 2 + assert update_strategy.after_step_count == 2 + assert trainer.context.state["local_optimizer_steps"] == 2 + + def test_local_steps_respect_builtin_gradient_accumulation( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 2, + } + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.CountingGradientAccumulationStepStrategy( + accumulation_steps=3 + ) + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 6 + assert update_strategy.after_step_count == 2 + assert trainer.context.state["local_optimizer_steps"] == 2 + + def test_local_steps_skip_finalize_after_limit_is_reached( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 2, + "local_steps_per_round": 1, + } + step_strategy = self.DelayedOptimizerStepStrategy(accumulation_steps=2) + update_strategy = self.CountingUpdateStrategy() + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 2 + assert step_strategy.optimizer_step_count == 1 + assert step_strategy.finalize_calls == 0 + assert update_strategy.after_step_count == 1 + assert trainer.context.state["local_optimizer_steps"] == 1 + + def test_epoch_behavior_is_unchanged_when_local_steps_unset( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 10, + "epochs": 2, + } + step_strategy = self.CountingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.batch_count == 20 + assert trainer.current_epoch == 2 + assert len(trainer.run_history.get_metric_values("train_loss")) == 2 + + def test_local_steps_do_not_replay_same_deterministic_sampler_prefix( + self, simple_model, simple_config + ): + dataset_size = 10 + features = torch.arange(dataset_size, dtype=torch.float32).view(-1, 1) + features = features.repeat(1, 10) + labels = torch.arange(dataset_size) % 2 + dataset = TensorDataset(features, labels) + config = { + **simple_config, + "batch_size": 1, + "epochs": 1, + "local_steps_per_round": 3, + } + sampler = self.DeterministicPlatoSampler(range(dataset_size)) + step_strategy = self.RecordingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=step_strategy, + ) + trainer.set_client_id(2) + + for round_number in (1, 2): + trainer.current_round = round_number + trainer.train_model(config, dataset, sampler) + + round_one_samples = step_strategy.samples_by_round[1] + round_two_samples = step_strategy.samples_by_round[2] + + assert len(round_one_samples) == config["local_steps_per_round"] + assert len(round_two_samples) == config["local_steps_per_round"] + assert round_one_samples != round_two_samples + + repeat_step_strategy = self.RecordingStepStrategy() + repeat_trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=repeat_step_strategy, + ) + repeat_trainer.set_client_id(2) + + for round_number in (1, 2): + repeat_trainer.current_round = round_number + repeat_trainer.train_model(config, dataset, sampler) + + assert repeat_step_strategy.samples_by_round == step_strategy.samples_by_round + + def test_local_step_sampling_warns_for_non_materializable_sampler( + self, simple_dataset, caplog + ): + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + sampler = self.NonMaterializableSampler() + + with caplog.at_level(logging.WARNING): + loader = DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + sampler, + batch_size=1, + context=context, + ) + + assert loader.sampler is sampler + assert ( + "cannot be materialized for round-aware local-step sampling" in caplog.text + ) + + def test_diloco_local_steps_require_full_client_participation( + self, simple_dataset, temp_config + ): + Config().server.type = "diloco" + Config().clients.total_clients = 4 + Config().clients.per_round = 2 + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + + with pytest.raises( + ValueError, match="clients\\.per_round.*clients\\.total_clients" + ): + DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + list(range(len(simple_dataset))), + batch_size=1, + context=context, + ) + + def test_partial_participation_still_allowed_without_diloco_local_steps( + self, simple_dataset, temp_config + ): + Config().server.type = "fedavg" + Config().clients.total_clients = 4 + Config().clients.per_round = 2 + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + + loader = DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + list(range(len(simple_dataset))), + batch_size=1, + context=context, + ) + + assert len(loader.sampler) == len(simple_dataset) + + def test_diloco_local_steps_advance_lr_scheduler_per_optimizer_step( + self, simple_model, simple_dataset, simple_config, temp_config + ): + Config().server.type = "diloco" + Config().clients.total_clients = 4 + Config().clients.per_round = 4 + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + trainer = ComposableTrainer( + model=simple_model, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert trainer.context.state["local_optimizer_steps"] == 3 + assert trainer.lr_scheduler.last_epoch == 3 + + def test_non_diloco_local_steps_keep_epoch_based_lr_scheduler( + self, simple_model, simple_dataset, simple_config, temp_config + ): + Config().server.type = "fedavg" + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + trainer = ComposableTrainer( + model=simple_model, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert trainer.context.state["local_optimizer_steps"] == 3 + assert trainer.lr_scheduler.last_epoch == 1 + + @pytest.mark.parametrize("local_steps_per_round", [0, -1, 1.5, "2", True]) + def test_invalid_local_steps_fail_clearly( + self, simple_model, simple_dataset, simple_config, local_steps_per_round + ): + config = { + **simple_config, + "local_steps_per_round": local_steps_per_round, + } + trainer = ComposableTrainer(model=simple_model) + + with pytest.raises(ValueError, match="local_steps_per_round"): + trainer.train_model( + config, simple_dataset, list(range(len(simple_dataset))) + ) + + class TestComposableTrainerStrategies: """Test strategy integration.""" @@ -468,9 +921,7 @@ def test_test_state_roundtrip_persists_evaluation_metadata(self, temp_config): "metric": "ifeval_avg", "value": 0.31, } - assert trainer.context.state["nanochat_core_results"] == { - "core_metric": 0.9 - } + assert trainer.context.state["nanochat_core_results"] == {"core_metric": 0.9} def test_test_state_restore_clears_stale_evaluation_metadata(self, temp_config): trainer = ComposableTrainer(model=nn.Linear(2, 1)) diff --git a/tests/trainers/test_huggingface_trainer.py b/tests/trainers/test_huggingface_trainer.py index 597c8084c..144c37ea8 100644 --- a/tests/trainers/test_huggingface_trainer.py +++ b/tests/trainers/test_huggingface_trainer.py @@ -30,7 +30,9 @@ def pad(self, features, padding=True, return_tensors=None): for feature in features: pad_width = max_len - len(feature["input_ids"]) - batch["input_ids"].append(feature["input_ids"] + [self.pad_token_id] * pad_width) + batch["input_ids"].append( + feature["input_ids"] + [self.pad_token_id] * pad_width + ) batch["attention_mask"].append( feature.get("attention_mask", [1] * len(feature["input_ids"])) + [0] * pad_width diff --git a/tests/trainers/test_lerobot_trainer.py b/tests/trainers/test_lerobot_trainer.py index 1d1316aff..9b849bcaa 100644 --- a/tests/trainers/test_lerobot_trainer.py +++ b/tests/trainers/test_lerobot_trainer.py @@ -166,7 +166,7 @@ def test_lerobot_trainer_consumes_policy_precision_and_device( monkeypatch.setattr( lerobot_trainer, "_import_make_pre_post_processors", - lambda: (lambda *_args, **_kwargs: (lambda batch: batch, lambda out: out)), + lambda: lambda *_args, **_kwargs: (lambda batch: batch, lambda out: out), ) trainer = lerobot_trainer.Trainer(model=_TinyLeRobotPolicy())