Skip to content

Commit 1440a98

Browse files
authored
Merge pull request #69 from atong01/forest_flow
Adding Forest-Flow: Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees
2 parents 81fcb8d + cac49f1 commit 1440a98

File tree

9 files changed

+825
-17
lines changed

9 files changed

+825
-17
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
## Description
2525

26-
Conditional Flow Matching (CFM) is a fast way to train continuous normalizing flow (CNF) models. CFM is a simulation-free training objective for continuous normalizing flows that allows conditional generative modeling and speeds up training and inference. CFM's performance closes the gap between CNFs and diffusion models. To spread its use within the machine learning community, we have built a library focused on Flow Matching methods: TorchCFM. TorchCFM is a library showing how Flow Matching methods can be trained and used to deal with image generation, single-cell dynamics and (soon) SO(3) data and tabular data.
26+
Conditional Flow Matching (CFM) is a fast way to train continuous normalizing flow (CNF) models. CFM is a simulation-free training objective for continuous normalizing flows that allows conditional generative modeling and speeds up training and inference. CFM's performance closes the gap between CNFs and diffusion models. To spread its use within the machine learning community, we have built a library focused on Flow Matching methods: TorchCFM. TorchCFM is a library showing how Flow Matching methods can be trained and used to deal with image generation, single-cell dynamics, tabular data and soon SO(3) data.
2727

2828
<p align="center">
2929
<img src="assets/169_generated_samples_otcfm.png" width="600"/>
@@ -107,8 +107,8 @@ List of implemented papers:
107107
- Building Normalizing Flows with Stochastic Interpolants (Albergo et al. 2023a) [Paper](https://openreview.net/forum?id=li7qeBbCR1t)
108108
- Action Matching: Learning Stochastic Dynamics From Samples (Neklyudov et al. 2022) [Paper](https://arxiv.org/abs/2210.06662) [Code](https://github.com/necludov/jam)
109109
- Concurrent work to our OT-CFM method: Multisample Flow Matching: Straightening Flows with Minibatch Couplings (Pooladian et al. 2023) [Paper](https://arxiv.org/abs/2304.14772)
110-
- Soon: SE(3)-Stochastic Flow Matching for Protein Backbone Generation (Bose et al.) [paper](https://arxiv.org/abs/2310.02391)
111-
- Soon: Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees (Jolicoeur-Martineau et al.) [paper](https://arxiv.org/abs/2309.09968)
110+
- Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees (Jolicoeur-Martineau et al.) [Paper](https://arxiv.org/abs/2309.09968) [Code](https://github.com/SamsungSAILMontreal/ForestDiffusion)
111+
- Soon: SE(3)-Stochastic Flow Matching for Protein Backbone Generation (Bose et al.) [Paper](https://arxiv.org/abs/2310.02391)
112112

113113
## How to run
114114

@@ -155,7 +155,7 @@ python -m ipykernel install --user --name=torchcfm
155155

156156
## Project Structure
157157

158-
The directory structure of a new project looks like this:
158+
The directory structure looks like this:
159159

160160
```
161161

examples/notebooks/mnist_example.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@
253253
"outputs": [],
254254
"source": [
255255
"# follows example from https://github.com/google-research/torchsde/blob/master/examples/cont_ddpm.py\n",
256+
"\n",
257+
"\n",
256258
"class SDE(torch.nn.Module):\n",
257259
" noise_type = \"diagonal\"\n",
258260
" sde_type = \"ito\"\n",

examples/notebooks/training-8gaussians-to-moons.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,8 @@
843843
],
844844
"source": [
845845
"# %%time\n",
846+
"\n",
847+
"\n",
846848
"class MLP2(torch.nn.Module):\n",
847849
" def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n",
848850
" super().__init__()\n",

examples/tabular/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Forest-Flow experiment on the Iris dataset using TorchCFM
2+
3+
This notebook is a self-contained example showing how to train the novel Forest-Flow method to generate tabular data [(Jolicoeur-Martineau et al. 2023)](https://arxiv.org/abs/2309.09968). The idea behind Forest-Flow is to **learn Independent Conditional Flow-Matching's vector field with XGBoost models** instead of neural networks. The motivation is that it is known that Forests work currently better than neural networks on Tabular data tasks. This idea comes with some difficulties, for instance how to approximate Flow Matching's loss, and this notebook shows how to do it on a minimal example. The method, its training procedure and the experiments are described in [(Jolicoeur-Martineau et al. 2023)](https://arxiv.org/abs/2309.09968). The full code can be found [here](https://github.com/SamsungSAILMontreal/ForestDiffusion).
4+
5+
To run our jupyter notebooks, installing our package:
6+
7+
```bash
8+
cd ../../
9+
10+
# install torchcfm
11+
pip install -e '.[forest-flow]'
12+
13+
# install ipykernel
14+
conda install -c anaconda ipykernel
15+
16+
# install conda env in jupyter notebook
17+
python -m ipykernel install --user --name=torchcfm
18+
19+
# launch our notebooks with the torchcfm kernel
20+
```

0 commit comments

Comments
 (0)