Skip to content

Commit 68f4609

Browse files
committed
added scatter_mul back in
1 parent 1272272 commit 68f4609

File tree

4 files changed

+62
-4
lines changed

4 files changed

+62
-4
lines changed

csrc/scatter.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,37 @@ class ScatterSum : public torch::autograd::Function<ScatterSum> {
7070
}
7171
};
7272

73+
class ScatterMul : public torch::autograd::Function<ScatterMul> {
74+
public:
75+
static variable_list forward(AutogradContext *ctx, Variable src,
76+
Variable index, int64_t dim,
77+
torch::optional<Variable> optional_out,
78+
torch::optional<int64_t> dim_size) {
79+
dim = dim < 0 ? src.dim() + dim : dim;
80+
ctx->saved_data["dim"] = dim;
81+
ctx->saved_data["src_shape"] = src.sizes();
82+
index = broadcast(index, src, dim);
83+
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "mul");
84+
auto out = std::get<0>(result);
85+
ctx->save_for_backward({src, index, out});
86+
if (optional_out.has_value())
87+
ctx->mark_dirty({optional_out.value()});
88+
return {out};
89+
}
90+
91+
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
92+
auto grad_out = grad_outs[0];
93+
auto saved = ctx->get_saved_variables();
94+
auto src = saved[0];
95+
auto index = saved[1];
96+
auto out = saved[2];
97+
auto dim = ctx->saved_data["dim"].toInt();
98+
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
99+
auto grad_in = torch::gather(grad_out * out, dim, index, false).div_(src);
100+
return {grad_in, Variable(), Variable(), Variable(), Variable()};
101+
}
102+
};
103+
73104
class ScatterMean : public torch::autograd::Function<ScatterMean> {
74105
public:
75106
static variable_list forward(AutogradContext *ctx, Variable src,
@@ -197,6 +228,12 @@ torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
197228
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
198229
}
199230

231+
torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
232+
torch::optional<torch::Tensor> optional_out,
233+
torch::optional<int64_t> dim_size) {
234+
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
235+
}
236+
200237
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
201238
torch::optional<torch::Tensor> optional_out,
202239
torch::optional<int64_t> dim_size) {
@@ -221,6 +258,7 @@ scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
221258

222259
static auto registry = torch::RegisterOperators()
223260
.op("torch_scatter::scatter_sum", &scatter_sum)
261+
.op("torch_scatter::scatter_mul", &scatter_mul)
224262
.op("torch_scatter::scatter_mean", &scatter_mean)
225263
.op("torch_scatter::scatter_min", &scatter_min)
226264
.op("torch_scatter::scatter_max", &scatter_max);

test/test_scatter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
from .utils import reductions, tensor, dtypes, devices
99

10+
reductions = reductions + ['mul']
11+
1012
tests = [
1113
{
1214
'src': [1, 3, 2, 4, 5, 6],
1315
'index': [0, 1, 0, 1, 1, 3],
1416
'dim': 0,
1517
'sum': [3, 12, 0, 6],
1618
'add': [3, 12, 0, 6],
19+
'mul': [2, 60, 1, 6],
1720
'mean': [1.5, 4, 0, 6],
1821
'min': [1, 3, 0, 6],
1922
'arg_min': [0, 1, 6, 5],
@@ -26,6 +29,7 @@
2629
'dim': 0,
2730
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
2831
'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
32+
'mul': [[1 * 3, 2 * 4], [5 * 7 * 9, 6 * 8 * 10], [1, 1], [11, 12]],
2933
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
3034
'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
3135
'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]],
@@ -38,6 +42,7 @@
3842
'dim': 1,
3943
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
4044
'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
45+
'mul': [[1 * 3, 5 * 7 * 9, 1, 11], [2 * 4 * 6, 8 * 10, 12, 1]],
4146
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
4247
'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
4348
'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]],
@@ -50,6 +55,7 @@
5055
'dim': 1,
5156
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
5257
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
58+
'mul': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 11 * 13]]],
5359
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
5460
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
5561
'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]],
@@ -62,6 +68,7 @@
6268
'dim': 1,
6369
'sum': [[4], [6]],
6470
'add': [[4], [6]],
71+
'mul': [[3], [8]],
6572
'mean': [[2], [3]],
6673
'min': [[1], [2]],
6774
'arg_min': [[0], [0]],
@@ -74,6 +81,7 @@
7481
'dim': 1,
7582
'sum': [[[4, 4]], [[6, 6]]],
7683
'add': [[[4, 4]], [[6, 6]]],
84+
'mul': [[[3, 3]], [[8, 8]]],
7785
'mean': [[[2, 2]], [[3, 3]]],
7886
'min': [[[1, 1]], [[2, 2]]],
7987
'arg_min': [[[0, 0]], [[0, 0]]],
@@ -125,6 +133,8 @@ def test_out(test, reduce, dtype, device):
125133

126134
if reduce == 'sum' or reduce == 'add':
127135
expected = expected - 2
136+
elif reduce == 'mul':
137+
expected = out # We can not really test this here.
128138
elif reduce == 'mean':
129139
expected = out # We can not really test this here.
130140
elif reduce == 'min':

torch_scatter/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@
5858
f'{major}.{minor}. Please reinstall the torch_scatter that '
5959
f'matches your PyTorch install.')
6060

61-
from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min,
62-
scatter_max, scatter) # noqa
61+
from .scatter import (scatter_sum, scatter_add, scatter_mul, scatter_mean,
62+
scatter_min, scatter_max, scatter) # noqa
6363
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
6464
segment_min_csr, segment_max_csr, segment_csr,
6565
gather_csr) # noqa
@@ -72,6 +72,7 @@
7272
__all__ = [
7373
'scatter_sum',
7474
'scatter_add',
75+
'scatter_mul',
7576
'scatter_mean',
7677
'scatter_min',
7778
'scatter_max',

torch_scatter/scatter.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
3131
return scatter_sum(src, index, dim, out, dim_size)
3232

3333

34+
@torch.jit.script
35+
def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
36+
out: Optional[torch.Tensor] = None,
37+
dim_size: Optional[int] = None) -> torch.Tensor:
38+
return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
39+
40+
3441
@torch.jit.script
3542
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
3643
out: Optional[torch.Tensor] = None,
@@ -127,8 +134,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
127134
with size :attr:`dim_size` at dimension :attr:`dim`.
128135
If :attr:`dim_size` is not given, a minimal sized output tensor
129136
according to :obj:`index.max() + 1` is returned.
130-
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`,
131-
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
137+
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
138+
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
132139
133140
:rtype: :class:`Tensor`
134141
@@ -150,6 +157,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
150157
"""
151158
if reduce == 'sum' or reduce == 'add':
152159
return scatter_sum(src, index, dim, out, dim_size)
160+
if reduce == 'mul':
161+
return scatter_mul(src, index, dim, out, dim_size)
153162
elif reduce == 'mean':
154163
return scatter_mean(src, index, dim, out, dim_size)
155164
elif reduce == 'min':

0 commit comments

Comments
 (0)