Skip to content

Commit 38c8b3a

Browse files
committed
cuda kernel
1 parent a2f18da commit 38c8b3a

File tree

21 files changed

+605
-555
lines changed

21 files changed

+605
-555
lines changed

cuda/atomics.cuh

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
#define ATOMIC(NAME) \
2+
template <typename scalar, size_t size> struct Atomic##NAME##IntegerImpl; \
3+
\
4+
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 1> { \
5+
inline __device__ void operator()(scalar *address, scalar val) { \
6+
uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \
7+
uint32_t old = *address_as_ui; \
8+
uint32_t shift = ((size_t)address & 3) * 8; \
9+
uint32_t sum; \
10+
uint32_t assumed; \
11+
\
12+
do { \
13+
assumed = old; \
14+
sum = OP(val, scalar((old >> shift) & 0xff)); \
15+
old = (old & ~(0x000000ff << shift)) | (sum << shift); \
16+
old = atomicCAS(address_as_ui, assumed, old); \
17+
} while (assumed != old); \
18+
} \
19+
}; \
20+
\
21+
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 2> { \
22+
inline __device__ void operator()(scalar *address, scalar val) { \
23+
uint32_t *address_as_ui = \
24+
(uint32_t *)((char *)address - ((size_t)address & 2)); \
25+
uint32_t old = *address_as_ui; \
26+
uint32_t sum; \
27+
uint32_t newval; \
28+
uint32_t assumed; \
29+
\
30+
do { \
31+
assumed = old; \
32+
sum = OP(val, (size_t)address & 2 ? scalar(old >> 16) \
33+
: scalar(old & 0xffff)); \
34+
newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) \
35+
: (old & 0xffff0000) | sum; \
36+
old = atomicCAS(address_as_ui, assumed, newval); \
37+
} while (assumed != old); \
38+
} \
39+
}; \
40+
\
41+
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 4> { \
42+
inline __device__ void operator()(scalar *address, scalar val) { \
43+
uint32_t *address_as_ui = (uint32_t *)address; \
44+
uint32_t old = *address_as_ui; \
45+
uint32_t assumed; \
46+
\
47+
do { \
48+
assumed = old; \
49+
old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old)); \
50+
} while (assumed != old); \
51+
} \
52+
}; \
53+
\
54+
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 8> { \
55+
inline __device__ void operator()(scalar *address, scalar val) { \
56+
unsigned long long *address_as_ull = (unsigned long long *)address; \
57+
unsigned long long old = *address_as_ull; \
58+
unsigned long long assumed; \
59+
\
60+
do { \
61+
assumed = old; \
62+
old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old)); \
63+
} while (assumed != old); \
64+
} \
65+
}; \
66+
\
67+
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
68+
\
69+
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
70+
inline __device__ void operator()(scalar *address, scalar val) { \
71+
int *address_as_i = (int *)address; \
72+
int old = *address_as_i; \
73+
int assumed; \
74+
\
75+
do { \
76+
assumed = old; \
77+
old = atomicCAS(address_as_i, assumed, \
78+
__float_as_int(OP(val, __int_as_float(assumed)))); \
79+
} while (assumed != old); \
80+
} \
81+
}; \
82+
\
83+
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 8> { \
84+
inline __device__ void operator()(scalar *address, scalar val) { \
85+
unsigned long long int *address_as_ull = \
86+
(unsigned long long int *)address; \
87+
unsigned long long int old = *address_as_ull; \
88+
unsigned long long int assumed; \
89+
\
90+
do { \
91+
assumed = old; \
92+
old = atomicCAS( \
93+
address_as_ull, assumed, \
94+
__double_as_longlong(OP(val, __longlong_as_double(assumed)))); \
95+
} while (assumed != old); \
96+
} \
97+
};
98+
99+
#define OP(X, Y) Y + X
100+
ATOMIC(Add)
101+
#undef OP
102+
static inline __device__ void atomAdd(uint8_t *address, uint8_t val) {
103+
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
104+
}
105+
static inline __device__ void atomAdd(int8_t *address, int8_t val) {
106+
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
107+
}
108+
static inline __device__ void atomAdd(int16_t *address, int16_t val) {
109+
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
110+
}
111+
static inline __device__ void atomAdd(int32_t *address, int32_t val) {
112+
atomicAdd(address, val);
113+
}
114+
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
115+
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
116+
}
117+
static inline __device__ void atomAdd(float *address, float val) {
118+
atomicAdd(address, val);
119+
}
120+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
121+
static inline __device__ void atomAdd(double *address, double val) {
122+
AtomicAddDecimalImpl<double, sizeof(double)>()(address, val);
123+
}
124+
#else
125+
static inline __device__ void atomAdd(double *address, double val) {
126+
atomicAdd(address, val);
127+
}
128+
#endif
129+
130+
#define OP(X, Y) Y *X
131+
ATOMIC(Mul)
132+
#undef OP
133+
static inline __device__ void atomMul(uint8_t *address, uint8_t val) {
134+
AtomicMulIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
135+
}
136+
static inline __device__ void atomMul(int8_t *address, int8_t val) {
137+
AtomicMulIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
138+
}
139+
static inline __device__ void atomMul(int16_t *address, int16_t val) {
140+
AtomicMulIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
141+
}
142+
static inline __device__ void atomMul(int32_t *address, int32_t val) {
143+
AtomicMulIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
144+
}
145+
static inline __device__ void atomMul(int64_t *address, int64_t val) {
146+
AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
147+
}
148+
static inline __device__ void atomMul(float *address, float val) {
149+
AtomicMulDecimalImpl<float, sizeof(float)>()(address, val);
150+
}
151+
static inline __device__ void atomMul(double *address, double val) {
152+
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
153+
}
154+
155+
#define OP(X, Y) Y / X
156+
ATOMIC(Div)
157+
#undef OP
158+
static inline __device__ void atomDiv(uint8_t *address, uint8_t val) {
159+
AtomicDivIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
160+
}
161+
static inline __device__ void atomDiv(int8_t *address, int8_t val) {
162+
AtomicDivIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
163+
}
164+
static inline __device__ void atomDiv(int16_t *address, int16_t val) {
165+
AtomicDivIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
166+
}
167+
static inline __device__ void atomDiv(int32_t *address, int32_t val) {
168+
AtomicDivIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
169+
}
170+
static inline __device__ void atomDiv(int64_t *address, int64_t val) {
171+
AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
172+
}
173+
static inline __device__ void atomDiv(float *address, float val) {
174+
AtomicDivDecimalImpl<float, sizeof(float)>()(address, val);
175+
}
176+
static inline __device__ void atomDiv(double *address, double val) {
177+
AtomicDivDecimalImpl<double, sizeof(double)>()(address, val);
178+
}
179+
180+
#define OP(X, Y) max(Y, X)
181+
ATOMIC(Max)
182+
#undef OP
183+
static inline __device__ void atomMax(uint8_t *address, uint8_t val) {
184+
AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
185+
}
186+
static inline __device__ void atomMax(int8_t *address, int8_t val) {
187+
AtomicMaxIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
188+
}
189+
static inline __device__ void atomMax(int16_t *address, int16_t val) {
190+
AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
191+
}
192+
static inline __device__ void atomMax(int32_t *address, int32_t val) {
193+
atomicMax(address, val);
194+
}
195+
static inline __device__ void atomMax(int64_t *address, int64_t val) {
196+
AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
197+
}
198+
static inline __device__ void atomMax(float *address, float val) {
199+
AtomicMaxDecimalImpl<float, sizeof(float)>()(address, val);
200+
}
201+
static inline __device__ void atomMax(double *address, double val) {
202+
AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val);
203+
}
204+
205+
#define OP(X, Y) min(Y, X)
206+
ATOMIC(Min)
207+
#undef OP
208+
static inline __device__ void atomMin(uint8_t *address, uint8_t val) {
209+
AtomicMinIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
210+
}
211+
static inline __device__ void atomMin(int8_t *address, int8_t val) {
212+
AtomicMinIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
213+
}
214+
static inline __device__ void atomMin(int16_t *address, int16_t val) {
215+
AtomicMinIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
216+
}
217+
static inline __device__ void atomMin(int32_t *address, int32_t val) {
218+
atomicMin(address, val);
219+
}
220+
static inline __device__ void atomMin(int64_t *address, int64_t val) {
221+
AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
222+
}
223+
static inline __device__ void atomMin(float *address, float val) {
224+
AtomicMinDecimalImpl<float, sizeof(float)>()(address, val);
225+
}
226+
static inline __device__ void atomMin(double *address, double val) {
227+
AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
228+
}

