Skip to content

Commit b0c87a9

Browse files
authored
classify.binary: handle NaNs and infinite values (#763)
* update docstring * cpu_binary: set nans as nans * gpu_binary * gpu_binary revert * check finite with cmath
1 parent 3df3ec7 commit b0c87a9

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

xrspatial/classify.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import partial
33
from typing import List, Optional
44

5+
import cmath
6+
57
import xarray as xr
68

79
try:
@@ -20,12 +22,13 @@ class cupy(object):
2022
@ngjit
2123
def _cpu_binary(data, values):
2224
out = np.zeros_like(data)
25+
out[:] = np.nan
2326
rows, cols = data.shape
2427
for y in range(0, rows):
2528
for x in range(0, cols):
2629
if np.any(values == data[y, x]):
2730
out[y, x] = 1
28-
else:
31+
elif np.isfinite(data[y, x]):
2932
out[y, x] = 0
3033
return out
3134

@@ -43,8 +46,7 @@ def _run_dask_numpy_binary(data, values):
4346

4447

4548
@nb.cuda.jit(device=True)
46-
def _gpu_binary(data, values):
47-
val = data[0, 0]
49+
def _gpu_binary(val, values):
4850
for v in values:
4951
if val == v:
5052
return 1
@@ -55,7 +57,8 @@ def _gpu_binary(data, values):
5557
def _run_gpu_binary(data, values, out):
5658
i, j = nb.cuda.grid(2)
5759
if i >= 0 and i < out.shape[0] and j >= 0 and j < out.shape[1]:
58-
out[i, j] = _gpu_binary(data[i:i+1, j:j+1], values)
60+
if cmath.isfinite(data[i, j]):
61+
out[i, j] = _gpu_binary(data[i, j], values)
5962

6063

6164
def _run_cupy_binary(data, values):
@@ -76,6 +79,7 @@ def binary(agg, values, name='binary'):
7679
"""
7780
Binarize a data array based on a set of values. Data that equals to a value in the set will be
7881
set to 1. In contrast, data that does not equal to any value in the set will be set to 0.
82+
Note that NaNs and infinite values will be set to NaNs.
7983
8084
Parameters
8185
----------
@@ -114,10 +118,10 @@ def binary(agg, values, name='binary'):
114118
>>> agg_binary = binary(agg, values)
115119
>>> print(agg_binary)
116120
<xarray.DataArray 'binary' (dim_0: 4, dim_1: 5)>
117-
array([[0., 1., 1., 1., 0.],
121+
array([[np.nan, 1., 1., 1., 0.],
118122
[0., 0., 0., 0., 0.],
119123
[0., 0., 0., 0., 0.],
120-
[0., 0., 0., 0., 0.]], dtype=float32)
124+
[0., 0., 0., 0., np.nan]], dtype=float32)
121125
Dimensions without coordinates: dim_0, dim_1
122126
"""
123127

xrspatial/tests/test_classify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ def input_data(backend='numpy'):
2222
def result_binary():
2323
values = [1, 2, 3]
2424
expected_result = np.asarray([
25-
[0, 1, 1, 0, 0],
25+
[np.nan, 1, 1, 0, np.nan],
2626
[0, 0, 0, 0, 0],
2727
[0, 0, 0, 0, 0],
28-
[0, 0, 0, 0, 0]
28+
[0, 0, 0, 0, np.nan]
2929
], dtype=np.float32)
3030
return values, expected_result
3131

0 commit comments

Comments
 (0)