66
77from .utils import dtypes , devices , tensor
88
9- dtypes = [torch .float ]
10-
119tests = [{
12- # 'name': 'add',
13- # 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
14- # 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
15- # 'dim': -1,
16- # 'fill_value': 0,
17- # 'expected': [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]],
18- # }, {
19- # 'name': 'add',
20- # 'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
21- # 'index': [0, 1, 1, 0],
22- # 'dim': 0,
23- # 'fill_value': 0,
24- # 'expected': [[6, 5], [6, 8]],
25- # }, {
26- # 'name': 'sub',
27- # 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
28- # 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
29- # 'dim': -1,
30- # 'fill_value': 9,
31- # 'expected': [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]],
32- # }, {
33- # 'name': 'sub',
34- # 'src': [[5, 2], [2, 2], [4, 2], [1, 3]],
35- # 'index': [0, 1, 1, 0],
36- # 'dim': 0,
37- # 'fill_value': 9,
38- # 'expected': [[3, 4], [3, 5]],
39- # }, {
10+ 'name' : 'add' ,
11+ 'src' : [[2 , 0 , 1 , 4 , 3 ], [0 , 2 , 1 , 3 , 4 ]],
12+ 'index' : [[4 , 5 , 4 , 2 , 3 ], [0 , 0 , 2 , 2 , 1 ]],
13+ 'dim' : - 1 ,
14+ 'fill_value' : 0 ,
15+ 'expected' : [[0 , 0 , 4 , 3 , 3 , 0 ], [2 , 4 , 4 , 0 , 0 , 0 ]],
16+ }, {
17+ 'name' : 'add' ,
18+ 'src' : [[5 , 2 ], [2 , 5 ], [4 , 3 ], [1 , 3 ]],
19+ 'index' : [0 , 1 , 1 , 0 ],
20+ 'dim' : 0 ,
21+ 'fill_value' : 0 ,
22+ 'expected' : [[6 , 5 ], [6 , 8 ]],
23+ }, {
24+ 'name' : 'sub' ,
25+ 'src' : [[2 , 0 , 1 , 4 , 3 ], [0 , 2 , 1 , 3 , 4 ]],
26+ 'index' : [[4 , 5 , 4 , 2 , 3 ], [0 , 0 , 2 , 2 , 1 ]],
27+ 'dim' : - 1 ,
28+ 'fill_value' : 9 ,
29+ 'expected' : [[9 , 9 , 5 , 6 , 6 , 9 ], [7 , 5 , 5 , 9 , 9 , 9 ]],
30+ }, {
31+ 'name' : 'sub' ,
32+ 'src' : [[5 , 2 ], [2 , 2 ], [4 , 2 ], [1 , 3 ]],
33+ 'index' : [0 , 1 , 1 , 0 ],
34+ 'dim' : 0 ,
35+ 'fill_value' : 9 ,
36+ 'expected' : [[3 , 4 ], [3 , 5 ]],
37+ }, {
4038 'name' : 'mul' ,
4139 'src' : [[2 , 0 , 1 , 4 , 3 ], [0 , 2 , 1 , 3 , 4 ]],
4240 'index' : [[4 , 5 , 4 , 2 , 3 ], [0 , 0 , 2 , 2 , 1 ]],
4341 'dim' : - 1 ,
4442 'fill_value' : 1 ,
4543 'expected' : [[1 , 1 , 4 , 3 , 2 , 0 ], [0 , 4 , 3 , 1 , 1 , 1 ]],
46- # }, {
47- # 'name': 'mul',
48- # 'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
49- # 'index': [0, 1, 1, 0],
50- # 'dim': 0,
51- # 'fill_value': 1,
52- # 'expected': [[5, 6], [8, 15]],
53- # }, {
54- # 'name': 'div',
55- # 'src': [[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]],
56- # 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
57- # 'dim': -1,
58- # 'fill_value': 1,
59- # 'expected': [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]],
60- # }, {
61- # 'name': 'div',
62- # 'src': [[4, 2], [2, 1], [4, 2], [1, 2]],
63- # 'index': [0, 1, 1, 0],
64- # 'dim': 0,
65- # 'fill_value': 1,
66- # 'expected': [[0.25, 0.25], [0.125, 0.5]],
67- # }, {
68- # 'name': 'mean',
69- # 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
70- # 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
71- # 'dim': -1,
72- # 'fill_value': 0,
73- # 'expected': [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]],
74- # }, {
75- # 'name': 'mean',
76- # 'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
77- # 'index': [0, 1, 1, 0],
78- # 'dim': 0,
79- # 'fill_value': 0,
80- # 'expected': [[3, 2.5], [3, 4]],
81- # }, {
82- # 'name': 'max',
83- # 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
84- # 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
85- # 'dim': -1,
86- # 'fill_value': 0,
87- # 'expected': [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
88- # 'expected_arg': [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]],
89- # }, {
90- # 'name': 'max',
91- # 'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
92- # 'index': [0, 1, 1, 0],
93- # 'dim': 0,
94- # 'fill_value': 0,
95- # 'expected': [[5, 3], [4, 5]],
96- # 'expected_arg': [[0, 3], [2, 1]],
97- # }, {
98- # 'name': 'min',
99- # 'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
100- # 'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
101- # 'dim': -1,
102- # 'fill_value': 9,
103- # 'expected': [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
104- # 'expected_arg': [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]],
105- # }, {
106- # 'name': 'min',
107- # 'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
108- # 'index': [0, 1, 1, 0],
109- # 'dim': 0,
110- # 'fill_value': 9,
111- # 'expected': [[1, 2], [2, 3]],
112- # 'expected_arg': [[3, 0], [1, 2]],
44+ }, {
45+ 'name' : 'mul' ,
46+ 'src' : [[5 , 2 ], [2 , 5 ], [4 , 3 ], [1 , 3 ]],
47+ 'index' : [0 , 1 , 1 , 0 ],
48+ 'dim' : 0 ,
49+ 'fill_value' : 1 ,
50+ 'expected' : [[5 , 6 ], [8 , 15 ]],
51+ }, {
52+ 'name' : 'div' ,
53+ 'src' : [[2 , 1 , 1 , 4 , 2 ], [1 , 2 , 1 , 2 , 4 ]],
54+ 'index' : [[4 , 5 , 4 , 2 , 3 ], [0 , 0 , 2 , 2 , 1 ]],
55+ 'dim' : - 1 ,
56+ 'fill_value' : 1 ,
57+ 'expected' : [[1 , 1 , 0.25 , 0.5 , 0.5 , 1 ], [0.5 , 0.25 , 0.5 , 1 , 1 , 1 ]],
58+ }, {
59+ 'name' : 'div' ,
60+ 'src' : [[4 , 2 ], [2 , 1 ], [4 , 2 ], [1 , 2 ]],
61+ 'index' : [0 , 1 , 1 , 0 ],
62+ 'dim' : 0 ,
63+ 'fill_value' : 1 ,
64+ 'expected' : [[0.25 , 0.25 ], [0.125 , 0.5 ]],
65+ }, {
66+ 'name' : 'mean' ,
67+ 'src' : [[2 , 0 , 1 , 4 , 3 ], [0 , 2 , 1 , 3 , 4 ]],
68+ 'index' : [[4 , 5 , 4 , 2 , 3 ], [0 , 0 , 2 , 2 , 1 ]],
69+ 'dim' : - 1 ,
70+ 'fill_value' : 0 ,
71+ 'expected' : [[0 , 0 , 4 , 3 , 1.5 , 0 ], [1 , 4 , 2 , 0 , 0 , 0 ]],
72+ }, {
73+ 'name' : 'mean' ,
74+ 'src' : [[5 , 2 ], [2 , 5 ], [4 , 3 ], [1 , 3 ]],
75+ 'index' : [0 , 1 , 1 , 0 ],
76+ 'dim' : 0 ,
77+ 'fill_value' : 0 ,
78+ 'expected' : [[3 , 2.5 ], [3 , 4 ]],
79+ }, {
80+ 'name' : 'max' ,
81+ 'src' : [[2 , 0 , 1 , 4 , 3 ], [0 , 2 , 1 , 3 , 4 ]],
82+ 'index' : [[4 , 5 , 4 , 2 , 3 ], [0 , 0 , 2 , 2 , 1 ]],
83+ 'dim' : - 1 ,
84+ 'fill_value' : 0 ,
85+ 'expected' : [[0 , 0 , 4 , 3 , 2 , 0 ], [2 , 4 , 3 , 0 , 0 , 0 ]],
86+ 'expected_arg' : [[- 1 , - 1 , 3 , 4 , 0 , 1 ], [1 , 4 , 3 , - 1 , - 1 , - 1 ]],
87+ }, {
88+ 'name' : 'max' ,
89+ 'src' : [[5 , 2 ], [2 , 5 ], [4 , 3 ], [1 , 3 ]],
90+ 'index' : [0 , 1 , 1 , 0 ],
91+ 'dim' : 0 ,
92+ 'fill_value' : 0 ,
93+ 'expected' : [[5 , 3 ], [4 , 5 ]],
94+ 'expected_arg' : [[0 , 3 ], [2 , 1 ]],
95+ }, {
96+ 'name' : 'min' ,
97+ 'src' : [[2 , 0 , 1 , 4 , 3 ], [0 , 2 , 1 , 3 , 4 ]],
98+ 'index' : [[4 , 5 , 4 , 2 , 3 ], [0 , 0 , 2 , 2 , 1 ]],
99+ 'dim' : - 1 ,
100+ 'fill_value' : 9 ,
101+ 'expected' : [[9 , 9 , 4 , 3 , 1 , 0 ], [0 , 4 , 1 , 9 , 9 , 9 ]],
102+ 'expected_arg' : [[- 1 , - 1 , 3 , 4 , 2 , 1 ], [0 , 4 , 2 , - 1 , - 1 , - 1 ]],
103+ }, {
104+ 'name' : 'min' ,
105+ 'src' : [[5 , 2 ], [2 , 5 ], [4 , 3 ], [1 , 3 ]],
106+ 'index' : [0 , 1 , 1 , 0 ],
107+ 'dim' : 0 ,
108+ 'fill_value' : 9 ,
109+ 'expected' : [[1 , 2 ], [2 , 3 ]],
110+ 'expected_arg' : [[3 , 0 ], [1 , 2 ]],
113111}]
114112
115113
@@ -118,15 +116,12 @@ def test_forward(test, dtype, device):
118116 src = tensor (test ['src' ], dtype , device )
119117 index = tensor (test ['index' ], torch .long , device )
120118 expected = tensor (test ['expected' ], dtype , device )
121- print (src )
122- print (index )
123119
124120 op = getattr (torch_scatter , 'scatter_{}' .format (test ['name' ]))
125121 out = op (src , index , test ['dim' ], fill_value = test ['fill_value' ])
126- print (out )
127122
128- # if isinstance(out, tuple):
129- # assert out[0].tolist() == expected.tolist()
130- # assert out[1].tolist() == test['expected_arg']
131- # else:
132- # assert out.tolist() == expected.tolist()
123+ if isinstance (out , tuple ):
124+ assert out [0 ].tolist () == expected .tolist ()
125+ assert out [1 ].tolist () == test ['expected_arg' ]
126+ else :
127+ assert out .tolist () == expected .tolist ()
0 commit comments