Skip to content

Commit c1ac176

Browse files
committed
reset
1 parent 5909acc commit c1ac176

File tree

4 files changed

+125
-139
lines changed

4 files changed

+125
-139
lines changed

cpu/scatter.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@
44

55
void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
66
int64_t dim) {
7-
printf("HUHUHUHU");
8-
// int64_t elems_per_row = index.size(dim), i, idx;
9-
// printf("elems_per_row: %lli\n", elems_per_row);
10-
// AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul", [&] {
11-
// DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
12-
// for (i = 0; i < elems_per_row; i++) {
13-
// idx = index_data[i * index_stride];
14-
// printf("i: %lli, idx: %lli\n", i, idx);
15-
// printf("src: %lli\n", (int64_t)src_data[i * src_stride]);
16-
// out_data[idx * out_stride] *= src_data[i * src_stride];
17-
// }
18-
// });
19-
// });
7+
int64_t elems_per_row = index.size(dim), i, idx;
8+
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul", [&] {
9+
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
10+
for (i = 0; i < elems_per_row; i++) {
11+
idx = index_data[i * index_stride];
12+
out_data[idx * out_stride] *= src_data[i * src_stride];
13+
}
14+
});
15+
});
2016
}
2117

2218
void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,

test/test_backward.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
import torch
5-
# from torch.autograd import gradcheck
5+
from torch.autograd import gradcheck
66
import torch_scatter
77

88
from .utils import grad_dtypes as dtypes, devices, tensor
@@ -13,14 +13,13 @@
1313

1414
@pytest.mark.parametrize('func,device', product(funcs, devices))
1515
def test_backward(func, device):
16-
pass
17-
# index = torch.tensor(indices, dtype=torch.long, device=device)
18-
# src = torch.rand((index.size(0), 2), dtype=torch.double, device=device)
19-
# src.requires_grad_()
16+
index = torch.tensor(indices, dtype=torch.long, device=device)
17+
src = torch.rand((index.size(0), 2), dtype=torch.double, device=device)
18+
src.requires_grad_()
2019

21-
# op = getattr(torch_scatter, 'scatter_{}'.format(func))
22-
# data = (src, index, 0)
23-
# assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
20+
op = getattr(torch_scatter, 'scatter_{}'.format(func))
21+
data = (src, index, 0)
22+
assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
2423

2524

