Skip to content

Commit 335525a

Browse files
committed
fix nan values
1 parent 68f4609 commit 335525a

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

csrc/scatter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class ScatterMul : public torch::autograd::Function<ScatterMul> {
9797
auto dim = ctx->saved_data["dim"].toInt();
9898
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
9999
auto grad_in = torch::gather(grad_out * out, dim, index, false).div_(src);
100+
grad_in.masked_fill_(grad_in.isnan(), 0);
100101
return {grad_in, Variable(), Variable(), Variable(), Variable()};
101102
}
102103
};

0 commit comments

Comments
 (0)