@@ -24,8 +24,8 @@ class cupy(object):
2424
2525
2626@ngjit
27- def _avri_cpu (nir_data , red_data , blue_data ):
28- out = np .zeros_like (nir_data , dtype = 'f4' )
27+ def _arvi_cpu (nir_data , red_data , blue_data ):
28+ out = np .zeros (nir_data . shape , dtype = np . float32 )
2929 rows , cols = nir_data .shape
3030 for y in range (0 , rows ):
3131 for x in range (0 , cols ):
@@ -52,15 +52,12 @@ def _arvi_gpu(nir_data, red_data, blue_data, out):
5252
5353
5454def _arvi_dask (nir_data , red_data , blue_data ):
55- out = da .map_blocks (_avri_cpu , nir_data , red_data , blue_data ,
55+ out = da .map_blocks (_arvi_cpu , nir_data , red_data , blue_data ,
5656 meta = np .array (()))
5757 return out
5858
5959
6060def _arvi_cupy (nir_data , red_data , blue_data ):
61-
62- import cupy
63-
6461 griddim , blockdim = cuda_args (nir_data .shape )
6562 out = cupy .empty (nir_data .shape , dtype = 'f4' )
6663 out [:] = cupy .nan
@@ -69,9 +66,6 @@ def _arvi_cupy(nir_data, red_data, blue_data):
6966
7067
7168def _arvi_dask_cupy (nir_data , red_data , blue_data ):
72-
73- import cupy
74-
7569 out = da .map_blocks (_arvi_cupy , nir_data , red_data , blue_data ,
7670 dtype = cupy .float32 , meta = cupy .array (()))
7771 return out
@@ -103,7 +97,7 @@ def arvi(nir_agg: DataArray, red_agg: DataArray, blue_agg: DataArray,
10397 """
10498 validate_arrays (red_agg , nir_agg , blue_agg )
10599
106- mapper = ArrayTypeFunctionMapping (numpy_func = _avri_cpu ,
100+ mapper = ArrayTypeFunctionMapping (numpy_func = _arvi_cpu ,
107101 dask_func = _arvi_dask ,
108102 cupy_func = _arvi_cupy ,
109103 dask_cupy_func = _arvi_dask_cupy )
0 commit comments