Skip to content

Commit 8b01c59

Browse files
authored
added in dask-cupy convolve_2d test (#823)
1 parent b2b71fd commit 8b01c59

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

xrspatial/convolution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def _convolve_2d_cupy(data, kernel):
373373
_convolve_2d_cuda[griddim, blockdim](data, kernel, cupy.asarray(out))
374374
return out
375375

376+
376377
def _convolve_2d_dask_cupy(data, kernel):
377378
data = data.astype(cupy.float32)
378379
pad_h = kernel.shape[0] // 2

xrspatial/tests/test_focal.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,11 @@ def test_2d_convolution_gpu(
264264
dask_cupy_agg = xr.DataArray(
265265
da.from_array(cupy.asarray(convolve_2d_data), chunks=(3, 3))
266266
)
267-
with pytest.raises(NotImplementedError) as e_info:
268-
convolution_2d(dask_cupy_agg, kernel_custom)
269-
assert e_info
267+
result_kernel_annulus = convolve_2d(dask_cupy_agg.data, kernel_annulus_2_2_2_1)
268+
assert isinstance(result_kernel_annulus, da.Array)
269+
np.testing.assert_allclose(
270+
result_kernel_annulus.compute().get(), convolution_kernel_annulus_2_2_1, equal_nan=True
271+
)
270272

271273

272274
def test_calc_cellsize_unit_input_attrs(convolve_2d_data):

0 commit comments

Comments
 (0)