File tree Expand file tree Collapse file tree 4 files changed +50
-10
lines changed
Expand file tree Collapse file tree 4 files changed +50
-10
lines changed Original file line number Diff line number Diff line change @@ -121,6 +121,16 @@ def _run_dask_numpy(data: da.Array) -> da.Array:
121121 return out
122122
123123
124+ def _run_dask_cupy (data : da .Array ) -> da .Array :
125+ data = data .astype (cupy .float32 )
126+ _func = partial (_run_cupy )
127+ out = data .map_overlap (_func ,
128+ depth = (1 , 1 ),
129+ boundary = cupy .nan ,
130+ meta = cupy .array (()))
131+ return out
132+
133+
124134def aspect (agg : xr .DataArray ,
125135 name : Optional [str ] = 'aspect' ) -> xr .DataArray :
126136 """
@@ -249,8 +259,7 @@ def aspect(agg: xr.DataArray,
249259 numpy_func = _run_numpy ,
250260 dask_func = _run_dask_numpy ,
251261 cupy_func = _run_cupy ,
252- dask_cupy_func = lambda * args : not_implemented_func (
253- * args , messages = 'aspect() does not support dask with cupy backed DataArray' ) # noqa
262+ dask_cupy_func = _run_dask_cupy ,
254263 )
255264
256265 out = mapper (agg )(agg .data )
Original file line number Diff line number Diff line change @@ -373,15 +373,24 @@ def _convolve_2d_cupy(data, kernel):
373373 _convolve_2d_cuda [griddim , blockdim ](data , kernel , cupy .asarray (out ))
374374 return out
375375
376+ def _convolve_2d_dask_cupy (data , kernel ):
377+ data = data .astype (cupy .float32 )
378+ pad_h = kernel .shape [0 ] // 2
379+ pad_w = kernel .shape [1 ] // 2
380+ _func = partial (_convolve_2d_cupy , kernel = kernel )
381+ out = data .map_overlap (_func ,
382+ depth = (pad_h , pad_w ),
383+ boundary = cupy .nan ,
384+ meta = cupy .array (()))
385+ return out
386+
376387
377388def convolve_2d (data , kernel ):
378389 mapper = ArrayTypeFunctionMapping (
379390 numpy_func = _convolve_2d_numpy ,
380391 cupy_func = _convolve_2d_cupy ,
381392 dask_func = _convolve_2d_dask_numpy ,
382- dask_cupy_func = lambda * args : not_implemented_func (
383- * args , messages = 'convolution_2d() does not support dask with cupy backed xr.DataArray' # noqa
384- )
393+ dask_cupy_func = _convolve_2d_dask_cupy
385394 )
386395 out = mapper (xr .DataArray (data ))(data , kernel )
387396 return out
Original file line number Diff line number Diff line change @@ -84,6 +84,16 @@ def _run_cupy(data: cupy.ndarray,
8484
8585 return out
8686
87+ def _run_dask_cupy (data : da .Array ,
88+ cellsize : Union [int , float ]) -> da .Array :
89+ data = data .astype (cupy .float32 )
90+ _func = partial (_cpu , cellsize = cellsize )
91+ out = data .map_overlap (_func ,
92+ depth = (1 , 1 ),
93+ boundary = cupy .nan ,
94+ meta = cupy .array (()))
95+ return out
96+
8797
8898def curvature (agg : xr .DataArray ,
8999 name : Optional [str ] = 'curvature' ) -> xr .DataArray :
@@ -209,8 +219,7 @@ def curvature(agg: xr.DataArray,
209219 numpy_func = _run_numpy ,
210220 cupy_func = _run_cupy ,
211221 dask_func = _run_dask_numpy ,
212- dask_cupy_func = lambda * args : not_implemented_func (
213- * args , messages = 'curvature() does not support dask with cupy backed DataArray.' ), # noqa
222+ dask_cupy_func = _run_dask_cupy
214223 )
215224 out = mapper (agg )(agg .data , cellsize )
216225 return xr .DataArray (out ,
Original file line number Diff line number Diff line change @@ -64,6 +64,21 @@ def _run_dask_numpy(data: da.Array,
6464 meta = np .array (()))
6565 return out
6666
67+ def _run_dask_cupy (data : da .Array ,
68+ cellsize_x : Union [int , float ],
69+ cellsize_y : Union [int , float ]) -> da .Array :
70+ data = data .astype (cupy .float32 )
71+ _func = partial (_run_cupy ,
72+ cellsize_x = cellsize_x ,
73+ cellsize_y = cellsize_y )
74+
75+ out = data .map_overlap (_func ,
76+ depth = (1 , 1 ),
77+ boundary = cupy .nan ,
78+ meta = cupy .array (()))
79+ return out
80+
81+
6782
6883@cuda .jit (device = True )
6984def _gpu (arr , cellsize_x , cellsize_y ):
@@ -163,9 +178,7 @@ def slope(agg: xr.DataArray,
163178 numpy_func = _run_numpy ,
164179 cupy_func = _run_cupy ,
165180 dask_func = _run_dask_numpy ,
166- dask_cupy_func = lambda * args : not_implemented_func (
167- * args , messages = 'slope() does not support dask with cupy backed DataArray' # noqa
168- ),
181+ dask_cupy_func = _run_dask_cupy ,
169182 )
170183 out = mapper (agg )(agg .data , cellsize_x , cellsize_y )
171184
You can’t perform that action at this time.
0 commit comments