-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining.py
More file actions
96 lines (76 loc) · 3.08 KB
/
training.py
File metadata and controls
96 lines (76 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim import SGD
from torch.utils.data import DataLoader
from tqdm import notebook
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
random_t = torch.rand(x.shape[0]) * (1. - eps) + eps
std = marginal_prob_std(random_t)
random_t = torch.reshape(random_t, (x.shape[0], 1))
z = torch.randn_like(x)
perturbed_x = x + z * std[:, None]
x_with_t = torch.hstack([perturbed_x,random_t])
x_with_t = x_with_t.to(torch.float32)
score = model(x_with_t)
loss = torch.mean(torch.sum((score * std[:, None] + z)**2, dim=0))
return loss
def CDE_loss_fn_BOD(model, x, marginal_prob_std, eps=1e-5):
y = x[:,[2,3,4,5,6]]
x = x[:,[0,1]]
random_t = torch.rand(x.shape[0]) * (1. - eps) + eps
std = marginal_prob_std(random_t)
random_t = torch.reshape(random_t, (x.shape[0], 1))
z = torch.randn_like(x)
perturbed_x = x + z * std[:, None]
perturbed_x = torch.hstack([perturbed_x,y])
x_with_t = torch.hstack([perturbed_x,random_t])
x_with_t = x_with_t.to(torch.float32)
score = model(x_with_t)
loss = torch.mean(torch.sum((score * std[:, None] + z)**2, dim=0))
return loss
def CDE_loss_fn_2D(model, x, marginal_prob_std, eps=1e-5):
y = x[:,[1]]
x = x[:,[0]]
x = x.to(torch.float32)
y = y.to(torch.float32)
random_t = torch.rand(x.shape[0], device='cpu') * (1. - eps) + eps
random_t = random_t.to(torch.float32)
std = marginal_prob_std(random_t)
random_t = random_t[:,None]
z = torch.randn_like(x)
perturbed_x = x + z * std[:, None]
score = model(perturbed_x, y, random_t)
loss = torch.mean(torch.sum((score * std[:, None] + z)**2, dim=0))
return loss
def loss_fn_MNIST(model, x, marginal_prob_std, eps=1e-5):
random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
z = torch.randn_like(x)
std = marginal_prob_std(random_t)
perturbed_x = x + z * std[:, None, None, None]
score = model(perturbed_x, random_t)
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
return loss
def train_model(score_model, data, loss_fn, marginal_prob_std_fn, file, epochs = 100, batch_size = 32, lr = 1e-4):
optimizer = Adam(score_model.parameters(), lr=lr)
dataset = data
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
tqdm_epoch = notebook.trange(epochs)
losses = []
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for x in data_loader:
loss = loss_fn(score_model, x, marginal_prob_std_fn)
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
torch.save(score_model.state_dict(), file)
losses.append(avg_loss / num_items)
return losses