Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions examples/AltDiffusion/dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pathlib import Path

import torch
from torch.cuda.amp import autocast as autocast
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
Expand All @@ -34,7 +35,7 @@
train_text_encoder = False
train_only_unet = True

num_train_epochs = 500
num_train_epochs = 10
batch_size = 4
learning_rate = 5e-6
adam_beta1 = 0.9
Expand Down Expand Up @@ -197,20 +198,23 @@ def collate_fn(examples):
if with_prior_preservation:
x, x_prior = torch.chunk(x, 2, dim=0)
c, c_prior = torch.chunk(c, 2, dim=0)
loss, _ = model(x, c)
with autocast():
loss, _ = model(x, c)
prior_loss, _ = model(x_prior, c_prior)
loss = loss + prior_loss_weight * prior_loss
else:
loss, _ = model(x, c)
with autocast():
loss, _ = model(x, c)

print('*'*20, "loss=", str(loss.detach().item()))

loss.backward()
optimizer.step()
optimizer.zero_grad()
with autocast():
loss.backward()
optimizer.step()
optimizer.zero_grad()

## mkdir ./checkpoints/DreamBooth and copy ./checkpoints/AltDiffusion to ./checkpoints/DreamBooth/AltDiffusion
## overwrite model.ckpt for latter usage
chekpoint_path = './checkpoints/DreamBooth/AltDiffusion/model.ckpt'
chekpoint_path = './checkpoints/AltDiffusion/dreambooth_model.ckpt'
torch.save(model.state_dict(), chekpoint_path)

Binary file added examples/AltDiffusion/instance_images/0.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/1.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/10.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/11.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/12.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/13.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/14.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/15.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/3.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/4.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/5.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/6.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/7.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/8.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/AltDiffusion/instance_images/9.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion flagai/model/mm/AltDiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,7 @@ def p_losses(self, x_start, cond, t, noise=None):

loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

t = t.cpu()
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
Expand Down Expand Up @@ -1932,4 +1933,4 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
]

return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
((mean1 - mean2)**2) * torch.exp(-logvar2))
((mean1 - mean2)**2) * torch.exp(-logvar2))
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ PyYAML==5.4.1
deepspeed==0.6.5
flash-attn==1.0.2
bminf
torch
Loading