Skip to content

Commit 4f89f6b

Browse files
authored
Support for certain Dask+Cupy (#815)
1 parent 6dd5faa commit 4f89f6b

File tree

4 files changed

+50
-10
lines changed

4 files changed

+50
-10
lines changed

xrspatial/aspect.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,16 @@ def _run_dask_numpy(data: da.Array) -> da.Array:
121121
return out
122122

123123

124+
def _run_dask_cupy(data: da.Array) -> da.Array:
125+
data = data.astype(cupy.float32)
126+
_func = partial(_run_cupy)
127+
out = data.map_overlap(_func,
128+
depth=(1, 1),
129+
boundary=cupy.nan,
130+
meta=cupy.array(()))
131+
return out
132+
133+
124134
def aspect(agg: xr.DataArray,
125135
name: Optional[str] = 'aspect') -> xr.DataArray:
126136
"""
@@ -249,8 +259,7 @@ def aspect(agg: xr.DataArray,
249259
numpy_func=_run_numpy,
250260
dask_func=_run_dask_numpy,
251261
cupy_func=_run_cupy,
252-
dask_cupy_func=lambda *args: not_implemented_func(
253-
*args, messages='aspect() does not support dask with cupy backed DataArray') # noqa
262+
dask_cupy_func=_run_dask_cupy,
254263
)
255264

256265
out = mapper(agg)(agg.data)

xrspatial/convolution.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,24 @@ def _convolve_2d_cupy(data, kernel):
373373
_convolve_2d_cuda[griddim, blockdim](data, kernel, cupy.asarray(out))
374374
return out
375375

376+
def _convolve_2d_dask_cupy(data, kernel):
377+
data = data.astype(cupy.float32)
378+
pad_h = kernel.shape[0] // 2
379+
pad_w = kernel.shape[1] // 2
380+
_func = partial(_convolve_2d_cupy, kernel=kernel)
381+
out = data.map_overlap(_func,
382+
depth=(pad_h, pad_w),
383+
boundary=cupy.nan,
384+
meta=cupy.array(()))
385+
return out
386+
376387

377388
def convolve_2d(data, kernel):
378389
mapper = ArrayTypeFunctionMapping(
379390
numpy_func=_convolve_2d_numpy,
380391
cupy_func=_convolve_2d_cupy,
381392
dask_func=_convolve_2d_dask_numpy,
382-
dask_cupy_func=lambda *args: not_implemented_func(
383-
*args, messages='convolution_2d() does not support dask with cupy backed xr.DataArray' # noqa
384-
)
393+
dask_cupy_func=_convolve_2d_dask_cupy
385394
)
386395
out = mapper(xr.DataArray(data))(data, kernel)
387396
return out

xrspatial/curvature.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ def _run_cupy(data: cupy.ndarray,
8484

8585
return out
8686

87+
def _run_dask_cupy(data: da.Array,
88+
cellsize: Union[int, float]) -> da.Array:
89+
data = data.astype(cupy.float32)
90+
_func = partial(_cpu, cellsize=cellsize)
91+
out = data.map_overlap(_func,
92+
depth=(1, 1),
93+
boundary=cupy.nan,
94+
meta=cupy.array(()))
95+
return out
96+
8797

8898
def curvature(agg: xr.DataArray,
8999
name: Optional[str] = 'curvature') -> xr.DataArray:
@@ -209,8 +219,7 @@ def curvature(agg: xr.DataArray,
209219
numpy_func=_run_numpy,
210220
cupy_func=_run_cupy,
211221
dask_func=_run_dask_numpy,
212-
dask_cupy_func=lambda *args: not_implemented_func(
213-
*args, messages='curvature() does not support dask with cupy backed DataArray.'), # noqa
222+
dask_cupy_func=_run_dask_cupy
214223
)
215224
out = mapper(agg)(agg.data, cellsize)
216225
return xr.DataArray(out,

xrspatial/slope.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,21 @@ def _run_dask_numpy(data: da.Array,
6464
meta=np.array(()))
6565
return out
6666

67+
def _run_dask_cupy(data: da.Array,
68+
cellsize_x: Union[int, float],
69+
cellsize_y: Union[int, float]) -> da.Array:
70+
data = data.astype(cupy.float32)
71+
_func = partial(_run_cupy,
72+
cellsize_x=cellsize_x,
73+
cellsize_y=cellsize_y)
74+
75+
out = data.map_overlap(_func,
76+
depth=(1, 1),
77+
boundary=cupy.nan,
78+
meta=cupy.array(()))
79+
return out
80+
81+
6782

6883
@cuda.jit(device=True)
6984
def _gpu(arr, cellsize_x, cellsize_y):
@@ -163,9 +178,7 @@ def slope(agg: xr.DataArray,
163178
numpy_func=_run_numpy,
164179
cupy_func=_run_cupy,
165180
dask_func=_run_dask_numpy,
166-
dask_cupy_func=lambda *args: not_implemented_func(
167-
*args, messages='slope() does not support dask with cupy backed DataArray' # noqa
168-
),
181+
dask_cupy_func=_run_dask_cupy,
169182
)
170183
out = mapper(agg)(agg.data, cellsize_x, cellsize_y)
171184

0 commit comments

Comments
 (0)