From 80766f7b333224771aec0f20cfdc63e680b686bf Mon Sep 17 00:00:00 2001 From: liuliuliu0605 Date: Tue, 26 Sep 2023 17:56:20 +0800 Subject: [PATCH 1/3] add mpi version of hierarchical fl and examples --- .../mpi_torch_hierarchical_fl/README.md | 11 ++ .../mpi_torch_hierarchical_fl/__init__.py | 0 .../config/mnist_lr/fedml_config.yaml | 49 +++++ .../config/mnist_lr/gpu_mapping.yaml | 70 ++++++++ .../mpi_torch_hierarchical_fl/mpi_host_file | 1 + .../run_step_by_step_example.sh | 10 ++ .../torch_step_by_step_example.py | 1 + .../mpi/hierarchical_fl/HierClient.py | 55 ++++++ .../mpi/hierarchical_fl/HierFedAvgAPI.py | 168 ++++++++++++++++++ .../HierFedAvgCloudAggregator.py | 168 ++++++++++++++++++ .../hierarchical_fl/HierFedAvgCloudManager.py | 134 ++++++++++++++ .../hierarchical_fl/HierFedAvgEdgeManager.py | 72 ++++++++ .../mpi/hierarchical_fl/HierGroup.py | 73 ++++++++ .../mpi/hierarchical_fl/__init__.py | 0 .../mpi/hierarchical_fl/message_define.py | 34 ++++ .../simulation/mpi/hierarchical_fl/utils.py | 30 ++++ python/fedml/simulation/simulator.py | 13 ++ 17 files changed, 889 insertions(+) create mode 100644 python/examples/simulation/mpi_torch_hierarchical_fl/README.md create mode 100644 python/examples/simulation/mpi_torch_hierarchical_fl/__init__.py create mode 100644 python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml create mode 100644 python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/gpu_mapping.yaml create mode 100644 python/examples/simulation/mpi_torch_hierarchical_fl/mpi_host_file create mode 100755 python/examples/simulation/mpi_torch_hierarchical_fl/run_step_by_step_example.sh create mode 100644 python/examples/simulation/mpi_torch_hierarchical_fl/torch_step_by_step_example.py create mode 100644 python/fedml/simulation/mpi/hierarchical_fl/HierClient.py create mode 100644 python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py create mode 100644 python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py create mode 100644 python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py create mode 100644 python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py create mode 100644 python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py create mode 100644 python/fedml/simulation/mpi/hierarchical_fl/__init__.py create mode 100644 python/fedml/simulation/mpi/hierarchical_fl/message_define.py create mode 100644 python/fedml/simulation/mpi/hierarchical_fl/utils.py diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/README.md b/python/examples/simulation/mpi_torch_hierarchical_fl/README.md new file mode 100644 index 0000000000..3010fd24ad --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/README.md @@ -0,0 +1,11 @@ +# Install FedML and Prepare the Distributed Environment +``` +pip install fedml +``` + + +# Run the example (step by step APIs) +``` +sh run_step_by_step_example.sh 4 config/mnist_lr/fedml_config.yaml +``` + diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/__init__.py b/python/examples/simulation/mpi_torch_hierarchical_fl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml new file mode 100644 index 0000000000..c5a244d1bd --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml @@ -0,0 +1,49 @@ +common_args: + training_type: "simulation" + random_seed: 0 + +data_args: + dataset: "mnist" + data_cache_dir: ~/fedml_data + partition_method: "hetero" + partition_alpha: 0.5 + +model_args: + model: "lr" + +train_args: + federated_optimizer: "HierarchicalFL" + client_id_list: "[]" + client_num_in_total: 1000 + client_num_per_round: 10 + comm_round: 20 + epochs: 1 + batch_size: 10 + client_optimizer: sgd + learning_rate: 0.03 + weight_decay: 0.001 + group_method: "random" + group_num: 2 + group_comm_round: 5 + +validation_args: + frequency_of_the_test: 5 + +device_args: + worker_num: 3 + using_gpu: true + gpu_mapping_file: config/mnist_lr/gpu_mapping.yaml + gpu_mapping_key: mapping_config1_3 + +comm_args: + backend: "MPI" + is_mobile: 0 + + +tracking_args: + # When running on MLOps platform(open.fedml.ai), the default log path is at ~/fedml-client/fedml/logs/ and ~/fedml-server/fedml/logs/ + enable_wandb: true + wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 + wandb_project: fedml + run_name: mpi_hierarchical_fl_mnist_lr + wandb_only_server: true \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/gpu_mapping.yaml b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/gpu_mapping.yaml new file mode 100644 index 0000000000..8c4961681f --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/gpu_mapping.yaml @@ -0,0 +1,70 @@ +# You can define a cluster containing multiple GPUs within multiple machines by defining `gpu_mapping.yaml` as follows: + +# config_cluster0: +# host_name_node0: [num_of_processes_on_GPU0, num_of_processes_on_GPU1, num_of_processes_on_GPU2, num_of_processes_on_GPU3, ..., num_of_processes_on_GPU_n] +# host_name_node1: [num_of_processes_on_GPU0, num_of_processes_on_GPU1, num_of_processes_on_GPU2, num_of_processes_on_GPU3, ..., num_of_processes_on_GPU_n] +# host_name_node_m: [num_of_processes_on_GPU0, num_of_processes_on_GPU1, num_of_processes_on_GPU2, num_of_processes_on_GPU3, ..., num_of_processes_on_GPU_n] + + +# this is used for 10 clients and 1 server training within a single machine which has 4 GPUs +mapping_default: + ChaoyangHe-GPU-RTX2080Tix4: [3, 3, 3, 2] + +mapping_config1_2: + host1: [1, 1] + +mapping_config1_3: + host1: [1, 1, 1] + +# this is used for 4 clients and 1 server training within a single machine which has 4 GPUs +mapping_config1_5: + host1: [2, 1, 1, 1] + +# this is used for 4 clients and 1 server training within a single machine which has 4 GPUs +mapping_config1_6: + host1: [2, 2, 1, 1] + +# this is used for 10 clients and 1 server training within a single machine which has 4 GPUs +mapping_config2_11: + host1: [3, 3, 3, 2] + +# this is used for 10 clients and 1 server training within a single machine which has 8 GPUs +mapping_config3_11: + host1: [2, 2, 2, 1, 1, 1, 1, 1] + +# this is used for 4 clients and 1 server training within a single machine which has 8 GPUs, but you hope to skip the GPU device ID. +mapping_config4_5: + host1: [1, 0, 0, 1, 1, 0, 1, 1] + +# this is used for 4 clients and 1 server training using 6 machines, each machine has 2 GPUs inside, but you hope to use the second GPU. +mapping_config5_6: + host1: [0, 1] + host2: [0, 1] + host3: [0, 1] + host4: [0, 1] + host5: [0, 1] +# this is used for 4 clients and 1 server training using 2 machines, each machine has 2 GPUs inside, but you hope to use the second GPU. +mapping_config5_2: + gpu-worker2: [1,1] + gpu-worker1: [2,1] + +# this is used for 10 clients and 1 server training using 4 machines, each machine has 2 GPUs inside, but you hope to use the second GPU. +mapping_config5_4: + gpu-worker2: [1,1] + gpu-worker1: [2,1] + gpu-worker3: [3,1] + gpu-worker4: [1,1] + +# for grpc GPU mapping +mapping_FedML_gRPC: + hostname_node_server: [1] + hostname_node_1: [1, 0, 0, 0] + hostname_node_2: [1, 0, 0, 0] + +# for torch RPC GPU mapping +mapping_FedML_tRPC: + lambda-server1: [0, 0, 0, 0, 2, 2, 1, 1] + lambda-server2: [2, 1, 1, 1, 0, 0, 0, 0] + +#mapping_FedML_tRPC: +# lambda-server1: [0, 0, 0, 0, 3, 3, 3, 2] \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/mpi_host_file b/python/examples/simulation/mpi_torch_hierarchical_fl/mpi_host_file new file mode 100644 index 0000000000..ebed096720 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/mpi_host_file @@ -0,0 +1 @@ +liuxuezheng3 diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/run_step_by_step_example.sh b/python/examples/simulation/mpi_torch_hierarchical_fl/run_step_by_step_example.sh new file mode 100755 index 0000000000..8dd565e6e8 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/run_step_by_step_example.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +WORKER_NUM=$1 +CONFIG_PATH=$2 + +hostname > mpi_host_file + +mpirun -np $WORKER_NUM \ +-hostfile mpi_host_file \ +python torch_step_by_step_example.py --cf $CONFIG_PATH \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/torch_step_by_step_example.py b/python/examples/simulation/mpi_torch_hierarchical_fl/torch_step_by_step_example.py new file mode 100644 index 0000000000..aa52f56397 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/torch_step_by_step_example.py @@ -0,0 +1 @@ +import fedml from fedml import FedMLRunner if __name__ == "__main__": # init FedML framework args = fedml.init() # init device device = fedml.device.get_device(args) # load data dataset, output_dim = fedml.data.load(args) # load model model = fedml.model.create(args, output_dim) # start training fedml_runner = FedMLRunner(args, device, dataset, model) fedml_runner.run() \ No newline at end of file diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py b/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py new file mode 100644 index 0000000000..8e4cf98c03 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py @@ -0,0 +1,55 @@ +import copy + +import torch +import torch.nn as nn + +from ...sp.fedavg.client import Client + + +class HFLClient(Client): + def __init__(self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model, + model_trainer): + + super().__init__(client_idx, local_training_data, local_test_data, local_sample_number, args, device, + model_trainer) + self.client_idx = client_idx + self.local_training_data = local_training_data + self.local_test_data = local_test_data + self.local_sample_number = local_sample_number + + self.args = args + self.device = device + self.model = model + self.model_trainer = model_trainer + self.criterion = nn.CrossEntropyLoss().to(device) + + def train(self, global_round_idx, group_round_idx, w): + self.model.load_state_dict(w) + self.model.to(self.device) + + if self.args.client_optimizer == "sgd": + optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.learning_rate) + else: + optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.args.learning_rate, + weight_decay=self.args.wd, + amsgrad=True, + ) + + w_list = [] + for epoch in range(self.args.epochs): + for x, labels in self.local_training_data: + x, labels = x.to(self.device), labels.to(self.device) + self.model.zero_grad() + log_probs = self.model(x) + loss = self.criterion(log_probs, labels) # pylint: disable=E1102 + loss.backward() + optimizer.step() + client_round = ( + global_round_idx * self.args.group_comm_round + + group_round_idx + ) + # if client_round % self.args.frequency_of_the_test == 0: + w_list.append((client_round, copy.deepcopy(self.model.cpu().state_dict()))) + return w_list diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py new file mode 100644 index 0000000000..662c1cc961 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py @@ -0,0 +1,168 @@ +from .HierFedAvgCloudAggregator import HierFedAVGCloudAggregator +from .HierFedAvgCloudManager import HierFedAVGCloudManager +from .HierFedAvgEdgeManager import HierFedAVGEdgeManager +from .HierGroup import HierGroup +from ....core import ClientTrainer, ServerAggregator +from ....core.dp.fedml_differential_privacy import FedMLDifferentialPrivacy +from ....core.security.fedml_attacker import FedMLAttacker +from ....core.security.fedml_defender import FedMLDefender +from ....ml.aggregator.aggregator_creator import create_server_aggregator +from ....ml.trainer.trainer_creator import create_model_trainer + +import numpy as np + + +def FedML_HierFedAvg_distributed( + args, + process_id, + worker_number, + comm, + device, + dataset, + model, + client_trainer: ClientTrainer = None, + server_aggregator: ServerAggregator = None, +): + [ + train_data_num, + test_data_num, + train_data_global, + test_data_global, + train_data_local_num_dict, + train_data_local_dict, + test_data_local_dict, + class_num, + ] = dataset + + FedMLAttacker.get_instance().init(args) + FedMLDefender.get_instance().init(args) + FedMLDifferentialPrivacy.get_instance().init(args) + + if process_id == 0: + init_cloud_server( + args, + device, + comm, + process_id, + worker_number, + model, + train_data_num, + train_data_global, + test_data_global, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + server_aggregator + ) + else: + init_edge_server_clients( + args, + device, + comm, + process_id, + worker_number, + model, + train_data_num, + train_data_local_num_dict, + train_data_local_dict, + test_data_local_dict, + client_trainer, + server_aggregator + ) + + +def init_cloud_server( + args, + device, + comm, + rank, + size, + model, + train_data_num, + train_data_global, + test_data_global, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + server_aggregator +): + if server_aggregator is None: + server_aggregator = create_server_aggregator(model, args) + server_aggregator.set_id(-1) + + # aggregator + worker_num = size - 1 + aggregator = HierFedAVGCloudAggregator( + train_data_global, + test_data_global, + train_data_num, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + worker_num, + device, + args, + server_aggregator, + ) + + # start the distributed training + backend = args.backend + group_indexes, group_to_client_indexes = setup_clients(args) + server_manager = HierFedAVGCloudManager(args, aggregator, group_indexes, group_to_client_indexes, comm, rank, size, backend) + server_manager.send_init_msg() + server_manager.run() + + +def init_edge_server_clients( + args, + device, + comm, + process_id, + size, + model, + train_data_num, + train_data_local_num_dict, + train_data_local_dict, + test_data_local_dict, + group, + model_trainer=None, +): + + if model_trainer is None: + model_trainer = create_model_trainer(model, args) + + edge_index = process_id - 1 + backend = args.backend + + # Client assignment is decided on cloud server and the information will be communicated later + group = HierGroup( + edge_index, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + args, + device, + model, + model_trainer + ) + + edge_manager = HierFedAVGEdgeManager(group, args, comm, process_id, size, backend) + edge_manager.run() + + +def setup_clients( + args + ): + if args.group_method == "random": + group_indexes = np.random.randint( + 0, args.group_num, args.client_num_in_total + ) + group_to_client_indexes = {} + for client_idx, group_idx in enumerate(group_indexes): + if not group_idx in group_to_client_indexes: + group_to_client_indexes[group_idx] = [] + group_to_client_indexes[group_idx].append(client_idx) + else: + raise Exception(args.group_method) + + return group_indexes, group_to_client_indexes \ No newline at end of file diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py new file mode 100644 index 0000000000..59ef1af712 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py @@ -0,0 +1,168 @@ +import logging +import random +import time +import numpy as np +import torch + +from ....core.security.fedml_attacker import FedMLAttacker +from ....core.security.fedml_defender import FedMLDefender + + +class HierFedAVGCloudAggregator(object): + def __init__( + self, + train_global, + test_global, + all_train_data_num, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + worker_num, + device, + args, + server_aggregator, + ): + self.aggregator = server_aggregator + self.args = args + self.train_global = train_global + self.test_global = test_global + self.val_global = self._generate_validation_set() + self.all_train_data_num = all_train_data_num + + self.train_data_local_dict = train_data_local_dict + self.test_data_local_dict = test_data_local_dict + self.train_data_local_num_dict = train_data_local_num_dict + + self.worker_num = worker_num + self.device = device + self.model_dict = dict() + self.sample_num_dict = dict() + self.flag_client_model_uploaded_dict = dict() + for idx in range(self.worker_num): + self.flag_client_model_uploaded_dict[idx] = False + + def get_global_model_params(self): + return self.aggregator.get_model_params() + + def set_global_model_params(self, model_parameters): + self.aggregator.set_model_params(model_parameters) + + def add_local_trained_result(self, index, model_params_list, sample_num): + logging.info("add_model. index = %d" % index) + self.model_dict[index] = model_params_list + self.sample_num_dict[index] = sample_num + self.flag_client_model_uploaded_dict[index] = True + + def check_whether_all_receive(self): + logging.debug("worker_num = {}".format(self.worker_num)) + for idx in range(self.worker_num): + if not self.flag_client_model_uploaded_dict[idx]: + return False + for idx in range(self.worker_num): + self.flag_client_model_uploaded_dict[idx] = False + return True + + def aggregate(self): + start_time = time.time() + + # Edge server may conduct partial aggregation multiple times, so cloud server will receive a model list + for group_round_idx in range(self.args.group_comm_round): + model_list = [] + + for idx in range(0, self.worker_num): + model_list.append((self.sample_num_dict[idx], + self.model_dict[idx][group_round_idx][1])) + + client_round = self.model_dict[0][group_round_idx][0] + averaged_params = self._fedavg_aggregation_(model_list) + self.set_global_model_params(averaged_params) + self.test_on_cloud_for_all_clients(client_round) + + if FedMLAttacker.get_instance().is_model_attack(): + model_list = FedMLAttacker.get_instance().attack_model(raw_client_grad_list=model_list, extra_auxiliary_info=None) + + if FedMLDefender.get_instance().is_defense_enabled(): + # todo: update extra_auxiliary_info according to defense type + averaged_params = FedMLDefender.get_instance().defend( + raw_client_grad_list=model_list, + base_aggregation_func=self._fedavg_aggregation_, + extra_auxiliary_info=self.get_global_model_params(), + ) + else: + averaged_params = self._fedavg_aggregation_(model_list) + + # update the global model which is cached in the cloud + self.set_global_model_params(averaged_params) + + end_time = time.time() + logging.info("aggregate time cost: %d" % (end_time - start_time)) + return averaged_params + + def _fedavg_aggregation_(self, model_list): + training_num = 0 + for i in range(0, len(model_list)): + local_sample_number, local_model_params = model_list[i] + training_num += local_sample_number + (num0, averaged_params) = model_list[0] + for k in averaged_params.keys(): + for i in range(0, len(model_list)): + local_sample_number, local_model_params = model_list[i] + if i == 0: + averaged_params[k] = ( + local_model_params[k] * local_sample_number / training_num + ) + else: + averaged_params[k] += ( + local_model_params[k] * local_sample_number / training_num + ) + return averaged_params + + def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + if client_num_in_total == client_num_per_round: + client_indexes = [ + client_index for client_index in range(client_num_in_total) + ] + else: + num_clients = min(client_num_per_round, client_num_in_total) + np.random.seed( + round_idx + ) # make sure for each comparison, we are selecting the same clients each round + client_indexes = np.random.choice( + range(client_num_in_total), num_clients, replace=False + ) + logging.info("client_indexes = %s" % str(client_indexes)) + return client_indexes + + def _generate_validation_set(self, num_samples=10000): + if self.args.dataset.startswith("stackoverflow"): + test_data_num = len(self.test_global.dataset) + sample_indices = random.sample( + range(test_data_num), min(num_samples, test_data_num) + ) + subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) + sample_testset = torch.utils.data.DataLoader( + subset, batch_size=self.args.batch_size + ) + return sample_testset + else: + return self.test_global + + def test_on_cloud_for_all_clients(self, client_round): + + if ( + client_round % self.args.frequency_of_the_test == 0 + or client_round == self.args.comm_round * self.args.group_comm_round - 1 + ): + + logging.info("################test_on_cloud_for_all_clients : {}".format(client_round)) + + # We may want to test the intermediate results of partial aggregated models, so we play a trick and let + # args.round_idx be total number of partial aggregated times + round_idx = self.args.round_idx + self.args.round_idx = client_round + train_metric_result_in_current_round = self.aggregator.test(self.train_global, self.device, self.args) + test_metric_result_in_current_round = self.aggregator.test(self.test_global, self.device, self.args) + self.args.round_idx = round_idx + + logging.info("train_metric_result_in_current_round = {}".format(train_metric_result_in_current_round)) + logging.info("test_metric_result_in_current_round = {}".format(test_metric_result_in_current_round)) \ No newline at end of file diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py new file mode 100644 index 0000000000..3f7813b17b --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py @@ -0,0 +1,134 @@ +import logging + +from .message_define import MyMessage +from ....core.distributed.fedml_comm_manager import FedMLCommManager +from ....core.distributed.communication.message import Message +from .utils import post_complete_message_to_sweep_process + +class HierFedAVGCloudManager(FedMLCommManager): + def __init__( + self, + args, + aggregator, + group_indexes, + group_to_client_indexes, + comm=None, + rank=0, + size=0, + backend="MPI", + # is_preprocessed=False, + # preprocessed_client_lists=None, + ): + super().__init__(args, comm, rank, size, backend) + self.args = args + self.aggregator = aggregator + self.group_indexes = group_indexes + self.group_to_client_indexes = group_to_client_indexes + self.round_num = args.comm_round + self.args.round_idx = 0 + # self.is_preprocessed = is_preprocessed + # self.preprocessed_client_lists = preprocessed_client_lists + + def run(self): + super().run() + + def send_init_msg(self): + # broadcast to edge servers + global_model_params = self.aggregator.get_global_model_params() + + sampled_client_indexes = self.aggregator.client_sampling( + self.args.round_idx, + self.args.client_num_in_total, + self.args.client_num_per_round, + ) + + sampled_group_to_client_indexes = {} + for client_idx in sampled_client_indexes: + group_idx = self.group_indexes[client_idx] + if not group_idx in sampled_group_to_client_indexes: + sampled_group_to_client_indexes[group_idx] = [] + sampled_group_to_client_indexes[group_idx].append(client_idx) + logging.info( + "client_indexes of each group = {}".format(sampled_group_to_client_indexes) + ) + + for process_id in range(1, self.size): + self.send_message_init_config( + process_id, + global_model_params, + self.group_to_client_indexes[process_id - 1], + sampled_group_to_client_indexes[process_id - 1], + process_id - 1 + ) + + def register_message_receive_handlers(self): + self.register_message_receive_handler( + MyMessage.MSG_TYPE_E2C_SEND_MODEL_TO_CLOUD, + self.handle_message_receive_model_from_edge, + ) + + def handle_message_receive_model_from_edge(self, msg_params): + sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) + model_params_list = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_LIST) + edge_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) + + self.aggregator.add_local_trained_result( + sender_id - 1, model_params_list, edge_sample_number + ) + b_all_received = self.aggregator.check_whether_all_receive() + logging.info("b_all_received = " + str(b_all_received)) + if b_all_received: + + global_model_params = self.aggregator.aggregate() + # start the next round + self.args.round_idx += 1 + if self.args.round_idx == self.round_num: + post_complete_message_to_sweep_process(self.args) + self.finish() + return + + sampled_client_indexes = self.aggregator.client_sampling( + self.args.round_idx, + self.args.client_num_in_total, + self.args.client_num_per_round, + ) + + sampled_group_to_client_indexes = {} + for client_idx in sampled_client_indexes: + group_idx = self.group_indexes[client_idx] + if not group_idx in sampled_group_to_client_indexes: + sampled_group_to_client_indexes[group_idx] = [] + sampled_group_to_client_indexes[group_idx].append(client_idx) + logging.info( + "client_indexes of each group = {}".format(sampled_group_to_client_indexes) + ) + + for receiver_id in range(1, self.size): + self.send_message_sync_model_to_edge( + receiver_id, global_model_params, + sampled_group_to_client_indexes[receiver_id-1], receiver_id-1 + ) + + def send_message_init_config(self, receive_id, global_model_params, total_client_indexes, sampled_client_indexed, edge_index): + message = Message( + MyMessage.MSG_TYPE_C2E_INIT_CONFIG, self.get_sender_id(), receive_id + ) + message.add_params(MyMessage.MSG_ARG_KEY_TOTAL_EDGE_CLIENTS, total_client_indexes) + message.add_params(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS, sampled_client_indexed) + message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params(MyMessage.MSG_ARG_KEY_EDGE_INDEX, str(edge_index)) + self.send_message(message) + + def send_message_sync_model_to_edge( + self, receive_id, global_model_params, sampled_client_indexed, edge_index + ): + logging.info("send_message_sync_model_to_edge. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_C2E_SYNC_MODEL_TO_EDGE, + self.get_sender_id(), + receive_id, + ) + message.add_params(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS, sampled_client_indexed) + message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params(MyMessage.MSG_ARG_KEY_EDGE_INDEX, str(edge_index)) + self.send_message(message) diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py new file mode 100644 index 0000000000..f60469ed95 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py @@ -0,0 +1,72 @@ +import logging + +from .message_define import MyMessage +from ....core.distributed.fedml_comm_manager import FedMLCommManager +from ....core.distributed.communication.message import Message +from .utils import post_complete_message_to_sweep_process + + +class HierFedAVGEdgeManager(FedMLCommManager): + def __init__( + self, + group, + args, + comm=None, + rank=0, + size=0, + backend="MPI", + ): + super().__init__(args, comm, rank, size, backend) + self.num_rounds = args.comm_round + self.args.round_idx = 0 + self.group =group + + def run(self): + super().run() + + def register_message_receive_handlers(self): + self.register_message_receive_handler( + MyMessage.MSG_TYPE_C2E_INIT_CONFIG, self.handle_message_init + ) + self.register_message_receive_handler( + MyMessage.MSG_TYPE_C2E_SYNC_MODEL_TO_EDGE, + self.handle_message_receive_model_from_cloud, + ) + + def handle_message_init(self, msg_params): + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) + total_client_indexes = msg_params.get(MyMessage.MSG_ARG_KEY_TOTAL_EDGE_CLIENTS) + sampled_client_indexes = msg_params.get(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS) + edge_index = msg_params.get(MyMessage.MSG_ARG_KEY_EDGE_INDEX) + + self.group.setup_clients(total_client_indexes) + self.args.round_idx = 0 + w_group_list = self.group.train(self.args.round_idx, global_model_params, sampled_client_indexes) + edge_sample_num = self.group.get_sample_number(sampled_client_indexes) + + self.send_model_to_cloud(0, w_group_list, edge_sample_num) + + def handle_message_receive_model_from_cloud(self, msg_params): + logging.info("handle_message_receive_model_from_cloud.") + sampled_client_indexes = msg_params.get(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS) + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) + edge_index = msg_params.get(MyMessage.MSG_ARG_KEY_EDGE_INDEX) + + self.args.round_idx += 1 + w_group_list = self.group.train(self.args.round_idx, global_model_params, sampled_client_indexes) + edge_sample_num = self.group.get_sample_number(sampled_client_indexes) + self.send_model_to_cloud(0, w_group_list, edge_sample_num) + + if self.args.round_idx == self.num_rounds: + post_complete_message_to_sweep_process(self.args) + self.finish() + + def send_model_to_cloud(self, receive_id, w_group_list, edge_sample_num): + message = Message( + MyMessage.MSG_TYPE_E2C_SEND_MODEL_TO_CLOUD, + self.get_sender_id(), + receive_id, + ) + message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_LIST, w_group_list) + message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, edge_sample_num) + self.send_message(message) diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py b/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py new file mode 100644 index 0000000000..de06379721 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py @@ -0,0 +1,73 @@ +import logging + +from .HierClient import HFLClient +from ...sp.fedavg.fedavg_api import FedAvgAPI + + +class HierGroup(FedAvgAPI): + def __init__( + self, + idx, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + args, + device, + model, + model_trainer, + ): + self.idx = idx + self.args = args + self.device = device + self.client_dict = {} + self.train_data_local_num_dict = train_data_local_num_dict + self.train_data_local_dict = train_data_local_dict + self.test_data_local_dict = test_data_local_dict + self.model = model + self.model_trainer = model_trainer + self.args = args + + def setup_clients(self, total_client_indexes): + self.client_dict = {} + for client_idx in total_client_indexes: + self.client_dict[client_idx] = HFLClient( + client_idx, + self.train_data_local_dict[client_idx], + self.test_data_local_dict[client_idx], + self.train_data_local_num_dict[client_idx], + self.args, + self.device, + self.model, + self.model_trainer, + ) + + def get_sample_number(self, sampled_client_indexes): + self.group_sample_number = 0 + for client_idx in sampled_client_indexes: + self.group_sample_number += self.train_data_local_num_dict[client_idx] + return self.group_sample_number + + def train(self, global_round_idx, w, sampled_client_indexes): + sampled_client_list = [self.client_dict[client_idx] for client_idx in sampled_client_indexes] + w_group = w + w_group_list = [] + for group_round_idx in range(self.args.group_comm_round): + logging.info("Group ID : {} / Group Communication Round : {}".format(self.idx, group_round_idx)) + w_locals_dict = {} + + # train each client + for client in sampled_client_list: + w_local_list = client.train(global_round_idx, group_round_idx, w_group) + for client_round, w in w_local_list: + if not client_round in w_locals_dict: + w_locals_dict[client_round] = [] + w_locals_dict[client_round].append((client.get_sample_number(), w)) + + # aggregate local weights + for client_round in sorted(w_locals_dict.keys()): + w_locals = w_locals_dict[client_round] + w_group_list.append((client_round, self._aggregate(w_locals))) + + # update the group weight + w_group = w_group_list[-1][1] + return w_group_list diff --git a/python/fedml/simulation/mpi/hierarchical_fl/__init__.py b/python/fedml/simulation/mpi/hierarchical_fl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fedml/simulation/mpi/hierarchical_fl/message_define.py b/python/fedml/simulation/mpi/hierarchical_fl/message_define.py new file mode 100644 index 0000000000..ecb0982b89 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/message_define.py @@ -0,0 +1,34 @@ +class MyMessage(object): + """ + message type definition + """ + + # cloud to edge + MSG_TYPE_C2E_INIT_CONFIG = 1 + MSG_TYPE_C2E_SYNC_MODEL_TO_EDGE = 2 + + # edge to cloud + MSG_TYPE_E2C_SEND_MODEL_TO_CLOUD = 3 + MSG_TYPE_E2C_SEND_STATS_TO_CLOUD = 4 + + MSG_ARG_KEY_TYPE = "msg_type" + MSG_ARG_KEY_SENDER = "sender" + MSG_ARG_KEY_RECEIVER = "receiver" + + """ + message payload keywords definition + """ + MSG_ARG_KEY_NUM_SAMPLES = "num_samples" + MSG_ARG_KEY_MODEL_PARAMS = "model_params" + MSG_ARG_KEY_MODEL_PARAMS_LIST = "model_params_list" + MSG_ARG_KEY_EDGE_INDEX = "edge_idx" + MSG_ARG_KEY_TOTAL_EDGE_CLIENTS = "total_edge_clients" + MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS = "sampled_edge_clients" + + MSG_ARG_KEY_TRAIN_CORRECT = "train_correct" + MSG_ARG_KEY_TRAIN_ERROR = "train_error" + MSG_ARG_KEY_TRAIN_NUM = "train_num_sample" + + MSG_ARG_KEY_TEST_CORRECT = "test_correct" + MSG_ARG_KEY_TEST_ERROR = "test_error" + MSG_ARG_KEY_TEST_NUM = "test_num_sample" diff --git a/python/fedml/simulation/mpi/hierarchical_fl/utils.py b/python/fedml/simulation/mpi/hierarchical_fl/utils.py new file mode 100644 index 0000000000..9143a374d4 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/utils.py @@ -0,0 +1,30 @@ +import os +import time + +import numpy as np +import torch + + +def transform_list_to_tensor(model_params_list): + for k in model_params_list.keys(): + model_params_list[k] = torch.from_numpy( + np.asarray(model_params_list[k]) + ).float() + return model_params_list + + +def transform_tensor_to_list(model_params): + for k in model_params.keys(): + model_params[k] = model_params[k].detach().numpy().tolist() + return model_params + + +def post_complete_message_to_sweep_process(args): + pipe_path = "./tmp/fedml" + if not os.path.exists(pipe_path): + os.mkfifo(pipe_path) + pipe_fd = os.open(pipe_path, os.O_WRONLY) + + with os.fdopen(pipe_fd, "w") as pipe: + pipe.write("training is finished! \n%s\n" % (str(args))) + time.sleep(3) diff --git a/python/fedml/simulation/simulator.py b/python/fedml/simulation/simulator.py index abf0394869..558d0f6aa9 100644 --- a/python/fedml/simulation/simulator.py +++ b/python/fedml/simulation/simulator.py @@ -90,6 +90,7 @@ def __init__( from .mpi.fedavg_seq.FedAvgSeqAPI import FedML_FedAvgSeq_distributed from .mpi.async_fedavg.AsyncFedAvgSeqAPI import FedML_Async_distributed from .mpi.fednova.FedNovaAPI import FedML_FedNova_distributed + from .mpi.hierarchical_fl.HierFedAvgAPI import FedML_HierFedAvg_distributed if args.federated_optimizer == FedML_FEDERATED_OPTIMIZER_FEDAVG: FedML_FedAvg_distributed( @@ -159,6 +160,18 @@ def __init__( SplitNN_distributed( args.process_id, args.worker_num, device, args.comm, model, dataset=dataset, args=args, ) + elif args.federated_optimizer == FedML_FEDERATED_OPTIMIZER_HIERACHICAL_FL: + FedML_HierFedAvg_distributed( + args, + args.process_id, + args.worker_num, + args.comm, + device, + dataset, + model, + client_trainer=client_trainer, + server_aggregator=server_aggregator + ) elif args.federated_optimizer == FedML_FEDERATED_OPTIMIZER_DECENTRALIZED_FL: FedML_Decentralized_Demo_distributed(args, args.process_id, args.worker_num, args.comm) elif args.federated_optimizer == FedML_FEDERATED_OPTIMIZER_FEDGAN: From 36d74a6328fda4828b810cb136afbd4977ab8056 Mon Sep 17 00:00:00 2001 From: liuliuliu0605 Date: Wed, 27 Sep 2023 19:53:09 +0800 Subject: [PATCH 2/3] add mixing operation between edge servers according to some topology --- .../mpi_torch_hierarchical_fl/README.md | 4 +- .../mpi_torch_hierarchical_fl/batch_lauch.sh | 10 ++ .../config/mnist_lr/fedml_config.yaml | 8 +- .../config/mnist_lr/fedml_config_topo.yaml | 51 +++++++++ .../topology/symmetric_topology_manager.py | 27 ++++- .../core/distributed/topology/topo_utils.py | 94 ++++++++++++++++ .../mpi/hierarchical_fl/HierClient.py | 12 +-- .../mpi/hierarchical_fl/HierFedAvgAPI.py | 45 ++++++-- .../HierFedAvgCloudAggregator.py | 100 +++++++++++++++--- .../hierarchical_fl/HierFedAvgCloudManager.py | 81 +++++++++----- .../hierarchical_fl/HierFedAvgEdgeManager.py | 10 +- .../mpi/hierarchical_fl/HierGroup.py | 23 ++-- .../simulation/mpi/hierarchical_fl/utils.py | 99 +++++++++++++++++ 13 files changed, 479 insertions(+), 85 deletions(-) create mode 100755 python/examples/simulation/mpi_torch_hierarchical_fl/batch_lauch.sh create mode 100644 python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml create mode 100644 python/fedml/core/distributed/topology/topo_utils.py diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/README.md b/python/examples/simulation/mpi_torch_hierarchical_fl/README.md index 3010fd24ad..bd5da27569 100644 --- a/python/examples/simulation/mpi_torch_hierarchical_fl/README.md +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/README.md @@ -6,6 +6,8 @@ pip install fedml # Run the example (step by step APIs) ``` -sh run_step_by_step_example.sh 4 config/mnist_lr/fedml_config.yaml +sh run_step_by_step_example.sh 5 config/mnist_lr/fedml_config.yaml + +sh run_step_by_step_example.sh 5 config/mnist_lr/fedml_config_topo.yaml ``` diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/batch_lauch.sh b/python/examples/simulation/mpi_torch_hierarchical_fl/batch_lauch.sh new file mode 100755 index 0000000000..8dd565e6e8 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/batch_lauch.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +WORKER_NUM=$1 +CONFIG_PATH=$2 + +hostname > mpi_host_file + +mpirun -np $WORKER_NUM \ +-hostfile mpi_host_file \ +python torch_step_by_step_example.py --cf $CONFIG_PATH \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml index c5a244d1bd..6543f7e033 100644 --- a/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml @@ -15,7 +15,7 @@ train_args: federated_optimizer: "HierarchicalFL" client_id_list: "[]" client_num_in_total: 1000 - client_num_per_round: 10 + client_num_per_round: 20 comm_round: 20 epochs: 1 batch_size: 10 @@ -23,17 +23,17 @@ train_args: learning_rate: 0.03 weight_decay: 0.001 group_method: "random" - group_num: 2 + group_num: 4 group_comm_round: 5 validation_args: frequency_of_the_test: 5 device_args: - worker_num: 3 + worker_num: 5 using_gpu: true gpu_mapping_file: config/mnist_lr/gpu_mapping.yaml - gpu_mapping_key: mapping_config1_3 + gpu_mapping_key: mapping_config1_5 comm_args: backend: "MPI" diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml new file mode 100644 index 0000000000..d9880072ed --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml @@ -0,0 +1,51 @@ +common_args: + training_type: "simulation" + random_seed: 0 + +data_args: + dataset: "mnist" + data_cache_dir: ~/fedml_data + partition_method: "hetero" + partition_alpha: 0.5 + +model_args: + model: "lr" + +train_args: + federated_optimizer: "HierarchicalFL" + client_id_list: "[]" + client_num_in_total: 1000 + client_num_per_round: 20 + comm_round: 20 + epochs: 1 + batch_size: 10 + client_optimizer: sgd + learning_rate: 0.03 + weight_decay: 0.001 + group_method: "random" + group_num: 4 + group_comm_round: 5 + topo_name: "complete" + topo_edge_probability: 0.5 + +validation_args: + frequency_of_the_test: 5 + +device_args: + worker_num: 5 + using_gpu: true + gpu_mapping_file: config/mnist_lr/gpu_mapping.yaml + gpu_mapping_key: mapping_config1_5 + +comm_args: + backend: "MPI" + is_mobile: 0 + + +tracking_args: + # When running on MLOps platform(open.fedml.ai), the default log path is at ~/fedml-client/fedml/logs/ and ~/fedml-server/fedml/logs/ + enable_wandb: true + wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 + wandb_project: fedml + run_name: mpi_hierarchical_fl_mnist_lr + wandb_only_server: true \ No newline at end of file diff --git a/python/fedml/core/distributed/topology/symmetric_topology_manager.py b/python/fedml/core/distributed/topology/symmetric_topology_manager.py index 07d90525e4..4ce736f2f7 100644 --- a/python/fedml/core/distributed/topology/symmetric_topology_manager.py +++ b/python/fedml/core/distributed/topology/symmetric_topology_manager.py @@ -2,6 +2,7 @@ import numpy as np from .base_topology_manager import BaseTopologyManager +from .topo_utils import * class SymmetricTopologyManager(BaseTopologyManager): @@ -18,6 +19,27 @@ def __init__(self, n, neighbor_num=2): self.neighbor_num = neighbor_num self.topology = [] + def generate_custom_topology(self, args): + topo_name = args.topo_name + if topo_name == 'ring': + self.neighbor_num = 2 + self.generate_topology() + elif topo_name == '2d_torus': + self.topology = get_2d_torus_overlay(self.n) + elif topo_name == 'star': + self.topology = get_star_overlay(self.n) + elif topo_name == 'complete': + self.topology = get_complete_overlay(self.n) + elif topo_name == 'isolated': + self.topology = get_isolated_overlay(self.n) + elif topo_name == 'balanced_tree': + self.topology = get_balanced_tree_overlay(self.n, self.neighbor_num) + elif topo_name == 'random': + probability = args.topo_edge_probability # Probability for edge creation + self.topology = get_random_overlay(self.n, probability) + else: + raise Exception(topo_name) + def generate_topology(self): # first generate a ring topology topology_ring = np.array( @@ -84,8 +106,9 @@ def get_out_neighbor_idx_list(self, node_index): if __name__ == "__main__": # generate a ring topology - tpmgr = SymmetricTopologyManager(6, 2) - tpmgr.generate_topology() + tpmgr = SymmetricTopologyManager(9, 2, 0.3) + # tpmgr.generate_topology() + tpmgr.generate_custom_topology('random') print("tpmgr.topology = " + str(tpmgr.topology)) # get the OUT neighbor weights for node 1 diff --git a/python/fedml/core/distributed/topology/topo_utils.py b/python/fedml/core/distributed/topology/topo_utils.py new file mode 100644 index 0000000000..bba541dbe2 --- /dev/null +++ b/python/fedml/core/distributed/topology/topo_utils.py @@ -0,0 +1,94 @@ +import math +import numpy as np +import networkx as nx + + +def get_2d_torus_overlay(node_num): + side_len = node_num ** 0.5 + assert math.ceil(side_len) == math.floor(side_len) + side_len = int(side_len) + + torus = np.zeros((node_num, node_num), dtype=np.float32) + + for i in range(side_len): + for j in range(side_len): + idx = i * side_len + j + torus[i, i] = 1 / 5 + torus[idx, (((i + 1) % side_len) * side_len + j)] = 1 / 5 + torus[idx, (((i - 1) % side_len) * side_len + j)] = 1 / 5 + torus[idx, (i * side_len + (j + 1) % side_len)] = 1 / 5 + torus[idx, (i * side_len + (j - 1) % side_len)] = 1 / 5 + + return torus + + +def get_star_overlay(node_num): + + star = np.zeros((node_num, node_num), dtype=np.float32) + for i in range(node_num): + if i == 0: + star[i, i] = 1 / node_num + else: + star[0, i] = star[i, 0] = 1 / node_num + star[i, i] = 1 - 1 / node_num + + return star + + +def get_complete_overlay(node_num): + + complete = np.ones((node_num, node_num), dtype=np.float32) + complete /= node_num + + return complete + + +def get_isolated_overlay(node_num): + + isolated = np.zeros((node_num, node_num), dtype=np.float32) + + for i in range(node_num): + isolated[i, i] = 1 + + return isolated + + +def get_balanced_tree_overlay(node_num, degree=2): + + tree = np.zeros((node_num, node_num), dtype=np.float32) + + for i in range(node_num): + for j in range(1, degree+1): + k = i * 2 + j + if k >= node_num: + break + tree[i, k] = 1 / (degree+1) + + for i in range(node_num): + tree[i, i] = 1 - tree[i, :].sum() + + return tree + + +def get_barbell_overlay(node_num, m1=1, m2=0): + + barbell = None + + return barbell + + +def get_random_overlay(node_num, probability=0.5): + + random = np.array( + nx.to_numpy_matrix(nx.fast_gnp_random_graph(node_num, probability)), dtype=np.float32 + ) + + matrix_sum = random.sum(1) + + for i in range(node_num): + for j in range(node_num): + if i != j and random[i, j] > 0: + random[i, j] = 1 / (1 + max(matrix_sum[i], matrix_sum[j])) + random[i, i] = 1 - random[i].sum() + + return random \ No newline at end of file diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py b/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py index 8e4cf98c03..5c32a08f95 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py @@ -23,7 +23,7 @@ def __init__(self, client_idx, local_training_data, local_test_data, local_sampl self.model_trainer = model_trainer self.criterion = nn.CrossEntropyLoss().to(device) - def train(self, global_round_idx, group_round_idx, w): + def train(self, w): self.model.load_state_dict(w) self.model.to(self.device) @@ -37,7 +37,6 @@ def train(self, global_round_idx, group_round_idx, w): amsgrad=True, ) - w_list = [] for epoch in range(self.args.epochs): for x, labels in self.local_training_data: x, labels = x.to(self.device), labels.to(self.device) @@ -46,10 +45,5 @@ def train(self, global_round_idx, group_round_idx, w): loss = self.criterion(log_probs, labels) # pylint: disable=E1102 loss.backward() optimizer.step() - client_round = ( - global_round_idx * self.args.group_comm_round - + group_round_idx - ) - # if client_round % self.args.frequency_of_the_test == 0: - w_list.append((client_round, copy.deepcopy(self.model.cpu().state_dict()))) - return w_list + + return copy.deepcopy(self.model.cpu().state_dict()) diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py index 662c1cc961..0ce4f92e3f 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py @@ -2,15 +2,18 @@ from .HierFedAvgCloudManager import HierFedAVGCloudManager from .HierFedAvgEdgeManager import HierFedAVGEdgeManager from .HierGroup import HierGroup +from .utils import analyze_clients_type, hetero_partition_groups, visualize_group_detail from ....core import ClientTrainer, ServerAggregator from ....core.dp.fedml_differential_privacy import FedMLDifferentialPrivacy from ....core.security.fedml_attacker import FedMLAttacker from ....core.security.fedml_defender import FedMLDefender from ....ml.aggregator.aggregator_creator import create_server_aggregator from ....ml.trainer.trainer_creator import create_model_trainer +from ....core.distributed.topology.symmetric_topology_manager import SymmetricTopologyManager -import numpy as np +import numpy as np +import wandb def FedML_HierFedAvg_distributed( args, @@ -52,7 +55,8 @@ def FedML_HierFedAvg_distributed( train_data_local_dict, test_data_local_dict, train_data_local_num_dict, - server_aggregator + server_aggregator, + class_num ) else: init_edge_server_clients( @@ -84,14 +88,22 @@ def init_cloud_server( train_data_local_dict, test_data_local_dict, train_data_local_num_dict, - server_aggregator + server_aggregator, + class_num ): if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) - # aggregator worker_num = size - 1 + + # set up topology + topology_manager = None + if hasattr(args, "topo_name"): + topology_manager = SymmetricTopologyManager(worker_num, args) + topology_manager.generate_custom_topology(args) + + # aggregator aggregator = HierFedAVGCloudAggregator( train_data_global, test_data_global, @@ -102,13 +114,20 @@ def init_cloud_server( worker_num, device, args, - server_aggregator, + server_aggregator ) # start the distributed training backend = args.backend - group_indexes, group_to_client_indexes = setup_clients(args) - server_manager = HierFedAVGCloudManager(args, aggregator, group_indexes, group_to_client_indexes, comm, rank, size, backend) + group_indexes, group_to_client_indexes = setup_clients(args, train_data_local_dict, class_num) + + # visualize group detail + if args.enable_wandb: + visualize_group_detail(group_to_client_indexes, train_data_local_dict, train_data_local_num_dict, class_num) + + + server_manager = HierFedAVGCloudManager(args, aggregator, group_indexes, group_to_client_indexes, + comm, rank, size, backend, topology_manager) server_manager.send_init_msg() server_manager.run() @@ -151,8 +170,11 @@ def init_edge_server_clients( def setup_clients( - args + args, + train_data_local_dict, + class_num ): + if args.group_method == "random": group_indexes = np.random.randint( 0, args.group_num, args.client_num_in_total @@ -162,7 +184,10 @@ def setup_clients( if not group_idx in group_to_client_indexes: group_to_client_indexes[group_idx] = [] group_to_client_indexes[group_idx].append(client_idx) - else: - raise Exception(args.group_method) + elif args.group_method == "hetero": + clients_type_list = analyze_clients_type(train_data_local_dict, class_num, num_type=args.group_num) + group_indexes, group_to_client_indexes = hetero_partition_groups(clients_type_list, + args.group_num, + alpha=args.group_alpha) return group_indexes, group_to_client_indexes \ No newline at end of file diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py index 59ef1af712..7cf441e0f5 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py @@ -1,3 +1,4 @@ +import copy import logging import random import time @@ -6,6 +7,7 @@ from ....core.security.fedml_attacker import FedMLAttacker from ....core.security.fedml_defender import FedMLDefender +from .utils import cal_mixing_consensus_speed class HierFedAVGCloudAggregator(object): @@ -66,17 +68,19 @@ def aggregate(self): start_time = time.time() # Edge server may conduct partial aggregation multiple times, so cloud server will receive a model list - for group_round_idx in range(self.args.group_comm_round): + group_comm_round = len(self.sample_num_dict[0]) + + for group_round_idx in range(group_comm_round): model_list = [] + global_round_idx = self.model_dict[0][group_round_idx][0] for idx in range(0, self.worker_num): - model_list.append((self.sample_num_dict[idx], + model_list.append((self.sample_num_dict[idx][group_round_idx], self.model_dict[idx][group_round_idx][1])) - client_round = self.model_dict[0][group_round_idx][0] averaged_params = self._fedavg_aggregation_(model_list) self.set_global_model_params(averaged_params) - self.test_on_cloud_for_all_clients(client_round) + self.test_on_cloud_for_all_clients(global_round_idx) if FedMLAttacker.get_instance().is_model_attack(): model_list = FedMLAttacker.get_instance().attack_model(raw_client_grad_list=model_list, extra_auxiliary_info=None) @@ -98,6 +102,41 @@ def aggregate(self): logging.info("aggregate time cost: %d" % (end_time - start_time)) return averaged_params + def mix(self, topology_manager): + start_time = time.time() + + # Edge server may conduct partial aggregation multiple times, so cloud server will receive a model list + group_comm_round = len(self.sample_num_dict[0]) + edge_model_list = [None for _ in range(self.worker_num)] + + p = cal_mixing_consensus_speed(topology_manager.topology, self.model_dict[0][0][0]) + + for group_round_idx in range(group_comm_round): + model_list = [] + global_round_idx = self.model_dict[0][group_round_idx][0] + + for idx in range(self.worker_num): + model_list.append((self.sample_num_dict[idx][group_round_idx], + self.model_dict[idx][group_round_idx][1])) + + # mixing between neighbors + for idx in range(self.worker_num): + edge_model_list[idx] = (self.sample_num_dict[idx][group_round_idx], + self._pfedavg_mixing_(model_list, + topology_manager.get_in_neighbor_weights(idx)) + ) + # average for testing + averaged_params = self._fedavg_aggregation_(edge_model_list) + self.set_global_model_params(averaged_params) + self.test_on_cloud_for_all_clients(global_round_idx) + + # update the global model which is cached in the cloud + self.set_global_model_params(averaged_params) + + end_time = time.time() + logging.info("mix time cost: %d" % (end_time - start_time)) + return [edge_model for _, edge_model in edge_model_list] + def _fedavg_aggregation_(self, model_list): training_num = 0 for i in range(0, len(model_list)): @@ -117,6 +156,29 @@ def _fedavg_aggregation_(self, model_list): ) return averaged_params + def _pfedavg_mixing_(self, model_list, neighbor_topo_weight_list): + training_num = 0 + for i in range(0, len(model_list)): + local_sample_number, local_model_params = model_list[i] + training_num += local_sample_number + + (num0, averaged_params) = model_list[0] + averaged_params = copy.deepcopy(averaged_params) + for k in averaged_params.keys(): + for i in range(0, len(model_list)): + local_sample_number, local_model_params = model_list[i] + topo_weight = neighbor_topo_weight_list[i] + if i == 0: + averaged_params[k] = ( + local_model_params[k] * topo_weight + ) + else: + averaged_params[k] += ( + local_model_params[k] * topo_weight + ) + + return averaged_params + def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): if client_num_in_total == client_num_per_round: client_indexes = [ @@ -147,22 +209,34 @@ def _generate_validation_set(self, num_samples=10000): else: return self.test_global - def test_on_cloud_for_all_clients(self, client_round): + def test_on_cloud_for_all_clients(self, global_round_idx): + if self.aggregator.test_all( + self.train_data_local_dict, + self.test_data_local_dict, + self.device, + self.args, + ): + return if ( - client_round % self.args.frequency_of_the_test == 0 - or client_round == self.args.comm_round * self.args.group_comm_round - 1 + global_round_idx % self.args.frequency_of_the_test == 0 + or global_round_idx == self.args.comm_round * self.args.group_comm_round - 1 ): - logging.info("################test_on_cloud_for_all_clients : {}".format(client_round)) + logging.info("################test_on_cloud_for_all_clients : {}".format(global_round_idx)) # We may want to test the intermediate results of partial aggregated models, so we play a trick and let # args.round_idx be total number of partial aggregated times + round_idx = self.args.round_idx - self.args.round_idx = client_round - train_metric_result_in_current_round = self.aggregator.test(self.train_global, self.device, self.args) - test_metric_result_in_current_round = self.aggregator.test(self.test_global, self.device, self.args) + self.args.round_idx = global_round_idx + + if global_round_idx == self.args.comm_round - 1: + # we allow to return four metrics, such as accuracy, AUC, loss, etc. + metric_result_in_current_round = self.aggregator.test(self.test_global, self.device, self.args) + else: + metric_result_in_current_round = self.aggregator.test(self.val_global, self.device, self.args) + self.args.round_idx = round_idx - logging.info("train_metric_result_in_current_round = {}".format(train_metric_result_in_current_round)) - logging.info("test_metric_result_in_current_round = {}".format(test_metric_result_in_current_round)) \ No newline at end of file + logging.info("metric_result_in_current_round = {}".format(metric_result_in_current_round)) \ No newline at end of file diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py index 3f7813b17b..e40afad048 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py @@ -16,6 +16,7 @@ def __init__( rank=0, size=0, backend="MPI", + topology_manager=None # is_preprocessed=False, # preprocessed_client_lists=None, ): @@ -26,6 +27,19 @@ def __init__( self.group_to_client_indexes = group_to_client_indexes self.round_num = args.comm_round self.args.round_idx = 0 + self.topology_manager = topology_manager + + total_clients = len(self.group_indexes) + self.group_to_client_num_per_round = [ + args.client_num_per_round * len(self.group_to_client_indexes[i]) // total_clients + for i in range(args.group_num) + ] + + remain_client_num_list_per_round = args.client_num_per_round - sum(self.group_to_client_num_per_round) + while remain_client_num_list_per_round > 0: + self.group_to_client_num_per_round[remain_client_num_list_per_round-1] += 1 + remain_client_num_list_per_round -= 1 + # self.is_preprocessed = is_preprocessed # self.preprocessed_client_lists = preprocessed_client_lists @@ -36,18 +50,18 @@ def send_init_msg(self): # broadcast to edge servers global_model_params = self.aggregator.get_global_model_params() - sampled_client_indexes = self.aggregator.client_sampling( - self.args.round_idx, - self.args.client_num_in_total, - self.args.client_num_per_round, - ) - sampled_group_to_client_indexes = {} - for client_idx in sampled_client_indexes: - group_idx = self.group_indexes[client_idx] - if not group_idx in sampled_group_to_client_indexes: - sampled_group_to_client_indexes[group_idx] = [] - sampled_group_to_client_indexes[group_idx].append(client_idx) + for group_idx, client_num_per_round in enumerate(self.group_to_client_num_per_round): + client_num_in_total = len(self.group_to_client_indexes[group_idx]) + sampled_client_indexes = self.aggregator.client_sampling( + self.args.round_idx, + client_num_in_total, + client_num_per_round, + ) + + sampled_group_to_client_indexes[group_idx] = [self.group_to_client_indexes[group_idx][index] + for index in sampled_client_indexes] + logging.info( "client_indexes of each group = {}".format(sampled_group_to_client_indexes) ) @@ -70,16 +84,20 @@ def register_message_receive_handlers(self): def handle_message_receive_model_from_edge(self, msg_params): sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params_list = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_LIST) - edge_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) + sample_num_list = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) self.aggregator.add_local_trained_result( - sender_id - 1, model_params_list, edge_sample_number + sender_id - 1, model_params_list, sample_num_list ) b_all_received = self.aggregator.check_whether_all_receive() logging.info("b_all_received = " + str(b_all_received)) if b_all_received: + # If topology_manage is None, it is simple average. Otherwise, it is mixing between neighbours. + if self.topology_manager is None: + global_model_params = self.aggregator.aggregate() + else: + global_model_params_list = self.aggregator.mix(self.topology_manager) - global_model_params = self.aggregator.aggregate() # start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: @@ -87,27 +105,32 @@ def handle_message_receive_model_from_edge(self, msg_params): self.finish() return - sampled_client_indexes = self.aggregator.client_sampling( - self.args.round_idx, - self.args.client_num_in_total, - self.args.client_num_per_round, - ) - sampled_group_to_client_indexes = {} - for client_idx in sampled_client_indexes: - group_idx = self.group_indexes[client_idx] - if not group_idx in sampled_group_to_client_indexes: - sampled_group_to_client_indexes[group_idx] = [] - sampled_group_to_client_indexes[group_idx].append(client_idx) + for group_idx, client_num_per_round in enumerate(self.group_to_client_num_per_round): + client_num_in_total = len(self.group_to_client_indexes[group_idx]) + sampled_client_indexes = self.aggregator.client_sampling( + self.args.round_idx, + client_num_in_total, + client_num_per_round, + ) + + sampled_group_to_client_indexes[group_idx] = [self.group_to_client_indexes[group_idx][index] + for index in sampled_client_indexes] logging.info( "client_indexes of each group = {}".format(sampled_group_to_client_indexes) ) for receiver_id in range(1, self.size): - self.send_message_sync_model_to_edge( - receiver_id, global_model_params, - sampled_group_to_client_indexes[receiver_id-1], receiver_id-1 - ) + if self.topology_manager is None: + self.send_message_sync_model_to_edge( + receiver_id, global_model_params, + sampled_group_to_client_indexes[receiver_id-1], receiver_id-1 + ) + else: + self.send_message_sync_model_to_edge( + receiver_id, global_model_params_list[receiver_id - 1], + sampled_group_to_client_indexes[receiver_id - 1], receiver_id - 1 + ) def send_message_init_config(self, receive_id, global_model_params, total_client_indexes, sampled_client_indexed, edge_index): message = Message( diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py index f60469ed95..3a46dbedb8 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py @@ -41,10 +41,9 @@ def handle_message_init(self, msg_params): self.group.setup_clients(total_client_indexes) self.args.round_idx = 0 - w_group_list = self.group.train(self.args.round_idx, global_model_params, sampled_client_indexes) - edge_sample_num = self.group.get_sample_number(sampled_client_indexes) + w_group_list, sample_num_list = self.group.train(self.args.round_idx, global_model_params, sampled_client_indexes) - self.send_model_to_cloud(0, w_group_list, edge_sample_num) + self.send_model_to_cloud(0, w_group_list, sample_num_list) def handle_message_receive_model_from_cloud(self, msg_params): logging.info("handle_message_receive_model_from_cloud.") @@ -53,9 +52,8 @@ def handle_message_receive_model_from_cloud(self, msg_params): edge_index = msg_params.get(MyMessage.MSG_ARG_KEY_EDGE_INDEX) self.args.round_idx += 1 - w_group_list = self.group.train(self.args.round_idx, global_model_params, sampled_client_indexes) - edge_sample_num = self.group.get_sample_number(sampled_client_indexes) - self.send_model_to_cloud(0, w_group_list, edge_sample_num) + w_group_list, sample_num_list = self.group.train(self.args.round_idx, global_model_params, sampled_client_indexes) + self.send_model_to_cloud(0, w_group_list, sample_num_list) if self.args.round_idx == self.num_rounds: post_complete_message_to_sweep_process(self.args) diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py b/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py index de06379721..9df21e6bb4 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py @@ -47,27 +47,28 @@ def get_sample_number(self, sampled_client_indexes): self.group_sample_number += self.train_data_local_num_dict[client_idx] return self.group_sample_number - def train(self, global_round_idx, w, sampled_client_indexes): + def train(self, round_idx, w, sampled_client_indexes): sampled_client_list = [self.client_dict[client_idx] for client_idx in sampled_client_indexes] w_group = w w_group_list = [] + sample_num_list = [] for group_round_idx in range(self.args.group_comm_round): logging.info("Group ID : {} / Group Communication Round : {}".format(self.idx, group_round_idx)) - w_locals_dict = {} + w_locals = [] + global_round_idx = ( + round_idx * self.args.group_comm_round + + group_round_idx + ) # train each client for client in sampled_client_list: - w_local_list = client.train(global_round_idx, group_round_idx, w_group) - for client_round, w in w_local_list: - if not client_round in w_locals_dict: - w_locals_dict[client_round] = [] - w_locals_dict[client_round].append((client.get_sample_number(), w)) + w_local = client.train(w_group) + w_locals.append((client.get_sample_number(), w_local)) # aggregate local weights - for client_round in sorted(w_locals_dict.keys()): - w_locals = w_locals_dict[client_round] - w_group_list.append((client_round, self._aggregate(w_locals))) + w_group_list.append((global_round_idx, self._aggregate(w_locals))) + sample_num_list.append(self.get_sample_number(sampled_client_indexes)) # update the group weight w_group = w_group_list[-1][1] - return w_group_list + return w_group_list, sample_num_list diff --git a/python/fedml/simulation/mpi/hierarchical_fl/utils.py b/python/fedml/simulation/mpi/hierarchical_fl/utils.py index 9143a374d4..a68ae6ccac 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/utils.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/utils.py @@ -3,6 +3,104 @@ import numpy as np import torch +import wandb + +from sklearn.cluster import KMeans + + +def cal_mixing_consensus_speed(topo_weight_matrix, global_round_idx): + n_rows, n_cols = np.shape(topo_weight_matrix) + assert n_rows == n_cols + A = np.array(topo_weight_matrix) - 1 / n_rows + p = 1 - np.linalg.norm(A, ord=2) ** 2 + wandb.log({"Groups/p": p, "comm_round": global_round_idx}) + return p + + +def visualize_group_detail(group_to_client_indexes, train_data_local_dict, train_data_local_num_dict, class_num): + + xs = [i for i in range(class_num)] + ys = [] + keys = [] + for group_idx in range(len(group_to_client_indexes)): + data_size = 0 + group_y_train = [] + for client_id in group_to_client_indexes[group_idx]: + data_size += train_data_local_num_dict[client_id] + y_train = torch.concat([y for _, y in train_data_local_dict[client_id]]).tolist() + group_y_train.extend(y_train) + + labels, counts = np.unique(group_y_train, return_counts=True) + + count_vector = np.zeros(class_num) + count_vector[labels] = counts + ys.append(count_vector/count_vector.sum()) + keys.append("Group {}".format(group_idx)) + + wandb.log({"Groups/Client_num": len(group_to_client_indexes[group_idx]), "group_id": group_idx}) + wandb.log({"Groups/Data_size": data_size, "group_id": group_idx}) + + wandb.log({"Groups/Data_distribution": + wandb.plot.line_series(xs=xs, ys=ys, keys=keys, title="Data distribution", xname="Label")} + ) + + +def hetero_partition_groups(clients_type_list, num_groups, alpha=0.5): + min_size = 0 + num_type = np.unique(clients_type_list).size + N = len(clients_type_list) + group_to_client_indexes = {} + while min_size < 10: + idx_batch = [[] for _ in range(num_groups)] + # for each type in clients + for k in range(num_type): + idx_k = np.where(np.array(clients_type_list) == k)[0] + np.random.shuffle(idx_k) + proportions = np.random.dirichlet(np.repeat(alpha, num_groups)) + ## Balance + proportions = np.array([p * (len(idx_j) < N / num_groups) for p, idx_j in zip(proportions, idx_batch)]) + proportions = proportions / proportions.sum() + proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] + idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] + min_size = min([len(idx_j) for idx_j in idx_batch]) + + group_indexes = [0 for _ in range(N)] + for j in range(num_groups): + np.random.shuffle(idx_batch[j]) + group_to_client_indexes[j] = idx_batch[j] + for client_id in group_to_client_indexes[j]: + group_indexes[client_id] = j + + return group_indexes, group_to_client_indexes + + +def analyze_clients_type(train_data_local_dict, class_num, num_type=5): + client_feature_list = [] + for i in range(len(train_data_local_dict)): + y_train = torch.concat([y for _, y in train_data_local_dict[i]]) + labels, counts = torch.unique(y_train, return_counts=True) + data_feature = np.zeros(class_num) + total = 0 + for label, count in zip(labels, counts): + data_feature[label.item()] = count.item() + total += count.item() + data_feature /= total + client_feature_list.append(data_feature) + + kmeans = KMeans(n_clusters=num_type, random_state=0, n_init="auto").fit(client_feature_list) + + + + # for k in range(num_type): + # tmp = [] + # for i, j in enumerate(kmeans.labels_): + # if j == k: + # indexes = np.where(np.array(client_feature_list[i]) > 0) + # tmp.extend(indexes[0].tolist()) + # print(np.unique(tmp)) + # + # exit(0) + return kmeans.labels_ def transform_list_to_tensor(model_params_list): @@ -21,6 +119,7 @@ def transform_tensor_to_list(model_params): def post_complete_message_to_sweep_process(args): pipe_path = "./tmp/fedml" + os.system("mkdir -p ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): os.mkfifo(pipe_path) pipe_fd = os.open(pipe_path, os.O_WRONLY) From aad1f47c6f0cfe4f4507b80f5ae5bbc85a3a2986 Mon Sep 17 00:00:00 2001 From: liuliuliu0605 Date: Sat, 30 Sep 2023 19:04:16 +0800 Subject: [PATCH 3/3] redefine local loss function to support weighted mixing --- .../mpi_torch_hierarchical_fl/README.md | 9 +++- .../mpi_torch_hierarchical_fl/batch_lauch.sh | 10 ---- .../mpi_torch_hierarchical_fl/batch_run.sh | 47 +++++++++++++++++++ .../config/mnist_lr/fedml_config_topo.yaml | 5 +- .../mpi/hierarchical_fl/HierClient.py | 9 ++-- .../mpi/hierarchical_fl/HierFedAvgAPI.py | 8 ++-- .../HierFedAvgCloudAggregator.py | 19 +++++++- .../hierarchical_fl/HierFedAvgCloudManager.py | 45 +++++++++++------- .../hierarchical_fl/HierFedAvgEdgeManager.py | 8 +++- .../mpi/hierarchical_fl/HierGroup.py | 11 ++++- .../mpi/hierarchical_fl/message_define.py | 1 + .../simulation/mpi/hierarchical_fl/utils.py | 26 ++++++---- 12 files changed, 145 insertions(+), 53 deletions(-) delete mode 100755 python/examples/simulation/mpi_torch_hierarchical_fl/batch_lauch.sh create mode 100755 python/examples/simulation/mpi_torch_hierarchical_fl/batch_run.sh diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/README.md b/python/examples/simulation/mpi_torch_hierarchical_fl/README.md index bd5da27569..6138709194 100644 --- a/python/examples/simulation/mpi_torch_hierarchical_fl/README.md +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/README.md @@ -4,10 +4,15 @@ pip install fedml ``` -# Run the example (step by step APIs) +# Run the example + +## mpi hierarchical fl ``` sh run_step_by_step_example.sh 5 config/mnist_lr/fedml_config.yaml +``` -sh run_step_by_step_example.sh 5 config/mnist_lr/fedml_config_topo.yaml +## mpi hierarchical fl based on some topology (e.g., 2d_torus, star, complete, isolated, balanced_tree and random) +``` +sh run_step_by_step_example.sh 5 config/mnist_lr/fedml_config.yaml ``` diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/batch_lauch.sh b/python/examples/simulation/mpi_torch_hierarchical_fl/batch_lauch.sh deleted file mode 100755 index 8dd565e6e8..0000000000 --- a/python/examples/simulation/mpi_torch_hierarchical_fl/batch_lauch.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash - -WORKER_NUM=$1 -CONFIG_PATH=$2 - -hostname > mpi_host_file - -mpirun -np $WORKER_NUM \ --hostfile mpi_host_file \ -python torch_step_by_step_example.py --cf $CONFIG_PATH \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/batch_run.sh b/python/examples/simulation/mpi_torch_hierarchical_fl/batch_run.sh new file mode 100755 index 0000000000..fcbe10bde3 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/batch_run.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +GROUP_NUM=5 +GROUP_METHOD="hetero" +COMM_ROUND=62 #250 +GROUP_COMM_ROUND=4 # 1 +TOPO_NAME="star" +CONFIG_PATH=config/mnist_lr/fedml_config_topo.yaml + +group_alpha_list=(0.01 0.1 1.0) + +WORKER_NUM=$(($GROUP_NUM+1)) +hostname > mpi_host_file +mkdir -p batch_log +# we need to install yq (https://github.com/mikefarah/yq) +# wget https://github.com/mikefarah/yq/releases/latest/download/yq_linux_amd64 -O /usr/bin/yq && chmod +x /usr/bin/yq + +yq -i ".device_args.worker_num = ${WORKER_NUM}" $CONFIG_PATH +yq -i ".device_args.gpu_mapping_key = \"mapping_config1_${WORKER_NUM}\"" $CONFIG_PATH +yq -i ".train_args.group_num = ${GROUP_NUM}" $CONFIG_PATH +yq -i ".train_args.comm_round = ${COMM_ROUND}" $CONFIG_PATH +yq -i ".train_args.group_comm_round = ${GROUP_COMM_ROUND}" $CONFIG_PATH +yq -i ".train_args.group_method = \"${GROUP_METHOD}\"" $CONFIG_PATH +yq -i ".train_args.topo_name = \"${TOPO_NAME}\"" $CONFIG_PATH + +if [ "${GROUP_METHOD}" = "random" ]; then + yq -i ".train_args.group_alpha = 0" $CONFIG_PATH +fi + +if [ "${TOPO_NAME}" != "random" ]; then + yq -i ".train_args.topo_edge_probability = 1.0" $CONFIG_PATH +fi + + +for group_alpha in ${group_alpha_list[@]}; +do + echo "group_alpha=$group_alpha" + yq -i ".train_args.group_alpha = ${group_alpha}" $CONFIG_PATH + + nohup mpirun -np $WORKER_NUM \ + -hostfile mpi_host_file \ + python torch_step_by_step_example.py --cf $CONFIG_PATH \ + > batch_log/"group_alpha=$group_alpha.log" 2>&1 & echo $! >> batch_log/group_alpha.pid + sleep 30 +done + +echo "Finished!" \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml index d9880072ed..43b680d37d 100644 --- a/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml @@ -22,10 +22,11 @@ train_args: client_optimizer: sgd learning_rate: 0.03 weight_decay: 0.001 - group_method: "random" + group_method: "hetero" + group_alpha: 0.5 group_num: 4 group_comm_round: 5 - topo_name: "complete" + topo_name: "ring" topo_edge_probability: 0.5 validation_args: diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py b/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py index 5c32a08f95..62c25ba683 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py @@ -23,17 +23,18 @@ def __init__(self, client_idx, local_training_data, local_test_data, local_sampl self.model_trainer = model_trainer self.criterion = nn.CrossEntropyLoss().to(device) - def train(self, w): + def train(self, w, scaled_loss_factor=1.0): self.model.load_state_dict(w) self.model.to(self.device) + scaled_loss_factor = min(scaled_loss_factor, 1.0) if self.args.client_optimizer == "sgd": - optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.learning_rate) + optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.learning_rate * scaled_loss_factor) else: optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.model.parameters()), - lr=self.args.learning_rate, - weight_decay=self.args.wd, + lr=self.args.learning_rate * scaled_loss_factor, + weight_decay=self.args.weight_decay, amsgrad=True, ) diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py index 0ce4f92e3f..d3095d4607 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py @@ -2,7 +2,7 @@ from .HierFedAvgCloudManager import HierFedAVGCloudManager from .HierFedAvgEdgeManager import HierFedAVGEdgeManager from .HierGroup import HierGroup -from .utils import analyze_clients_type, hetero_partition_groups, visualize_group_detail +from .utils import analyze_clients_type, hetero_partition_groups, stats_group from ....core import ClientTrainer, ServerAggregator from ....core.dp.fedml_differential_privacy import FedMLDifferentialPrivacy from ....core.security.fedml_attacker import FedMLAttacker @@ -121,10 +121,8 @@ def init_cloud_server( backend = args.backend group_indexes, group_to_client_indexes = setup_clients(args, train_data_local_dict, class_num) - # visualize group detail - if args.enable_wandb: - visualize_group_detail(group_to_client_indexes, train_data_local_dict, train_data_local_num_dict, class_num) - + # print group detail + stats_group(group_to_client_indexes, train_data_local_dict, train_data_local_num_dict, class_num, args) server_manager = HierFedAVGCloudManager(args, aggregator, group_indexes, group_to_client_indexes, comm, rank, size, backend, topology_manager) diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py index 7cf441e0f5..ee23e14bff 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py @@ -109,7 +109,7 @@ def mix(self, topology_manager): group_comm_round = len(self.sample_num_dict[0]) edge_model_list = [None for _ in range(self.worker_num)] - p = cal_mixing_consensus_speed(topology_manager.topology, self.model_dict[0][0][0]) + p = cal_mixing_consensus_speed(topology_manager.topology, self.model_dict[0][0][0], self.args) for group_round_idx in range(group_comm_round): model_list = [] @@ -126,7 +126,7 @@ def mix(self, topology_manager): topology_manager.get_in_neighbor_weights(idx)) ) # average for testing - averaged_params = self._fedavg_aggregation_(edge_model_list) + averaged_params = self._pfedavg_aggregation_(edge_model_list) self.set_global_model_params(averaged_params) self.test_on_cloud_for_all_clients(global_round_idx) @@ -156,6 +156,21 @@ def _fedavg_aggregation_(self, model_list): ) return averaged_params + def _pfedavg_aggregation_(self, model_list): + (num0, averaged_params) = model_list[0] + for k in averaged_params.keys(): + for i in range(0, len(model_list)): + _, local_model_params = model_list[i] + if i == 0: + averaged_params[k] = ( + local_model_params[k] * 1 / len(model_list) + ) + else: + averaged_params[k] += ( + local_model_params[k] * 1 / len(model_list) + ) + return averaged_params + def _pfedavg_mixing_(self, model_list, neighbor_topo_weight_list): training_num = 0 for i in range(0, len(model_list)): diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py index e40afad048..6111f44c4e 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py @@ -51,6 +51,7 @@ def send_init_msg(self): global_model_params = self.aggregator.get_global_model_params() sampled_group_to_client_indexes = {} + total_sampled_data_size = 0 for group_idx, client_num_per_round in enumerate(self.group_to_client_num_per_round): client_num_in_total = len(self.group_to_client_indexes[group_idx]) sampled_client_indexes = self.aggregator.client_sampling( @@ -58,20 +59,24 @@ def send_init_msg(self): client_num_in_total, client_num_per_round, ) - - sampled_group_to_client_indexes[group_idx] = [self.group_to_client_indexes[group_idx][index] - for index in sampled_client_indexes] + sampled_group_to_client_indexes[group_idx] = [] + for index in sampled_client_indexes: + client_idx = self.group_to_client_indexes[group_idx][index] + sampled_group_to_client_indexes[group_idx].append(client_idx) + total_sampled_data_size += self.aggregator.train_data_local_num_dict[client_idx] logging.info( "client_indexes of each group = {}".format(sampled_group_to_client_indexes) ) for process_id in range(1, self.size): + total_sampled_data_size = 0 if self.topology_manager is None else total_sampled_data_size self.send_message_init_config( process_id, global_model_params, self.group_to_client_indexes[process_id - 1], sampled_group_to_client_indexes[process_id - 1], + total_sampled_data_size, process_id - 1 ) @@ -106,6 +111,7 @@ def handle_message_receive_model_from_edge(self, msg_params): return sampled_group_to_client_indexes = {} + total_sampled_data_size = 0 for group_idx, client_num_per_round in enumerate(self.group_to_client_num_per_round): client_num_in_total = len(self.group_to_client_indexes[group_idx]) sampled_client_indexes = self.aggregator.client_sampling( @@ -113,37 +119,43 @@ def handle_message_receive_model_from_edge(self, msg_params): client_num_in_total, client_num_per_round, ) + sampled_group_to_client_indexes[group_idx] = [] + for index in sampled_client_indexes: + client_idx = self.group_to_client_indexes[group_idx][index] + sampled_group_to_client_indexes[group_idx].append(client_idx) + total_sampled_data_size += self.aggregator.train_data_local_num_dict[client_idx] - sampled_group_to_client_indexes[group_idx] = [self.group_to_client_indexes[group_idx][index] - for index in sampled_client_indexes] logging.info( "client_indexes of each group = {}".format(sampled_group_to_client_indexes) ) for receiver_id in range(1, self.size): - if self.topology_manager is None: - self.send_message_sync_model_to_edge( - receiver_id, global_model_params, - sampled_group_to_client_indexes[receiver_id-1], receiver_id-1 - ) + if self.topology_manager is not None: + global_model_params = global_model_params_list[receiver_id - 1] else: - self.send_message_sync_model_to_edge( - receiver_id, global_model_params_list[receiver_id - 1], - sampled_group_to_client_indexes[receiver_id - 1], receiver_id - 1 - ) + total_sampled_data_size = 0 + self.send_message_sync_model_to_edge( + receiver_id, + global_model_params, + sampled_group_to_client_indexes[receiver_id - 1], + total_sampled_data_size, + receiver_id - 1 + ) - def send_message_init_config(self, receive_id, global_model_params, total_client_indexes, sampled_client_indexed, edge_index): + def send_message_init_config(self, receive_id, global_model_params, total_client_indexes, + sampled_client_indexed, total_sampled_data_size, edge_index): message = Message( MyMessage.MSG_TYPE_C2E_INIT_CONFIG, self.get_sender_id(), receive_id ) message.add_params(MyMessage.MSG_ARG_KEY_TOTAL_EDGE_CLIENTS, total_client_indexes) message.add_params(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS, sampled_client_indexed) + message.add_params(MyMessage.MSG_ARG_KEY_TOTAL_SAMPLED_DATA_SIZE, total_sampled_data_size) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_EDGE_INDEX, str(edge_index)) self.send_message(message) def send_message_sync_model_to_edge( - self, receive_id, global_model_params, sampled_client_indexed, edge_index + self, receive_id, global_model_params, sampled_client_indexed, total_sampled_data_size, edge_index ): logging.info("send_message_sync_model_to_edge. receive_id = %d" % receive_id) message = Message( @@ -152,6 +164,7 @@ def send_message_sync_model_to_edge( receive_id, ) message.add_params(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS, sampled_client_indexed) + message.add_params(MyMessage.MSG_ARG_KEY_TOTAL_SAMPLED_DATA_SIZE, total_sampled_data_size) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_EDGE_INDEX, str(edge_index)) self.send_message(message) diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py index 3a46dbedb8..62a4efb106 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py @@ -37,22 +37,26 @@ def handle_message_init(self, msg_params): global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) total_client_indexes = msg_params.get(MyMessage.MSG_ARG_KEY_TOTAL_EDGE_CLIENTS) sampled_client_indexes = msg_params.get(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS) + total_sampled_data_size = msg_params.get(MyMessage.MSG_ARG_KEY_TOTAL_SAMPLED_DATA_SIZE) edge_index = msg_params.get(MyMessage.MSG_ARG_KEY_EDGE_INDEX) self.group.setup_clients(total_client_indexes) self.args.round_idx = 0 - w_group_list, sample_num_list = self.group.train(self.args.round_idx, global_model_params, sampled_client_indexes) + w_group_list, sample_num_list = self.group.train(self.args.round_idx, global_model_params, + sampled_client_indexes, total_sampled_data_size) self.send_model_to_cloud(0, w_group_list, sample_num_list) def handle_message_receive_model_from_cloud(self, msg_params): logging.info("handle_message_receive_model_from_cloud.") sampled_client_indexes = msg_params.get(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS) + total_sampled_data_size = msg_params.get(MyMessage.MSG_ARG_KEY_TOTAL_SAMPLED_DATA_SIZE) global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) edge_index = msg_params.get(MyMessage.MSG_ARG_KEY_EDGE_INDEX) self.args.round_idx += 1 - w_group_list, sample_num_list = self.group.train(self.args.round_idx, global_model_params, sampled_client_indexes) + w_group_list, sample_num_list = self.group.train(self.args.round_idx, global_model_params, + sampled_client_indexes, total_sampled_data_size) self.send_model_to_cloud(0, w_group_list, sample_num_list) if self.args.round_idx == self.num_rounds: diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py b/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py index 9df21e6bb4..ef71d5e320 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py @@ -47,7 +47,7 @@ def get_sample_number(self, sampled_client_indexes): self.group_sample_number += self.train_data_local_num_dict[client_idx] return self.group_sample_number - def train(self, round_idx, w, sampled_client_indexes): + def train(self, round_idx, w, sampled_client_indexes, total_sampled_data_size=0): sampled_client_list = [self.client_dict[client_idx] for client_idx in sampled_client_indexes] w_group = w w_group_list = [] @@ -62,7 +62,14 @@ def train(self, round_idx, w, sampled_client_indexes): ) # train each client for client in sampled_client_list: - w_local = client.train(w_group) + if total_sampled_data_size > 0: + scaled_loss_factor = ( + self.args.group_num * len(sampled_client_list) + * client.local_sample_number / total_sampled_data_size + ) + w_local = client.train(w_group, scaled_loss_factor) + else: + w_local = client.train(w_group) w_locals.append((client.get_sample_number(), w_local)) # aggregate local weights diff --git a/python/fedml/simulation/mpi/hierarchical_fl/message_define.py b/python/fedml/simulation/mpi/hierarchical_fl/message_define.py index ecb0982b89..f5d6cd1c73 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/message_define.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/message_define.py @@ -24,6 +24,7 @@ class MyMessage(object): MSG_ARG_KEY_EDGE_INDEX = "edge_idx" MSG_ARG_KEY_TOTAL_EDGE_CLIENTS = "total_edge_clients" MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS = "sampled_edge_clients" + MSG_ARG_KEY_TOTAL_SAMPLED_DATA_SIZE = "total_sampled_data_size" MSG_ARG_KEY_TRAIN_CORRECT = "train_correct" MSG_ARG_KEY_TRAIN_ERROR = "train_error" diff --git a/python/fedml/simulation/mpi/hierarchical_fl/utils.py b/python/fedml/simulation/mpi/hierarchical_fl/utils.py index a68ae6ccac..4f7987ec35 100644 --- a/python/fedml/simulation/mpi/hierarchical_fl/utils.py +++ b/python/fedml/simulation/mpi/hierarchical_fl/utils.py @@ -4,20 +4,22 @@ import numpy as np import torch import wandb +import logging from sklearn.cluster import KMeans -def cal_mixing_consensus_speed(topo_weight_matrix, global_round_idx): +def cal_mixing_consensus_speed(topo_weight_matrix, global_round_idx, args): n_rows, n_cols = np.shape(topo_weight_matrix) assert n_rows == n_cols A = np.array(topo_weight_matrix) - 1 / n_rows p = 1 - np.linalg.norm(A, ord=2) ** 2 - wandb.log({"Groups/p": p, "comm_round": global_round_idx}) + if args.enable_wandb: + wandb.log({"Groups/p": p, "comm_round": global_round_idx}) return p -def visualize_group_detail(group_to_client_indexes, train_data_local_dict, train_data_local_num_dict, class_num): +def stats_group(group_to_client_indexes, train_data_local_dict, train_data_local_num_dict, class_num, args): xs = [i for i in range(class_num)] ys = [] @@ -37,12 +39,20 @@ def visualize_group_detail(group_to_client_indexes, train_data_local_dict, train ys.append(count_vector/count_vector.sum()) keys.append("Group {}".format(group_idx)) - wandb.log({"Groups/Client_num": len(group_to_client_indexes[group_idx]), "group_id": group_idx}) - wandb.log({"Groups/Data_size": data_size, "group_id": group_idx}) + if args.enable_wandb: + wandb.log({"Groups/Client_num": len(group_to_client_indexes[group_idx]), "group_id": group_idx}) + wandb.log({"Groups/Data_size": data_size, "group_id": group_idx}) - wandb.log({"Groups/Data_distribution": - wandb.plot.line_series(xs=xs, ys=ys, keys=keys, title="Data distribution", xname="Label")} - ) + logging.info("Group {}: client num={}, data size={} ".format( + group_idx, + len(group_to_client_indexes[group_idx]), + data_size + )) + + if args.enable_wandb: + wandb.log({"Groups/Data_distribution": + wandb.plot.line_series(xs=xs, ys=ys, keys=keys, title="Data distribution", xname="Label")} + ) def hetero_partition_groups(clients_type_list, num_groups, alpha=0.5):