From 79ff1be9fcee70648f8b6973138392a4fea9ce64 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 15 Dec 2025 10:39:34 -0800 Subject: [PATCH] added in dask-cupy convolve_2d test --- xrspatial/convolution.py | 1 + xrspatial/tests/test_focal.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/xrspatial/convolution.py b/xrspatial/convolution.py index 3bcc751f..2c46e95c 100644 --- a/xrspatial/convolution.py +++ b/xrspatial/convolution.py @@ -373,6 +373,7 @@ def _convolve_2d_cupy(data, kernel): _convolve_2d_cuda[griddim, blockdim](data, kernel, cupy.asarray(out)) return out + def _convolve_2d_dask_cupy(data, kernel): data = data.astype(cupy.float32) pad_h = kernel.shape[0] // 2 diff --git a/xrspatial/tests/test_focal.py b/xrspatial/tests/test_focal.py index 61f3b54f..1ab72287 100644 --- a/xrspatial/tests/test_focal.py +++ b/xrspatial/tests/test_focal.py @@ -264,9 +264,11 @@ def test_2d_convolution_gpu( dask_cupy_agg = xr.DataArray( da.from_array(cupy.asarray(convolve_2d_data), chunks=(3, 3)) ) - with pytest.raises(NotImplementedError) as e_info: - convolution_2d(dask_cupy_agg, kernel_custom) - assert e_info + result_kernel_annulus = convolve_2d(dask_cupy_agg.data, kernel_annulus_2_2_2_1) + assert isinstance(result_kernel_annulus, da.Array) + np.testing.assert_allclose( + result_kernel_annulus.compute().get(), convolution_kernel_annulus_2_2_1, equal_nan=True + ) def test_calc_cellsize_unit_input_attrs(convolve_2d_data):