cuda/index.cuh

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/cuda/detail/TensorInfo.cuh>
5+
6+
template <typename scalar1, typename scalar2, int64_t Dims>
7+
struct IndexToScatterOffsets3 {
8+
static __device__ void
9+
compute(int64_t i, const int64_t dim,
10+
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
11+
int64_t *indexOffset,
12+
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
13+
int64_t *t1Offset,
14+
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
15+
int64_t *t2Offset) {
16+
for (int64_t d = Dims - 1; d >= 0; d--) {
17+
int64_t curDimIndex = i % index.sizes[d];
18+
*indexOffset += curDimIndex * index.strides[d];
19+
*t1Offset += curDimIndex * t1.strides[d];
20+
if (d != dim) {
21+
*t2Offset += curDimIndex * t2.strides[d];
22+
}
23+
i /= index.sizes[d];
24+
}
25+
int64_t indexValue = index.data[*indexOffset];
26+
*t2Offset += indexValue * t2.strides[dim];
27+
}
28+
};
29+
30+
template <typename scalar1, typename scalar2>
31+
struct IndexToScatterOffsets3<scalar1, scalar2, -1> {
32+
static __device__ void
33+
compute(int64_t i, const int64_t dim,
34+
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
35+
int64_t *indexOffset,
36+
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
37+
int64_t *t1Offset,
38+
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
39+
int64_t *t2Offset) {
40+
for (int64_t d = index.dims - 1; d >= 0; d--) {
41+
int64_t curDimIndex = i % index.sizes[d];
42+
*indexOffset += curDimIndex * index.strides[d];
43+
*t1Offset += curDimIndex * t1.strides[d];
44+
if (d != dim) {
45+
*t2Offset += curDimIndex * t2.strides[d];
46+
}
47+
i /= index.sizes[d];
48+
}
49+
int64_t indexValue = index.data[*indexOffset];
50+
*t2Offset += indexValue * t2.strides[dim];
51+
}
52+
};
53+
54+
template <typename scalar1, typename scalar2, typename scalar3, int64_t Dims>
55+
struct IndexToScatterOffsets4 {
56+
static __device__ void
57+
compute(int64_t i, const int64_t dim,
58+
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
59+
int64_t *indexOffset,
60+
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
61+
int64_t *t1Offset,
62+
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
63+
int64_t *t2Offset,
64+
const at::cuda::detail::TensorInfo<scalar3, int64_t> &t3,
65+
int64_t *t3Offset) {
66+
for (int64_t d = Dims - 1; d >= 0; d--) {
67+
int64_t curDimIndex = i % index.sizes[d];
68+
*indexOffset += curDimIndex * index.strides[d];
69+
*t1Offset += curDimIndex * t1.strides[d];
70+
if (d != dim) {
71+
*t2Offset += curDimIndex * t2.strides[d];
72+
*t3Offset += curDimIndex * t3.strides[d];
73+
}
74+
i /= index.sizes[d];
75+
}
76+
int64_t indexValue = index.data[*indexOffset];
77+
*t2Offset += indexValue * t2.strides[dim];
78+
*t3Offset += indexValue * t3.strides[dim];
79+
}
80+
};
81+
82+
template <typename scalar1, typename scalar2, typename scalar3>
83+
struct IndexToScatterOffsets4<scalar1, scalar2, scalar3, -1> {
84+
static __device__ void
85+
compute(int64_t i, const int64_t dim,
86+
const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
87+
int64_t *indexOffset,
88+
const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
89+
int64_t *t1Offset,
90+
const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
91+
int64_t *t2Offset,
92+
const at::cuda::detail::TensorInfo<scalar3, int64_t> &t3,
93+
int64_t *t3Offset) {
94+
for (int64_t d = index.dims - 1; d >= 0; d--) {
95+
int64_t curDimIndex = i % index.sizes[d];
96+
*indexOffset += curDimIndex * index.strides[d];
97+
*t1Offset += curDimIndex * t1.strides[d];
98+
if (d != dim) {
99+
*t2Offset += curDimIndex * t2.strides[d];
100+
*t3Offset += curDimIndex * t3.strides[d];
101+
}
102+
i /= index.sizes[d];
103+
}
104+
int64_t indexValue = index.data[*indexOffset];
105+
*t2Offset += indexValue * t2.strides[dim];
106+
*t3Offset += indexValue * t3.strides[dim];
107+
}
108+
};

0 commit comments

Comments
 (0)