forked from octadion/fedlearn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.py
More file actions
94 lines (68 loc) · 3.21 KB
/
client.py
File metadata and controls
94 lines (68 loc) · 3.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import flwr as fl
import torch
from collections import OrderedDict
from model import SimpleCNN, train, test
from data_utils import load_datasets, create_non_iid_distribution, get_client_loaders
import argparse
class FederatedClient(fl.client.NumPyClient):
def __init__(self, client_id, trainloader, testloader, device):
self.client_id = client_id
self.trainloader = trainloader
self.testloader = testloader
self.device = device
self.model = SimpleCNN().to(device)
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def set_parameters(self, parameters):
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
print(f"\nClient {self.client_id} - Starting local training...")
self.set_parameters(parameters)
epochs = config.get("local_epochs", 2)
loss = train(self.model, self.trainloader, epochs, self.device)
test_loss, accuracy = test(self.model, self.testloader, self.device)
print(f"Client {self.client_id} - Training complete | Loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")
return (
self.get_parameters(config={}),
len(self.trainloader.dataset),
{"loss": float(test_loss), "accuracy": float(accuracy)}
)
def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(self.model, self.testloader, self.device)
print(f"Client {self.client_id} - Evaluation | Loss={loss:.4f}, Acc={accuracy:.4f}")
return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)}
def start_client(client_id, server_address="localhost:8080"):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n{'='*60}")
print(f"FEDERATED CLIENT {client_id}")
print(f"{'='*60}")
print(f"Device: {device}")
print("Loading data...")
trainset, testset = load_datasets()
client_indices = create_non_iid_distribution(trainset, num_clients=3)
trainloader, testloader = get_client_loaders(
client_id,
client_indices,
trainset,
testset,
batch_size=64
)
print(f"Training samples: {len(trainloader.dataset)}")
print(f"Connecting to server: {server_address}")
print(f"{'='*60}\n")
client = FederatedClient(client_id, trainloader, testloader, device)
fl.client.start_client(
server_address=server_address,
client=client.to_client()
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Federated Learning Client')
parser.add_argument('--client_id', type=int, required=True,
help='Client ID (0, 1, or 2)')
parser.add_argument('--server', type=str, default='localhost:8080',
help='Server address')
args = parser.parse_args()
start_client(args.client_id, args.server)