Skip to content

Commit c25e191

Browse files
add support for distributed data parallel training (#116)
* make code changes in `train_cifar10.py` to allow DDP (distributed data parallel) * add instructions to README on how to run cifar10 image generation code on multiple GPUs * fix: when running cifar10 image generation on multiple gpus, use `rank` for device setting * fix: load checkpoint on right device * fix runner ci requirements (#125) * change pytorch lightning version * fix pip version * fix pip in code cov * change variable name `world_size` to `total_num_gpus` * change: do not overwrite batch size flag * add, refactor: calculate number of epochs based on total number of steps, rewrite training loop to use epochs instead of steps * fix: add `sampler.set_epoch(epoch)` to training loop to shuffle data in distributed mode * rename file, update README * add original CIFAR10 training file --------- Co-authored-by: Alexander Tong <alexandertongdev@gmail.com>
1 parent b4525b5 commit c25e191

File tree

4 files changed

+247
-5
lines changed

4 files changed

+247
-5
lines changed

examples/images/cifar10/README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,12 @@ python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_siz
2626
python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
2727
```
2828

29-
Note that you can train all our methods in parallel using multiple GPUs and DataParallel. You can do this by setting the parallel flag to True in the command line. As an example:
29+
Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. As an example:
3030

3131
```bash
32-
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True
32+
torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT"
3333
```
3434

35-
*Note from the authors*: We have observed that training with parallel leads to slightly poorer performance than what you can get with one GPU. The reason is probably that DataParallel computes statistics over each device. We are thinking of using DistributedDataParallel to solve this problem in the future. In the meantime, we strongly encourage users to train on a single GPU (the provided scripts require about 8G of GPU memory).
36-
3735
To compute the FID from the OT-CFM model at end of training, run:
3836

3937
```bash

examples/images/cifar10/compute_fid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
# Load the model
5252
PATH = f"{FLAGS.input_dir}/{FLAGS.model}/{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt"
5353
print("path: ", PATH)
54-
checkpoint = torch.load(PATH)
54+
checkpoint = torch.load(PATH, map_location=device)
5555
state_dict = checkpoint["ema_model"]
5656
try:
5757
new_net.load_state_dict(state_dict)
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.
2+
3+
# Authors: Kilian Fatras
4+
# Alexander Tong
5+
# Imahn Shekhzadeh
6+
7+
import copy
8+
import math
9+
import os
10+
11+
import torch
12+
from absl import app, flags
13+
from torch.nn.parallel import DistributedDataParallel
14+
from torch.utils.data import DistributedSampler
15+
from torchdyn.core import NeuralODE
16+
from torchvision import datasets, transforms
17+
from tqdm import trange
18+
from utils_cifar import ema, generate_samples, infiniteloop, setup
19+
20+
from torchcfm.conditional_flow_matching import (
21+
ConditionalFlowMatcher,
22+
ExactOptimalTransportConditionalFlowMatcher,
23+
TargetConditionalFlowMatcher,
24+
VariancePreservingConditionalFlowMatcher,
25+
)
26+
from torchcfm.models.unet.unet import UNetModelWrapper
27+
28+
FLAGS = flags.FLAGS
29+
30+
flags.DEFINE_string("model", "otcfm", help="flow matching model type")
31+
flags.DEFINE_string("output_dir", "./results/", help="output_directory")
32+
# UNet
33+
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")
34+
35+
# Training
36+
flags.DEFINE_float("lr", 2e-4, help="target learning rate") # TRY 2e-4
37+
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
38+
flags.DEFINE_integer(
39+
"total_steps", 400001, help="total training steps"
40+
) # Lipman et al uses 400k but double batch size
41+
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
42+
flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128
43+
flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader")
44+
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
45+
flags.DEFINE_bool("parallel", False, help="multi gpu training")
46+
flags.DEFINE_string(
47+
"master_addr", "localhost", help="master address for Distributed Data Parallel"
48+
)
49+
flags.DEFINE_string("master_port", "12355", help="master port for Distributed Data Parallel")
50+
51+
# Evaluation
52+
flags.DEFINE_integer(
53+
"save_step",
54+
20000,
55+
help="frequency of saving checkpoints, 0 to disable during training",
56+
)
57+
58+
59+
def warmup_lr(step):
60+
return min(step, FLAGS.warmup) / FLAGS.warmup
61+
62+
63+
def train(rank, total_num_gpus, argv):
64+
print(
65+
"lr, total_steps, ema decay, save_step:",
66+
FLAGS.lr,
67+
FLAGS.total_steps,
68+
FLAGS.ema_decay,
69+
FLAGS.save_step,
70+
)
71+
72+
if FLAGS.parallel and total_num_gpus > 1:
73+
# When using `DistributedDataParallel`, we need to divide the batch
74+
# size ourselves based on the total number of GPUs of the current node.
75+
batch_size_per_gpu = FLAGS.batch_size // total_num_gpus
76+
setup(rank, total_num_gpus, FLAGS.master_addr, FLAGS.master_port)
77+
else:
78+
batch_size_per_gpu = FLAGS.batch_size
79+
80+
# DATASETS/DATALOADER
81+
dataset = datasets.CIFAR10(
82+
root="./data",
83+
train=True,
84+
download=True,
85+
transform=transforms.Compose(
86+
[
87+
transforms.RandomHorizontalFlip(),
88+
transforms.ToTensor(),
89+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
90+
]
91+
),
92+
)
93+
sampler = DistributedSampler(dataset) if FLAGS.parallel else None
94+
dataloader = torch.utils.data.DataLoader(
95+
dataset,
96+
batch_size=batch_size_per_gpu,
97+
sampler=sampler,
98+
shuffle=False if FLAGS.parallel else True,
99+
num_workers=FLAGS.num_workers,
100+
drop_last=True,
101+
)
102+
103+
datalooper = infiniteloop(dataloader)
104+
105+
# Calculate number of epochs
106+
steps_per_epoch = math.ceil(len(dataset) / FLAGS.batch_size)
107+
num_epochs = math.ceil(FLAGS.total_steps / steps_per_epoch)
108+
109+
# MODELS
110+
net_model = UNetModelWrapper(
111+
dim=(3, 32, 32),
112+
num_res_blocks=2,
113+
num_channels=FLAGS.num_channel,
114+
channel_mult=[1, 2, 2, 2],
115+
num_heads=4,
116+
num_head_channels=64,
117+
attention_resolutions="16",
118+
dropout=0.1,
119+
).to(
120+
rank
121+
) # new dropout + bs of 128
122+
123+
ema_model = copy.deepcopy(net_model)
124+
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
125+
sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
126+
if FLAGS.parallel:
127+
net_model = DistributedDataParallel(net_model, device_ids=[rank])
128+
ema_model = DistributedDataParallel(ema_model, device_ids=[rank])
129+
130+
# show model size
131+
model_size = 0
132+
for param in net_model.parameters():
133+
model_size += param.data.nelement()
134+
print("Model params: %.2f M" % (model_size / 1024 / 1024))
135+
136+
#################################
137+
# OT-CFM
138+
#################################
139+
140+
sigma = 0.0
141+
if FLAGS.model == "otcfm":
142+
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
143+
elif FLAGS.model == "icfm":
144+
FM = ConditionalFlowMatcher(sigma=sigma)
145+
elif FLAGS.model == "fm":
146+
FM = TargetConditionalFlowMatcher(sigma=sigma)
147+
elif FLAGS.model == "si":
148+
FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
149+
else:
150+
raise NotImplementedError(
151+
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
152+
)
153+
154+
savedir = FLAGS.output_dir + FLAGS.model + "/"
155+
os.makedirs(savedir, exist_ok=True)
156+
157+
global_step = 0 # to keep track of the global step in training loop
158+
159+
with trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
160+
for epoch in epoch_pbar:
161+
epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
162+
if sampler is not None:
163+
sampler.set_epoch(epoch)
164+
165+
with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar:
166+
for step in step_pbar:
167+
global_step += step
168+
169+
optim.zero_grad()
170+
x1 = next(datalooper).to(rank)
171+
x0 = torch.randn_like(x1)
172+
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
173+
vt = net_model(t, xt)
174+
loss = torch.mean((vt - ut) ** 2)
175+
loss.backward()
176+
torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new
177+
optim.step()
178+
sched.step()
179+
ema(net_model, ema_model, FLAGS.ema_decay) # new
180+
181+
# sample and Saving the weights
182+
if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0:
183+
generate_samples(
184+
net_model, FLAGS.parallel, savedir, global_step, net_="normal"
185+
)
186+
generate_samples(
187+
ema_model, FLAGS.parallel, savedir, global_step, net_="ema"
188+
)
189+
torch.save(
190+
{
191+
"net_model": net_model.state_dict(),
192+
"ema_model": ema_model.state_dict(),
193+
"sched": sched.state_dict(),
194+
"optim": optim.state_dict(),
195+
"step": global_step,
196+
},
197+
savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt",
198+
)
199+
200+
201+
def main(argv):
202+
# get world size (number of GPUs)
203+
total_num_gpus = int(os.getenv("WORLD_SIZE", 1))
204+
205+
if FLAGS.parallel and total_num_gpus > 1:
206+
train(rank=int(os.getenv("RANK", 0)), total_num_gpus=total_num_gpus, argv=argv)
207+
else:
208+
use_cuda = torch.cuda.is_available()
209+
device = torch.device("cuda" if use_cuda else "cpu")
210+
train(rank=device, total_num_gpus=total_num_gpus, argv=argv)
211+
212+
213+
if __name__ == "__main__":
214+
app.run(main)

examples/images/cifar10/utils_cifar.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import copy
2+
import os
23

34
import torch
5+
from torch import distributed as dist
46
from torchdyn.core import NeuralODE
57

68
# from torchvision.transforms import ToPILImage
@@ -10,6 +12,34 @@
1012
device = torch.device("cuda" if use_cuda else "cpu")
1113

1214

15+
def setup(
16+
rank: int,
17+
total_num_gpus: int,
18+
master_addr: str = "localhost",
19+
master_port: str = "12355",
20+
backend: str = "nccl",
21+
):
22+
"""Initialize the distributed environment.
23+
24+
Args:
25+
rank: Rank of the current process.
26+
total_num_gpus: Number of GPUs used in the job.
27+
master_addr: IP address of the master node.
28+
master_port: Port number of the master node.
29+
backend: Backend to use.
30+
"""
31+
32+
os.environ["MASTER_ADDR"] = master_addr
33+
os.environ["MASTER_PORT"] = master_port
34+
35+
# initialize the process group
36+
dist.init_process_group(
37+
backend=backend,
38+
rank=rank,
39+
world_size=total_num_gpus,
40+
)
41+
42+
1343
def generate_samples(model, parallel, savedir, step, net_="normal"):
1444
"""Save 64 generated images (8 x 8) for sanity check along training.
1545

0 commit comments

Comments
 (0)