Skip to content

Commit 4cdb732

Browse files
committed
code supporting pytorch 0.4.1
1 parent 87e4218 commit 4cdb732

File tree

7 files changed

+687
-2
lines changed

7 files changed

+687
-2
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,19 @@ in **lib/network.py**.
1010

1111
## Environment
1212
- python 3.6
13-
- pytorch 0.3.0
13+
- pytorch 0.4.1
1414

1515
## Update Records
1616
1. Figure out how to implement the **concatenation** type, and add the code to **lib/**.
1717
2. Fix the bug in **lib/non_local.py** (old version) when using multi-gpu. Someone shares the
1818
reason with me, and you can find it in [here](https://github.com/pytorch/pytorch/issues/8637).
19-
3. Fix the bug of 3D pooling in **lib/non_local.py** (old version). Appreciate
19+
3. Fix the error of 3D pooling in **lib/non_local.py** (old version). Appreciate
2020
[**protein27**](https://github.com/AlexHex7/Non-local_pytorch/issues/17) for pointing it out.
2121
4. For convenience, I split the **lib/non_local.py** into four python files, and move the
2222
old versions (**lib/non_loca.py** and **lib/non_local_simple_version.py**) into
2323
**lib/backup/**.
24+
5. modify the code to support pytorch 0.4.1, and move the code supporting pytorch 0.3.1 \
25+
to **Non-Local_pytorch_0.3.1/**.
2426

2527

2628
## Running Steps

demo_MNIST.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
import torch.utils.data as Data
3+
import torchvision
4+
from lib.network import Network
5+
from torch import nn
6+
import time
7+
8+
9+
def calc_acc(x, y):
10+
x = torch.max(x, dim=-1)[1]
11+
accuracy = sum(x == y) / x.size(0)
12+
return accuracy
13+
14+
15+
train_data = torchvision.datasets.MNIST(root='./mnist', train=True,
16+
transform=torchvision.transforms.ToTensor(),
17+
download=True)
18+
test_data = torchvision.datasets.MNIST(root='./mnist/',
19+
transform=torchvision.transforms.ToTensor(),
20+
train=False)
21+
22+
train_loader = Data.DataLoader(dataset=train_data, batch_size=128, shuffle=True)
23+
test_loader = Data.DataLoader(dataset=test_data, batch_size=128, shuffle=False)
24+
25+
train_batch_num = len(train_loader)
26+
test_batch_num = len(test_loader)
27+
28+
net = Network()
29+
if torch.cuda.is_available():
30+
net = nn.DataParallel(net)
31+
net.cuda()
32+
33+
opt = torch.optim.Adam(net.parameters(), lr=0.001)
34+
loss_func = nn.CrossEntropyLoss()
35+
36+
for epoch_index in range(20):
37+
st = time.time()
38+
39+
torch.set_grad_enabled(True)
40+
net.train()
41+
for train_batch_index, (img_batch, label_batch) in enumerate(train_loader):
42+
if torch.cuda.is_available():
43+
img_batch = img_batch.cuda()
44+
label_batch = label_batch.cuda()
45+
46+
predict = net(img_batch)
47+
acc = calc_acc(predict.cpu().data, label_batch.cpu().data)
48+
loss = loss_func(predict, label_batch)
49+
50+
net.zero_grad()
51+
loss.backward()
52+
opt.step()
53+
54+
print('(LR:%f) Time of a epoch:%.4fs' % (opt.param_groups[0]['lr'], time.time()-st))
55+
56+
torch.set_grad_enabled(False)
57+
net.eval()
58+
total_loss = []
59+
total_acc = 0
60+
total_sample = 0
61+
62+
for test_batch_index, (img_batch, label_batch) in enumerate(test_loader):
63+
if torch.cuda.is_available():
64+
img_batch = img_batch.cuda()
65+
label_batch = label_batch.cuda()
66+
67+
predict = net(img_batch)
68+
loss = loss_func(predict, label_batch)
69+
70+
predict = predict.argmax(dim=1)
71+
acc = (predict == label_batch).sum()
72+
73+
total_loss.append(loss)
74+
total_acc += acc
75+
total_sample += img_batch.size(0)
76+
77+
net.train()
78+
79+
mean_acc = total_acc.item() * 1.0 / total_sample
80+
mean_loss = sum(total_loss) / total_loss.__len__()
81+
82+
print('[Test] epoch[%d/%d] acc:%.4f%% loss:%.4f\n'
83+
% (epoch_index, 100, mean_acc * 100, mean_loss.item()))

lib/network.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from torch import nn
2+
# from lib.non_local_concatenation import NONLocalBlock2D
3+
# from lib.non_local_gaussian import NONLocalBlock2D
4+
from lib.non_local_embedded_gaussian import NONLocalBlock2D
5+
# from lib.non_local_dot_product import NONLocalBlock2D
6+
7+
8+
class Network(nn.Module):
9+
def __init__(self):
10+
super(Network, self).__init__()
11+
12+
self.convs = nn.Sequential(
13+
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
14+
nn.BatchNorm2d(32),
15+
nn.ReLU(),
16+
nn.MaxPool2d(2),
17+
18+
NONLocalBlock2D(in_channels=32),
19+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
20+
nn.BatchNorm2d(64),
21+
nn.ReLU(),
22+
nn.MaxPool2d(2),
23+
24+
NONLocalBlock2D(in_channels=64),
25+
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
26+
nn.BatchNorm2d(128),
27+
nn.ReLU(),
28+
nn.MaxPool2d(2),
29+
)
30+
31+
self.fc = nn.Sequential(
32+
nn.Linear(in_features=128*3*3, out_features=256),
33+
nn.ReLU(),
34+
nn.Dropout(0.5),
35+
36+
nn.Linear(in_features=256, out_features=10)
37+
)
38+
39+
def forward(self, x):
40+
batch_size = x.size(0)
41+
output = self.convs(x).view(batch_size, -1)
42+
output = self.fc(output)
43+
return output
44+
45+
if __name__ == '__main__':
46+
import torch
47+
48+
img = torch.randn(3, 1, 28, 28)
49+
net = Network()
50+
out = net(img)
51+
print(out.size())
52+

lib/non_local_concatenation.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
5+
6+
class _NonLocalBlockND(nn.Module):
7+
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8+
super(_NonLocalBlockND, self).__init__()
9+
10+
assert dimension in [1, 2, 3]
11+
12+
self.dimension = dimension
13+
self.sub_sample = sub_sample
14+
15+
self.in_channels = in_channels
16+
self.inter_channels = inter_channels
17+
18+
if self.inter_channels is None:
19+
self.inter_channels = in_channels // 2
20+
if self.inter_channels == 0:
21+
self.inter_channels = 1
22+
23+
if dimension == 3:
24+
conv_nd = nn.Conv3d
25+
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
26+
bn = nn.BatchNorm3d
27+
elif dimension == 2:
28+
conv_nd = nn.Conv2d
29+
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
30+
bn = nn.BatchNorm2d
31+
else:
32+
conv_nd = nn.Conv1d
33+
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
34+
bn = nn.BatchNorm1d
35+
36+
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
37+
kernel_size=1, stride=1, padding=0)
38+
39+
if bn_layer:
40+
self.W = nn.Sequential(
41+
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
42+
kernel_size=1, stride=1, padding=0),
43+
bn(self.in_channels)
44+
)
45+
nn.init.constant_(self.W[1].weight, 0)
46+
nn.init.constant_(self.W[1].bias, 0)
47+
else:
48+
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49+
kernel_size=1, stride=1, padding=0)
50+
nn.init.constant_(self.W.weight, 0)
51+
nn.init.constant_(self.W.bias, 0)
52+
53+
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
54+
kernel_size=1, stride=1, padding=0)
55+
56+
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
57+
kernel_size=1, stride=1, padding=0)
58+
59+
self.concat_project = nn.Sequential(
60+
nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
61+
nn.ReLU()
62+
)
63+
64+
if sub_sample:
65+
self.g = nn.Sequential(self.g, max_pool_layer)
66+
self.phi = nn.Sequential(self.phi, max_pool_layer)
67+
68+
def forward(self, x):
69+
'''
70+
:param x: (b, c, t, h, w)
71+
:return:
72+
'''
73+
74+
batch_size = x.size(0)
75+
76+
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
77+
g_x = g_x.permute(0, 2, 1)
78+
79+
# (b, c, N, 1)
80+
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
81+
# (b, c, 1, N)
82+
phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
83+
84+
h = theta_x.size(2)
85+
w = phi_x.size(3)
86+
theta_x = theta_x.repeat(1, 1, 1, w)
87+
phi_x = phi_x.repeat(1, 1, h, 1)
88+
89+
concat_feature = torch.cat([theta_x, phi_x], dim=1)
90+
f = self.concat_project(concat_feature)
91+
b, _, h, w = f.size()
92+
f = f.view(b, h, w)
93+
94+
N = f.size(-1)
95+
f_div_C = f / N
96+
97+
y = torch.matmul(f_div_C, g_x)
98+
y = y.permute(0, 2, 1).contiguous()
99+
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
100+
W_y = self.W(y)
101+
z = W_y + x
102+
103+
return z
104+
105+
106+
class NONLocalBlock1D(_NonLocalBlockND):
107+
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
108+
super(NONLocalBlock1D, self).__init__(in_channels,
109+
inter_channels=inter_channels,
110+
dimension=1, sub_sample=sub_sample,
111+
bn_layer=bn_layer)
112+
113+
114+
class NONLocalBlock2D(_NonLocalBlockND):
115+
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
116+
super(NONLocalBlock2D, self).__init__(in_channels,
117+
inter_channels=inter_channels,
118+
dimension=2, sub_sample=sub_sample,
119+
bn_layer=bn_layer)
120+
121+
122+
class NONLocalBlock3D(_NonLocalBlockND):
123+
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
124+
super(NONLocalBlock3D, self).__init__(in_channels,
125+
inter_channels=inter_channels,
126+
dimension=3, sub_sample=sub_sample,
127+
bn_layer=bn_layer)
128+
129+
130+
if __name__ == '__main__':
131+
import torch
132+
133+
for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]:
134+
img = torch.zeros(2, 3, 20)
135+
net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
136+
out = net(img)
137+
print(out.size())
138+
139+
img = torch.zeros(2, 3, 20, 20)
140+
net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
141+
out = net(img)
142+
print(out.size())
143+
144+
img = torch.randn(2, 3, 8, 20, 20)
145+
net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
146+
out = net(img)
147+
print(out.size())

0 commit comments

Comments
 (0)