2625
tests = [{
@@ -44,13 +43,12 @@ def test_backward(func, device):
4443

4544
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
4645
def test_arg_backward(test, dtype, device):
47-
pass
48-
# src = tensor(test['src'], dtype, device)
49-
# src.requires_grad_()
50-
# index = tensor(test['index'], torch.long, device)
51-
# grad = tensor(test['grad'], dtype, device)
52-
53-
# op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
54-
# out, _ = op(src, index, test['dim'], fill_value=test['fill_value'])
55-
# out.backward(grad)
56-
# assert src.grad.tolist() == test['expected']
46+
src = tensor(test['src'], dtype, device)
47+
src.requires_grad_()
48+
index = tensor(test['index'], torch.long, device)
49+
grad = tensor(test['grad'], dtype, device)
50+
51+
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
52+
out, _ = op(src, index, test['dim'], fill_value=test['fill_value'])
53+
out.backward(grad)
54+
assert src.grad.tolist() == test['expected']

test/test_forward.py

Lines changed: 100 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -6,110 +6,108 @@
66

77
from .utils import dtypes, devices, tensor
88

9-
dtypes = [torch.float]
10-
119
tests = [{
12-
# 'name': 'add',
13-
# 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
14-
# 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
15-
# 'dim': -1,
16-
# 'fill_value': 0,
17-
# 'expected': [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]],
18-
# }, {
19-
# 'name': 'add',
20-
# 'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
21-
# 'index': [0, 1, 1, 0],
22-
# 'dim': 0,
23-
# 'fill_value': 0,
24-
# 'expected': [[6, 5], [6, 8]],
25-
# }, {
26-
# 'name': 'sub',
27-
# 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
28-
# 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
29-
# 'dim': -1,
30-
# 'fill_value': 9,
31-
# 'expected': [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]],
32-
# }, {
33-
# 'name': 'sub',
34-
# 'src': [[5, 2], [2, 2], [4, 2], [1, 3]],
35-
# 'index': [0, 1, 1, 0],
36-
# 'dim': 0,
37-
# 'fill_value': 9,
38-
# 'expected': [[3, 4], [3, 5]],
39-
# }, {
10+
'name': 'add',
11+
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
12+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
13+
'dim': -1,
14+
'fill_value': 0,
15+
'expected': [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]],
16+
}, {
17+
'name': 'add',
18+
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
19+
'index': [0, 1, 1, 0],
20+
'dim': 0,
21+
'fill_value': 0,
22+
'expected': [[6, 5], [6, 8]],
23+
}, {
24+
'name': 'sub',
25+
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
26+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
27+
'dim': -1,
28+
'fill_value': 9,
29+
'expected': [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]],
30+
}, {
31+
'name': 'sub',
32+
'src': [[5, 2], [2, 2], [4, 2], [1, 3]],
33+
'index': [0, 1, 1, 0],
34+
'dim': 0,
35+
'fill_value': 9,
36+
'expected': [[3, 4], [3, 5]],
37+
}, {
4038
'name': 'mul',
4139
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
4240
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
4341
'dim': -1,
4442
'fill_value': 1,
4543
'expected': [[1, 1, 4, 3, 2, 0], [0, 4, 3, 1, 1, 1]],
46-
# }, {
47-
# 'name': 'mul',
48-
# 'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
49-
# 'index': [0, 1, 1, 0],
50-
# 'dim': 0,
51-
# 'fill_value': 1,
52-
# 'expected': [[5, 6], [8, 15]],
53-
# }, {
54-
# 'name': 'div',
55-
# 'src': [[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]],
56-
# 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
57-
# 'dim': -1,
58-
# 'fill_value': 1,
59-
# 'expected': [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]],
60-
# }, {
61-
# 'name': 'div',
62-
# 'src': [[4, 2], [2, 1], [4, 2], [1, 2]],
63-
# 'index': [0, 1, 1, 0],
64-
# 'dim': 0,
65-
# 'fill_value': 1,
66-
# 'expected': [[0.25, 0.25], [0.125, 0.5]],
67-
# }, {
68-
# 'name': 'mean',
69-
# 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
70-
# 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
71-
# 'dim': -1,
72-
# 'fill_value': 0,
73-
# 'expected': [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]],
74-
# }, {
75-
# 'name': 'mean',
76-
# 'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
77-
# 'index': [0, 1, 1, 0],
78-
# 'dim': 0,
79-
# 'fill_value': 0,
80-
# 'expected': [[3, 2.5], [3, 4]],
81-
# }, {
82-
# 'name': 'max',
83-
# 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
84-
# 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
85-
# 'dim': -1,
86-
# 'fill_value': 0,
87-
# 'expected': [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
88-
# 'expected_arg': [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]],
89-
# }, {
90-
# 'name': 'max',
91-
# 'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
92-
# 'index': [0, 1, 1, 0],
93-
# 'dim': 0,
94-
# 'fill_value': 0,
95-
# 'expected': [[5, 3], [4, 5]],
96-
# 'expected_arg': [[0, 3], [2, 1]],
97-
# }, {
98-
# 'name': 'min',
99-
# 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
100-
# 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
101-
# 'dim': -1,
102-
# 'fill_value': 9,
103-
# 'expected': [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
104-
# 'expected_arg': [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]],
105-
# }, {
106-
# 'name': 'min',
107-
# 'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
108-
# 'index': [0, 1, 1, 0],
109-
# 'dim': 0,
110-
# 'fill_value': 9,
111-
# 'expected': [[1, 2], [2, 3]],
112-
# 'expected_arg': [[3, 0], [1, 2]],
44+
}, {
45+
'name': 'mul',
46+
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
47+
'index': [0, 1, 1, 0],
48+
'dim': 0,
49+
'fill_value': 1,
50+
'expected': [[5, 6], [8, 15]],
51+
}, {
52+
'name': 'div',
53+
'src': [[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]],
54+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
55+
'dim': -1,
56+
'fill_value': 1,
57+
'expected': [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]],
58+
}, {
59+
'name': 'div',
60+
'src': [[4, 2], [2, 1], [4, 2], [1, 2]],
61+
'index': [0, 1, 1, 0],
62+
'dim': 0,
63+
'fill_value': 1,
64+
'expected': [[0.25, 0.25], [0.125, 0.5]],
65+
}, {
66+
'name': 'mean',
67+
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
68+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
69+
'dim': -1,
70+
'fill_value': 0,
71+
'expected': [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]],
72+
}, {
73+
'name': 'mean',
74+
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
75+
'index': [0, 1, 1, 0],
76+
'dim': 0,
77+
'fill_value': 0,
78+
'expected': [[3, 2.5], [3, 4]],
79+
}, {
80+
'name': 'max',
81+
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
82+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
83+
'dim': -1,
84+
'fill_value': 0,
85+
'expected': [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
86+
'expected_arg': [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]],
87+
}, {
88+
'name': 'max',
89+
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
90+
'index': [0, 1, 1, 0],
91+
'dim': 0,
92+
'fill_value': 0,
93+
'expected': [[5, 3], [4, 5]],
94+
'expected_arg': [[0, 3], [2, 1]],
95+
}, {
96+
'name': 'min',
97+
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
98+
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
99+
'dim': -1,
100+
'fill_value': 9,
101+
'expected': [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
102+
'expected_arg': [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]],
103+
}, {
104+
'name': 'min',
105+
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
106+
'index': [0, 1, 1, 0],
107+
'dim': 0,
108+
'fill_value': 9,
109+
'expected': [[1, 2], [2, 3]],
110+
'expected_arg': [[3, 0], [1, 2]],
113111
}]
114112

115113

@@ -118,15 +116,12 @@ def test_forward(test, dtype, device):
118116
src = tensor(test['src'], dtype, device)
119117
index = tensor(test['index'], torch.long, device)
120118
expected = tensor(test['expected'], dtype, device)
121-
print(src)
122-
print(index)
123119

124120
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
125121
out = op(src, index, test['dim'], fill_value=test['fill_value'])
126-
print(out)
127122

128-
# if isinstance(out, tuple):
129-
# assert out[0].tolist() == expected.tolist()
130-
# assert out[1].tolist() == test['expected_arg']
131-
# else:
132-
# assert out.tolist() == expected.tolist()
123+
if isinstance(out, tuple):
124+
assert out[0].tolist() == expected.tolist()
125+
assert out[1].tolist() == test['expected_arg']
126+
else:
127+
assert out.tolist() == expected.tolist()

torch_scatter/mul.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@
77
class ScatterMul(Function):
88
@staticmethod
99
def forward(ctx, out, src, index, dim):
10-
print("DRIN")
1110
func = get_func('scatter_mul', src)
12-
print(func)
1311
func(src, index, out, dim)
14-
print(out)
1512

1613
ctx.mark_dirty(out)
1714
ctx.save_for_backward(out, src, index)

0 commit comments

Comments
 (0)