Skip to content

Commit d5f6ea4

Browse files
edoyangotennlee
andauthored
test coverage for filters (#229)
* make healpix importable * add no healpy import test * ensure fallback HEALPix works with args * test numpy dropvalue * add missing brackets * cover numpy filters * _check -> filter * add dataarray capability to DropAnyNan * add test for xarray dropanynan filter * add check for invalid types * add dataarray capability for dropallnan * add tests for dropallnan * add dataset functionality for dropvalue * add tests for dropvalue * add coverage for Shape * remove not in DropAnyNan and DropAllNan * check_shape -> filter * fix mismatched tuple length error * add dask filter tests * Remove unused import * Simplify CI/CD install requirements to bring install under disk space requirements * Test reduced dependencies * Test reduced-complexity requirements installation for CI/CD needs * Test further reduction of dependencies for CI/CD * Test tweak --------- Co-authored-by: Tennessee Leeuwenburg <tennessee.leeuwenburg@bom.gov.au>
1 parent 005aff4 commit d5f6ea4

File tree

13 files changed

+466
-38
lines changed

13 files changed

+466
-38
lines changed

.github/workflows/python-app.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,18 @@ jobs:
2727
- name: Install dependencies
2828
run: |
2929
python -m pip install --upgrade pip
30-
pip install -r requirements.txt
30+
pip install -r requirements_cicd.txt
3131
- name: Test with pytest
3232
run: |
3333
# editable is necessary as pytest will run against the installed
3434
# package rather than the local files creating a coverage report of 0%
3535
pip install -e packages/utils
36-
pip install -e packages/data[all]
37-
pip install -e packages/training[all]
38-
pip install -e packages/pipeline[all]
36+
pip install -e packages/data
37+
pip install -e packages/training
38+
pip install -e packages/pipeline
3939
pip install -e packages/zoo
4040
pip install -e packages/bundled_models/fourcastnext
41-
pip install -e packages/tutorial
42-
pip install -e .[test,docs]
41+
pip install -e .[test]
4342
4443
pytest -m="not noci" --cov=packages/data --cov=packages/utils --cov=packages/pipeline --cov=packages/training --cov=packages/zoo --cov=packages/bundled_models/fourcastnext --ignore=packages/nci_site_archive
4544
- name: Coveralls GitHub Action

packages/data/tests/transform/test_derive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717
import math
1818

19-
from numpy import nan, isnan
19+
from numpy import isnan
2020
from pyearthtools.data.transforms.derive import evaluate, EquationException
2121

2222

packages/pipeline/src/pyearthtools/pipeline/operations/dask/filters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def filter(self, sample: da.Array):
5959
(bool):
6060
If sample contains nan's
6161
"""
62-
if not bool(da.array(list(da.isnan(sample))).any()):
62+
if da.array(list(da.isnan(sample))).any():
6363
raise PipelineFilterException(sample, "Data contained nan's.")
6464

6565

@@ -85,7 +85,7 @@ def filter(self, sample: da.Array):
8585
(bool):
8686
If sample contains nan's
8787
"""
88-
if not bool(da.array(list(da.isnan(sample))).all()):
88+
if da.array(list(da.isnan(sample))).all():
8989
raise PipelineFilterException(sample, "Data contained all nan's.")
9090

9191

@@ -164,9 +164,9 @@ def _find_shape(self, data: Union[tuple[da.Array, ...], da.Array]) -> tuple[Unio
164164
return tuple(map(self._find_shape, data))
165165
return data.shape
166166

167-
def check_shape(self, sample: Union[tuple[da.Array, ...], da.Array]):
167+
def filter(self, sample: Union[tuple[da.Array, ...], da.Array]):
168168
if isinstance(sample, (list, tuple)):
169-
if not isinstance(self._shape, (list, tuple)) and len(self._shape) == len(sample):
169+
if not (isinstance(self._shape, (list, tuple)) and len(self._shape) == len(sample)):
170170
raise RuntimeError(
171171
f"If sample is tuple, shape must also be, and of the same length. {self._shape} != {tuple(self._find_shape(i) for i in sample)}"
172172
)

packages/pipeline/src/pyearthtools/pipeline/operations/numpy/filters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def _find_shape(self, data: Union[tuple[np.ndarray, ...], np.ndarray]) -> tuple[
168168
return tuple(map(self._find_shape, data))
169169
return data.shape
170170

171-
def check_shape(self, sample: Union[tuple[np.ndarray, ...], np.ndarray]):
171+
def filter(self, sample: Union[tuple[np.ndarray, ...], np.ndarray]):
172172
if isinstance(sample, (list, tuple)):
173-
if not isinstance(self._shape, (list, tuple)) and len(self._shape) == len(sample):
173+
if not (isinstance(self._shape, (list, tuple)) and len(self._shape) == len(sample)):
174174
raise RuntimeError(
175175
f"If sample is tuple, shape must also be, and of the same length. {self._shape} != {tuple(self._find_shape(i) for i in sample)}"
176176
)

packages/pipeline/src/pyearthtools/pipeline/operations/xarray/filters.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919
import xarray as xr
20-
20+
import warnings
2121
import math
2222

2323
from pyearthtools.pipeline.filters import Filter, PipelineFilterException
@@ -58,7 +58,7 @@ def __init__(self, variables: Optional[list] = None) -> None:
5858

5959
self.variables = variables
6060

61-
def _check(self, sample: xr.Dataset):
61+
def filter(self, sample: xr.Dataset):
6262
"""Check if any of the sample is nan
6363
6464
Args:
@@ -68,10 +68,21 @@ def _check(self, sample: xr.Dataset):
6868
(bool):
6969
If sample contains nan's
7070
"""
71+
7172
if self.variables:
72-
sample = sample[self.variables]
73+
if isinstance(sample, xr.DataArray):
74+
warnings.warn("input sample is xr.DataArray - ignoring filter variables.")
75+
else:
76+
sample = sample[self.variables]
77+
78+
if isinstance(sample, xr.DataArray):
79+
has_nan = np.isnan(sample).any()
80+
elif isinstance(sample, xr.Dataset):
81+
has_nan = np.array(list(np.isnan(sample).values())).any()
82+
else:
83+
raise TypeError("This filter only accepts xr.DataArray or xr.Dataset")
7384

74-
if not bool(np.array(list(np.isnan(sample).values())).any()):
85+
if has_nan:
7586
raise PipelineFilterException(sample, "Data contained nan's.")
7687

7788

@@ -95,7 +106,7 @@ def __init__(self, variables: Optional[list] = None) -> None:
95106

96107
self.variables = variables
97108

98-
def _check(self, sample: xr.Dataset):
109+
def filter(self, sample: xr.Dataset):
99110
"""Check if all of the sample is nan
100111
101112
Args:
@@ -106,9 +117,19 @@ def _check(self, sample: xr.Dataset):
106117
If sample contains nan's
107118
"""
108119
if self.variables:
109-
sample = sample[self.variables]
120+
if isinstance(sample, xr.DataArray):
121+
warnings.warn("input sample is xr.DataArray - ignoring filter variables.")
122+
else:
123+
sample = sample[self.variables]
124+
125+
if isinstance(sample, xr.DataArray):
126+
all_nan = np.isnan(sample).all()
127+
elif isinstance(sample, xr.Dataset):
128+
all_nan = np.array(list(np.isnan(sample).values())).all()
129+
else:
130+
raise TypeError("This filter only accepts xr.DataArray or xr.Dataset")
110131

111-
if not bool(np.array(list(np.isnan(sample).values())).all()):
132+
if all_nan:
112133
raise PipelineFilterException(sample, "Data contained all nan's.")
113134

114135

@@ -147,16 +168,24 @@ def filter(self, sample: T):
147168
(bool):
148169
If sample contains nan's
149170
"""
150-
if np.isnan(self._value):
151-
function = ( # noqa
152-
lambda x: ((np.count_nonzero(np.isnan(x)) / math.prod(x.shape)) * 100) >= self._percentage
153-
) # noqa
171+
if isinstance(sample, xr.DataArray):
172+
if np.isnan(self._value):
173+
drop = ((np.count_nonzero(np.isnan(sample)) / math.prod(sample.shape)) * 100) >= self._percentage
174+
else:
175+
drop = ((np.count_nonzero(sample == self._value) / math.prod(sample.shape)) * 100) >= self._percentage
176+
elif isinstance(sample, xr.Dataset):
177+
if np.isnan(self._value):
178+
nmatches = np.sum(list(np.isnan(sample).sum().values()))
179+
nvalues = np.sum([math.prod(v.shape) for v in sample.values()])
180+
drop = nmatches / nvalues * 100 >= self._percentage
181+
else:
182+
nmatches = np.sum(list((sample == 1).sum().values()))
183+
nvalues = np.sum([math.prod(v.shape) for v in sample.values()])
184+
drop = nmatches / nvalues * 100 >= self._percentage
154185
else:
155-
function = ( # noqa
156-
lambda x: ((np.count_nonzero(x == self._value) / math.prod(x.shape)) * 100) >= self._percentage
157-
) # noqa
186+
raise TypeError("This filter only accepts xr.DataArray or xr.Dataset")
158187

159-
if not function(sample):
188+
if not drop:
160189
raise PipelineFilterException(sample, f"Data contained more than {self._percentage}% of {self._value}.")
161190

162191

@@ -198,7 +227,7 @@ def _find_shape(self, data: T) -> tuple[int, ...]:
198227

199228
def filter(self, sample: Union[tuple[T, ...], T]):
200229
if isinstance(sample, (list, tuple)):
201-
if not isinstance(self._shape, (list, tuple)) and len(self._shape) == len(sample):
230+
if not (isinstance(self._shape, (list, tuple)) and len(self._shape) == len(sample)):
202231
raise RuntimeError(
203232
f"If sample is tuple, shape must also be, and of the same length. {self._shape} != {tuple(self._find_shape(i) for i in sample)}"
204233
)

packages/pipeline/src/pyearthtools/pipeline/operations/xarray/remapping/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
except ImportError:
2828

2929
class HEALPix:
30-
def __init__(self):
30+
def __init__(self, *args, **kwargs):
3131
warnings.warn(
3232
"Could not import the healpix projection, please install the 'healpy' and 'reproject' optional dependencies"
3333
)

packages/pipeline/src/pyearthtools/pipeline/operations/xarray/remapping/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Type, TypeVar
2222
import xarray as xr
2323

24-
from pyearthtools.pipeline import Operation
24+
from pyearthtools.pipeline.operation import Operation
2525

2626
XR_TYPE = TypeVar("XR_TYPE", xr.Dataset, xr.DataArray)
2727

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright Commonwealth of Australia, Bureau of Meteorology 2025.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pyearthtools.pipeline.operations.dask import filters
16+
from pyearthtools.pipeline.exceptions import PipelineFilterException
17+
18+
import numpy as np
19+
import dask.array as da
20+
import pytest
21+
22+
23+
def test_DropAnyNan():
24+
"""Tests DropAnyNan dask filter."""
25+
26+
original = da.ones((2, 2))
27+
28+
# no nans - should succeed quietly
29+
drop = filters.DropAnyNan()
30+
drop.filter(original)
31+
32+
# one nan - should raise exception
33+
original[0, 0] = np.nan
34+
drop = filters.DropAnyNan()
35+
with pytest.raises(PipelineFilterException):
36+
drop.filter(original)
37+
38+
39+
# xfailed since the result seems to be inverted to documented requirements
40+
@pytest.mark.xfail
41+
def test_DropAllNan():
42+
"""Tests DropAllNan dask filter."""
43+
44+
original = da.empty((2, 2))
45+
46+
# no nans - should succeed quietly
47+
drop = filters.DropAllNan()
48+
drop.filter(original)
49+
50+
# one nan - should succeed quietly
51+
original[0, 0] = np.nan
52+
drop.filter(original)
53+
54+
# all nans - should raise exception
55+
original[:, :] = np.nan
56+
with pytest.raises(PipelineFilterException):
57+
drop.filter(original)
58+
59+
60+
def test_DropValue():
61+
"""Tests DropValue dask filter."""
62+
63+
original = da.from_array([[0, 0], [1, 2]])
64+
65+
# drop case (num zeros < threshold)
66+
drop = filters.DropValue(0, 75)
67+
with pytest.raises(PipelineFilterException):
68+
drop.filter(original)
69+
70+
# non-drop case (num zeros >= threshold)
71+
drop = filters.DropValue(0, 50)
72+
drop.filter(original)
73+
74+
# drop case (num nans < threshold)
75+
original = da.from_array([[np.nan, np.nan], [1, 2]])
76+
drop = filters.DropValue("nan", 75)
77+
with pytest.raises(PipelineFilterException):
78+
drop.filter(original)
79+
80+
# non-drop case (num nans >= threshold)
81+
drop = filters.DropValue("nan", 50)
82+
drop.filter(original)
83+
84+
85+
def test_Shape():
86+
"""Tests Shape dask filter."""
87+
88+
originals = (da.empty((2, 2)), da.empty((2, 3)))
89+
90+
# check drop case
91+
drop = filters.Shape((2, 3))
92+
with pytest.raises(PipelineFilterException):
93+
drop.filter(originals[0])
94+
95+
# check non-drop case
96+
drop = filters.Shape((2, 2))
97+
drop.filter(originals[0])
98+
99+
# check tuple inputs drop cases
100+
drop = filters.Shape(((2, 3), (2, 3)))
101+
with pytest.raises(PipelineFilterException):
102+
drop.filter(originals)
103+
104+
# check tuple inputs non-drop cases
105+
drop = filters.Shape(((2, 2), (2, 3)))
106+
drop.filter(originals)
107+
108+
# invalid mismatched shape and input
109+
drop = filters.Shape(((2, 2),))
110+
with pytest.raises(RuntimeError):
111+
drop.filter(originals)

packages/pipeline/tests/operations/numpy/test_numpy_filter.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_DropAnyNan_true():
3535
drop = filters.DropAnyNan()
3636

3737
with pytest.raises(PipelineFilterException):
38-
result = drop.filter(original)
38+
drop.filter(original)
3939

4040

4141
def test_DropAllNan_false():
@@ -54,4 +54,59 @@ def test_DropAllNan_true():
5454
drop = filters.DropAllNan()
5555

5656
with pytest.raises(PipelineFilterException):
57-
result = drop.filter(original)
57+
drop.filter(original)
58+
59+
60+
def test_DropValue():
61+
62+
# test drop case
63+
original = np.array([[1, 1], [np.nan, np.nan]])
64+
65+
drop = filters.DropValue(value=1, percentage=75)
66+
67+
with pytest.raises(PipelineFilterException):
68+
drop.filter(original)
69+
70+
# test no drop case
71+
drop = filters.DropValue(value=1, percentage=50)
72+
drop.filter(original)
73+
74+
# test with nan - drop case
75+
drop = filters.DropValue(value="nan", percentage=75)
76+
77+
with pytest.raises(PipelineFilterException):
78+
drop.filter(original)
79+
80+
# no drop case
81+
drop = filters.DropValue(value="nan", percentage=50)
82+
drop.filter(original)
83+
84+
85+
def test_Shape():
86+
87+
# test drop case
88+
original = np.empty((2, 3))
89+
drop = filters.Shape((2, 2))
90+
91+
with pytest.raises(PipelineFilterException):
92+
drop.filter(original)
93+
94+
# test non-drop case
95+
original = np.empty((2, 2))
96+
drop.filter(original)
97+
98+
# test with multiple inputs
99+
originals = (np.empty((2, 3)), np.empty((2, 2)))
100+
drop = filters.Shape(((2, 2), (2, 3)))
101+
102+
with pytest.raises(PipelineFilterException):
103+
drop.filter(originals)
104+
105+
# test non drop case
106+
drop = filters.Shape(((2, 3), (2, 2)))
107+
drop.filter(originals)
108+
109+
# test mismatched number of input shapes
110+
drop = filters.Shape(((1, 2), (3, 4), (5, 6)))
111+
with pytest.raises(RuntimeError):
112+
drop.filter(originals)

0 commit comments

Comments
 (0)