Skip to content

Commit 1bb2f2a

Browse files
committed
fix 3dpooling bug and simplify code
1 parent 589dde8 commit 1bb2f2a

18 files changed

+637
-93
lines changed

README.md

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,40 @@
11
# non-local_pytorch
2-
- Implementation of [**Non-local Neural Networks**](https://arxiv.org/abs/1711.07971).
2+
- Implementation of [**Non-local Neural Block**](https://arxiv.org/abs/1711.07971).
33

44
## Statement
5-
- Only do the experiments on MNIST dataset so far.
6-
- You can find the non-local block in **lib/**.
7-
- The code can support **multi-gpu** now.
5+
- You can find different kinds of non-local block in **lib/**.
6+
- The code is tested on MNIST dataset. You can select the type of non-local block
7+
in **lib/network.py**.
88
- If there is something wrong in my code, please contact me, thanks!
99

10-
There are two version **non-local.py** and **non-local-simple-version.py**.
11-
12-
- **non-local.py** contains the implementation of Gaussian, embedded Gaussian and dot product, which is mainly for learning.
13-
- **non-local-simple-version.py** only contains the implementation of embedded Gaussian.
1410

1511
## Environment
1612
- python 3.6
1713
- pytorch 0.3.0
1814

15+
## Update Records
16+
1. Figure out how to implement the **concatenation** type, and add the code to **lib/**.
17+
2. Fix the bug in **lib/non_local.py** (old version) when using multi-gpu. Someone shares the
18+
reason with me, and you can find it in [here](https://github.com/pytorch/pytorch/issues/8637).
19+
3. Fix the bug of 3D pooling in **lib/non_local.py** (old version). Appreciate
20+
[**protein27**](https://github.com/AlexHex7/Non-local_pytorch/issues/17) for pointing it out.
21+
4. For convenience, I split the **lib/non_local.py** into four python files, and move the
22+
old versions (**lib/non_loca.py** and **lib/non_local_simple_version.py**) into
23+
**lib/backup/**.
24+
25+
26+
## Running Steps
27+
1. Select the type of non-local block in **lib/network.py**.
28+
```
29+
from lib.non_local_concatenation import NONLocalBlock2D
30+
from lib.non_local_gaussian import NONLocalBlock2D
31+
from lib.non_local_embedded_gaussian import NONLocalBlock2D
32+
from lib.non_local_dot_product import NONLocalBlock2D
33+
2. Run **demo_MNIST.py** with one GPU or multi GPU.
34+
```
35+
CUDA_VISIBLE_DEVICES=0,1 python demo_MNIST.py
36+
1937
## Todo
2038
- Experiments on Kinetics dataset.
2139
- Experiments on Charades dataset.
2240
- Experiments on COCO dataset.
23-
- [x] Make sure how to do the Implementation of concatenation.
24-
- [x] Support multi-gpu.
25-
- [x] Fix the bug in **lib/non_local.py** when using multi-gpu (thanks for the person who shares the reason, you can find it in [here](https://github.com/pytorch/pytorch/issues/8637)).

config.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

demo_MNIST.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,27 @@
1-
import logging
21
import torch
32
import torch.utils.data as Data
43
import torchvision
54
from lib.network import Network
65
from torch.autograd import Variable
76
from torch import nn
8-
import config as cfg
9-
from lib.utils import create_architecture
7+
import time
108

119

1210
def calc_acc(x, y):
1311
x = torch.max(x, dim=-1)[1]
1412
accuracy = sum(x == y) / x.size(0)
1513
return accuracy
1614

17-
logging.getLogger().setLevel(logging.INFO)
18-
19-
create_architecture()
2015

2116
train_data = torchvision.datasets.MNIST(root='./mnist', train=True,
2217
transform=torchvision.transforms.ToTensor(),
2318
download=True)
2419
test_data = torchvision.datasets.MNIST(root='./mnist/',
2520
transform=torchvision.transforms.ToTensor(),
2621
train=False)
27-
train_loader = Data.DataLoader(dataset=train_data, batch_size=cfg.batch_size, shuffle=True)
28-
test_loader = Data.DataLoader(dataset=test_data, batch_size=cfg.batch_size, shuffle=False)
22+
23+
train_loader = Data.DataLoader(dataset=train_data, batch_size=128, shuffle=True)
24+
test_loader = Data.DataLoader(dataset=test_data, batch_size=128, shuffle=False)
2925

3026
train_batch_num = len(train_loader)
3127
test_batch_num = len(test_loader)
@@ -35,20 +31,19 @@ def calc_acc(x, y):
3531
net = nn.DataParallel(net)
3632
net.cuda()
3733

38-
opt = torch.optim.Adam(net.parameters(), lr=cfg.LR, weight_decay=cfg.weight_decay)
34+
opt = torch.optim.Adam(net.parameters(), lr=0.001)
3935
loss_func = nn.CrossEntropyLoss()
4036

41-
if cfg.load_model:
42-
net.load_state_dict(torch.load(cfg.model_path))
4337

44-
for epoch_index in range(cfg.epoch):
38+
for epoch_index in range(20):
39+
st = time.time()
4540
for train_batch_index, (img_batch, label_batch) in enumerate(train_loader):
4641
img_batch = Variable(img_batch)
4742
label_batch = Variable(label_batch)
4843

4944
if torch.cuda.is_available():
50-
img_batch = img_batch.cuda(cfg.cuda_num)
51-
label_batch = label_batch.cuda(cfg.cuda_num)
45+
img_batch = img_batch.cuda()
46+
label_batch = label_batch.cuda()
5247

5348
predict = net(img_batch)
5449
acc = calc_acc(predict.cpu().data, label_batch.cpu().data)
@@ -58,15 +53,9 @@ def calc_acc(x, y):
5853
loss.backward()
5954
opt.step()
6055

61-
# logging.info('epoch[%d/%d] batch[%d/%d] loss:%.4f acc:%.4f' %
62-
# (epoch_index, cfg.epoch, train_batch_index, train_batch_num, loss.data[0], acc))
63-
64-
opt.param_groups[0]['lr'] = cfg.LR * (cfg.LR_decay_rate ** (epoch_index // cfg.LR_decay_every_epoch))
65-
print('LR', opt.param_groups[0]['lr'])
66-
# if (train_batch_index + 1) % cfg.test_per_batch == 0:
56+
print('(LR:%f) Time of a epoch:%.4fs' % (opt.param_groups[0]['lr'], time.time()-st))
6757

6858
net.eval()
69-
7059
total_loss = 0
7160
total_acc = 0
7261

@@ -75,8 +64,8 @@ def calc_acc(x, y):
7564
label_batch = Variable(label_batch, volatile=True)
7665

7766
if torch.cuda.is_available():
78-
img_batch = img_batch.cuda(cfg.cuda_num)
79-
label_batch = label_batch.cuda(cfg.cuda_num)
67+
img_batch = img_batch.cuda()
68+
label_batch = label_batch.cuda()
8069

8170
predict = net(img_batch)
8271
acc = calc_acc(predict.cpu().data, label_batch.cpu().data)
@@ -89,7 +78,6 @@ def calc_acc(x, y):
8978

9079
mean_acc = total_acc / test_batch_num
9180
mean_loss = total_loss / test_batch_num
92-
logging.info('[Test] epoch[%d/%d] acc:%.4f loss:%.4f '
93-
% (epoch_index, cfg.epoch, mean_acc, mean_loss.data[0]))
9481

95-
torch.save(net.state_dict(), cfg.model_path)
82+
print('[Test] epoch[%d/%d] acc:%.4f loss:%.4f\n'
83+
% (epoch_index, 100, mean_acc, mean_loss.data[0]))
-3.59 KB
Binary file not shown.
Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded
2727

2828
if dimension == 3:
2929
conv_nd = nn.Conv3d
30-
max_pool = nn.MaxPool3d
30+
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
3131
bn = nn.BatchNorm3d
3232
elif dimension == 2:
3333
conv_nd = nn.Conv2d
34-
max_pool = nn.MaxPool2d
34+
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
3535
bn = nn.BatchNorm2d
3636
else:
3737
conv_nd = nn.Conv1d
38-
max_pool = nn.MaxPool1d
38+
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
3939
bn = nn.BatchNorm1d
4040

4141
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
@@ -91,11 +91,13 @@ def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded
9191
# self.operation_function = self._gaussian
9292

9393
if sub_sample:
94-
self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
94+
self.g = nn.Sequential(self.g, max_pool_layer)
9595
if self.phi is None:
96-
self.phi = max_pool(kernel_size=2)
96+
self.phi = max_pool_layer
9797
else:
98-
self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))
98+
self.phi = nn.Sequential(self.phi, max_pool_layer)
99+
100+
print(self.phi)
99101

100102
def forward(self, x):
101103
'''
@@ -141,20 +143,25 @@ def _embedded_gaussian(self, x):
141143

142144
def _gaussian(self, x):
143145
batch_size = x.size(0)
146+
144147
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
148+
145149
g_x = g_x.permute(0, 2, 1)
146150

147151
theta_x = x.view(batch_size, self.in_channels, -1)
148152
theta_x = theta_x.permute(0, 2, 1)
149153

150154
if self.sub_sample:
155+
print(self.phi(x).size())
151156
phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
152157
else:
153158
phi_x = x.view(batch_size, self.in_channels, -1)
154-
159+
print(phi_x.size())
155160
f = torch.matmul(theta_x, phi_x)
156161
f_div_C = F.softmax(f, dim=-1)
157162

163+
print(f_div_C.size(), g_x.size())
164+
158165
y = torch.matmul(f_div_C, g_x)
159166
y = y.permute(0, 2, 1).contiguous()
160167
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
@@ -247,23 +254,23 @@ def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', s
247254
if __name__ == '__main__':
248255
from torch.autograd import Variable
249256

250-
mode_list = ['concatenation', 'embedded_gaussian', 'gaussian', 'dot_product', ]
251-
# mode_list = ['concatenation']
257+
# mode_list = ['concatenation', 'embedded_gaussian', 'gaussian', 'dot_product', ]
258+
mode_list = ['gaussian']
252259

253260
for mode in mode_list:
254261
print(mode)
255-
img = Variable(torch.zeros(2, 4, 5))
256-
net = NONLocalBlock1D(4, mode=mode, sub_sample=True)
262+
img = Variable(torch.zeros(2, 6, 20))
263+
net = NONLocalBlock1D(6, mode=mode, sub_sample=True)
257264
out = net(img)
258265
print(out.size())
259266

260-
img = Variable(torch.zeros(2, 4, 10, 10))
261-
net = NONLocalBlock2D(4, mode=mode, sub_sample=False, bn_layer=False)
262-
out = net(img)
263-
print(out.size())
264-
265-
img = Variable(torch.zeros(2, 4, 5, 4, 5))
266-
net = NONLocalBlock3D(4, mode=mode)
267-
out = net(img)
268-
print(out.size())
267+
# img = Variable(torch.zeros(2, 4, 20, 20))
268+
# net = NONLocalBlock2D(4, mode=mode, sub_sample=False, bn_layer=False)
269+
# out = net(img)
270+
# print(out.size())
271+
#
272+
# img = Variable(torch.zeros(2, 4, 10, 20, 20))
273+
# net = NONLocalBlock3D(4, mode=mode)
274+
# out = net(img)
275+
# print(out.size())
269276

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=Tru
2222

2323
if dimension == 3:
2424
conv_nd = nn.Conv3d
25-
max_pool = nn.MaxPool3d
25+
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
2626
bn = nn.BatchNorm3d
2727
elif dimension == 2:
2828
conv_nd = nn.Conv2d
29-
max_pool = nn.MaxPool2d
29+
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
3030
bn = nn.BatchNorm2d
3131
else:
3232
conv_nd = nn.Conv1d
33-
max_pool = nn.MaxPool1d
33+
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
3434
bn = nn.BatchNorm1d
3535

3636
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
@@ -56,8 +56,8 @@ def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=Tru
5656
kernel_size=1, stride=1, padding=0)
5757

5858
if sub_sample:
59-
self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
60-
self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))
59+
self.g = nn.Sequential(self.g, max_pool_layer)
60+
self.phi = nn.Sequential(self.phi, max_pool_layer)
6161

6262
def forward(self, x):
6363
'''
@@ -112,20 +112,22 @@ def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=T
112112
if __name__ == '__main__':
113113
from torch.autograd import Variable
114114
import torch
115-
sub_sample = False
116115

117-
img = Variable(torch.zeros(2, 4, 5))
118-
net = NONLocalBlock1D(4, sub_sample=sub_sample, bn_layer=False)
116+
sub_sample = True
117+
bn_layer = True
118+
119+
img = Variable(torch.zeros(2, 3, 20))
120+
net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
119121
out = net(img)
120122
print(out.size())
121123

122-
img = Variable(torch.zeros(2, 4, 5, 3))
123-
net = NONLocalBlock2D(4, sub_sample=sub_sample)
124+
img = Variable(torch.zeros(2, 3, 20, 20))
125+
net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
124126
out = net(img)
125127
print(out.size())
126128

127-
img = Variable(torch.zeros(2, 4, 5, 4, 5))
128-
net = NONLocalBlock3D(4, sub_sample=sub_sample)
129+
img = Variable(torch.randn(2, 3, 10, 20, 20))
130+
net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
129131
out = net(img)
130132
print(out.size())
131133

lib/network.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from torch import nn
2-
from lib.non_local_simple_version import NONLocalBlock2D
3-
# from lib.non_local import NONLocalBlock2D
2+
# from lib.non_local_concatenation import NONLocalBlock2D
3+
# from lib.non_local_gaussian import NONLocalBlock2D
4+
from lib.non_local_embedded_gaussian import NONLocalBlock2D
5+
# from lib.non_local_dot_product import NONLocalBlock2D
46

57

68
class Network(nn.Module):

0 commit comments

Comments
 (0)