Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- name: Checkout
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/test_runner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10"]

steps:
- name: Checkout
Expand All @@ -27,7 +27,8 @@ jobs:

- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Fix pip version < 24.1 due to lightning incomaptibility
python -m pip install pip==23.2.1
pip install -r runner-requirements.txt
pip install pytest
pip install sh
Expand Down Expand Up @@ -56,7 +57,8 @@ jobs:

- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Fix pip version < 24.1 due to lightning incomaptibility
python -m pip install pip==23.2.1
pip install -r runner-requirements.txt
pip install pytest
pip install pytest-cov[toml]
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@ slurm*.out

notebooks/figures/

examples/single_cell/single_cell_sf2m_grn/Synthetic-I

.DS_Store
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ repos:

# python docstring formatting
- repo: https://github.com/myint/docformatter
rev: v1.7.5
rev: master
hooks:
- id: docformatter
require_serial: true
Expand Down
26 changes: 10 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
[![code-quality](https://github.com/atong01/conditional-flow-matching/actions/workflows/code-quality-main.yaml/badge.svg)](https://github.com/atong01/conditional-flow-matching/actions/workflows/code-quality-main.yaml)
[![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/atong01/conditional-flow-matching#license)
<a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a>
[![Downloads](https://static.pepy.tech/badge/torchcfm)](https://pepy.tech/project/torchcfm)
[![Downloads](https://static.pepy.tech/badge/torchcfm/month)](https://pepy.tech/project/torchcfm)

</div>

Expand Down Expand Up @@ -61,11 +63,14 @@ A. Tong, N. Malkin, G. Huguet, Y. Zhang, J. Rector-Brooks, K. Fatras, G. Wolf, Y
</summary>

```bibtex
@article{tong2023improving,
title={Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport},
author={Tong, Alexander and Malkin, Nikolay and Huguet, Guillaume and Zhang, Yanlei and {Rector-Brooks}, Jarrid and Fatras, Kilian and Wolf, Guy and Bengio, Yoshua},
year={2023},
journal={arXiv preprint 2302.00482}
@article{tong2024improving,
title={Improving and generalizing flow-based generative models with minibatch optimal transport},
author={Alexander Tong and Kilian FATRAS and Nikolay Malkin and Guillaume Huguet and Yanlei Zhang and Jarrid Rector-Brooks and Guy Wolf and Yoshua Bengio},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2024},
url={https://openreview.net/forum?id=CD9Snc73AW},
note={Expert Certification}
}
```

Expand Down Expand Up @@ -196,17 +201,6 @@ Before making an issue, please verify that:

Suggestions for improvements are always welcome!

## Sponsors

TorchCFM development and maintenance are financially supported by:

<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="assets/DF_logo_dark.png" width="300"/>
<img alt="DF logo changing depending on mode.'" src="assets/DF_logo.png" width="300"/>
</picture>
</p>

## License

Conditional-Flow-Matching is licensed under the MIT License.
Expand Down
2 changes: 1 addition & 1 deletion examples/2D_tutorials/Flow_matching_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@
" x0, x1 = ot_sampler.sample_plan(x0, x1)\n",
"\n",
" t = torch.rand(x0.shape[0]).type_as(x0)\n",
" xt = sample_xt(x0, x1, t, sigma=0.01)\n",
" xt = sample_conditional_pt(x0, x1, t, sigma=0.01)\n",
" ut = compute_conditional_vector_field(x0, x1)\n",
"\n",
" vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
Expand Down
304 changes: 304 additions & 0 deletions examples/2D_tutorials/Maximum_likelihood_CNF_tutorial.ipynb

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions examples/images/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_siz
python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

Note that you can train all our methods in parallel using multiple GPUs and DataParallel. You can do this by setting the parallel flag to True in the command line. As an example:
Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. Please refer to [the official document for the usage](https://pytorch.org/docs/stable/elastic/run.html#usage). As an example:

```bash
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True
torchrun --standalone --nnodes=1 --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT"
```

*Note from the authors*: We have observed that training with parallel leads to slightly poorer performance than what you can get with one GPU. The reason is probably that DataParallel computes statistics over each device. We are thinking of using DistributedDataParallel to solve this problem in the future. In the meantime, we strongly encourage users to train on a single GPU (the provided scripts require about 8G of GPU memory).

To compute the FID from the OT-CFM model at end of training, run:

```bash
Expand Down
8 changes: 5 additions & 3 deletions examples/images/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
flags.DEFINE_integer("step", 400000, help="training steps")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)")
flags.DEFINE_integer("batch_size_fid", 1024, help="Batch size to compute FID")

FLAGS(sys.argv)


Expand All @@ -49,7 +51,7 @@
# Load the model
PATH = f"{FLAGS.input_dir}/{FLAGS.model}/{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt"
print("path: ", PATH)
checkpoint = torch.load(PATH)
checkpoint = torch.load(PATH, map_location=device)
state_dict = checkpoint["ema_model"]
try:
new_net.load_state_dict(state_dict)
Expand All @@ -70,7 +72,7 @@

def gen_1_img(unused_latent):
with torch.no_grad():
x = torch.randn(500, 3, 32, 32, device=device)
x = torch.randn(FLAGS.batch_size_fid, 3, 32, 32, device=device)
if FLAGS.integration_method == "euler":
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device)
Expand All @@ -90,7 +92,7 @@ def gen_1_img(unused_latent):
score = fid.compute_fid(
gen=gen_1_img,
dataset_name="cifar10",
batch_size=500,
batch_size=FLAGS.batch_size_fid,
dataset_res=32,
num_gen=FLAGS.num_gen,
dataset_split="train",
Expand Down
214 changes: 214 additions & 0 deletions examples/images/cifar10/train_cifar10_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Kilian Fatras
# Alexander Tong
# Imahn Shekhzadeh

import copy
import math
import os

import torch
from absl import app, flags
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop, setup

from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher,
ExactOptimalTransportConditionalFlowMatcher,
TargetConditionalFlowMatcher,
VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper

FLAGS = flags.FLAGS

flags.DEFINE_string("model", "otcfm", help="flow matching model type")
flags.DEFINE_string("output_dir", "./results/", help="output_directory")
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_float("lr", 2e-4, help="target learning rate") # TRY 2e-4
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
flags.DEFINE_integer(
"total_steps", 400001, help="total training steps"
) # Lipman et al uses 400k but double batch size
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128
flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader")
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", False, help="multi gpu training")
flags.DEFINE_string(
"master_addr", "localhost", help="master address for Distributed Data Parallel"
)
flags.DEFINE_string("master_port", "12355", help="master port for Distributed Data Parallel")

# Evaluation
flags.DEFINE_integer(
"save_step",
20000,
help="frequency of saving checkpoints, 0 to disable during training",
)


def warmup_lr(step):
return min(step, FLAGS.warmup) / FLAGS.warmup


def train(rank, total_num_gpus, argv):
print(
"lr, total_steps, ema decay, save_step:",
FLAGS.lr,
FLAGS.total_steps,
FLAGS.ema_decay,
FLAGS.save_step,
)

if FLAGS.parallel and total_num_gpus > 1:
# When using `DistributedDataParallel`, we need to divide the batch
# size ourselves based on the total number of GPUs of the current node.
batch_size_per_gpu = FLAGS.batch_size // total_num_gpus
setup(rank, total_num_gpus, FLAGS.master_addr, FLAGS.master_port)
else:
batch_size_per_gpu = FLAGS.batch_size

# DATASETS/DATALOADER
dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
),
)
sampler = DistributedSampler(dataset) if FLAGS.parallel else None
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size_per_gpu,
sampler=sampler,
shuffle=False if FLAGS.parallel else True,
num_workers=FLAGS.num_workers,
drop_last=True,
)

datalooper = infiniteloop(dataloader)

# Calculate number of epochs
steps_per_epoch = math.ceil(len(dataset) / FLAGS.batch_size)
num_epochs = math.ceil(FLAGS.total_steps / steps_per_epoch)

# MODELS
net_model = UNetModelWrapper(
dim=(3, 32, 32),
num_res_blocks=2,
num_channels=FLAGS.num_channel,
channel_mult=[1, 2, 2, 2],
num_heads=4,
num_head_channels=64,
attention_resolutions="16",
dropout=0.1,
).to(
rank
) # new dropout + bs of 128

ema_model = copy.deepcopy(net_model)
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
if FLAGS.parallel:
net_model = DistributedDataParallel(net_model, device_ids=[rank])
ema_model = DistributedDataParallel(ema_model, device_ids=[rank])

# show model size
model_size = 0
for param in net_model.parameters():
model_size += param.data.nelement()
print("Model params: %.2f M" % (model_size / 1024 / 1024))

#################################
# OT-CFM
#################################

sigma = 0.0
if FLAGS.model == "otcfm":
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "icfm":
FM = ConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "fm":
FM = TargetConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "si":
FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
else:
raise NotImplementedError(
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
)

savedir = FLAGS.output_dir + FLAGS.model + "/"
os.makedirs(savedir, exist_ok=True)

global_step = 0 # to keep track of the global step in training loop

with trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
for epoch in epoch_pbar:
epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
if sampler is not None:
sampler.set_epoch(epoch)

with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar:
for step in step_pbar:
global_step += step

optim.zero_grad()
x1 = next(datalooper).to(rank)
x0 = torch.randn_like(x1)
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
vt = net_model(t, xt)
loss = torch.mean((vt - ut) ** 2)
loss.backward()
torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new
optim.step()
sched.step()
ema(net_model, ema_model, FLAGS.ema_decay) # new

# sample and Saving the weights
if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0:
generate_samples(
net_model, FLAGS.parallel, savedir, global_step, net_="normal"
)
generate_samples(
ema_model, FLAGS.parallel, savedir, global_step, net_="ema"
)
torch.save(
{
"net_model": net_model.state_dict(),
"ema_model": ema_model.state_dict(),
"sched": sched.state_dict(),
"optim": optim.state_dict(),
"step": global_step,
},
savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt",
)


def main(argv):
# get world size (number of GPUs)
total_num_gpus = int(os.getenv("WORLD_SIZE", 1))

if FLAGS.parallel and total_num_gpus > 1:
train(rank=int(os.getenv("RANK", 0)), total_num_gpus=total_num_gpus, argv=argv)
else:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train(rank=device, total_num_gpus=total_num_gpus, argv=argv)


if __name__ == "__main__":
app.run(main)
Loading