@@ -6,16 +6,16 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
66 int64_t dim) {
77 int64_t elems_per_row = index.size (dim), i, idx;
88 printf (" elems_per_row: %lli\n " , elems_per_row);
9- AT_DISPATCH_ALL_TYPES (src.type (), " scatter_mul" , [&] {
10- DIM_APPLY3 (scalar_t , src, int64_t , index, scalar_t , out, dim, {
11- for (i = 0 ; i < elems_per_row; i++) {
12- idx = index_data[i * index_stride];
13- printf (" i: %lli, idx: %lli\n " , i, idx);
14- printf (" src: %lli\n " , (int64_t )src_data[i * src_stride]);
15- out_data[idx * out_stride] *= src_data[i * src_stride];
16- }
17- });
18- });
9+ // AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul", [&] {
10+ // DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
11+ // for (i = 0; i < elems_per_row; i++) {
12+ // idx = index_data[i * index_stride];
13+ // printf("i: %lli, idx: %lli\n", i, idx);
14+ // printf("src: %lli\n", (int64_t)src_data[i * src_stride]);
15+ // out_data[idx * out_stride] *= src_data[i * src_stride];
16+ // }
17+ // });
18+ // });
1919}
2020
2121void scatter_div (at::Tensor src, at::Tensor index, at::Tensor out,
0 commit comments