@@ -65,6 +65,20 @@ def cftime_arrays(
6565 return cftime .num2date (values , units = unit , calendar = cal )
6666
6767
68+ def insert_nans (draw : st .DrawFn , array : np .ndarray ) -> np .ndarray :
69+ if array .dtype .kind in "cf" :
70+ nan_idx = draw (
71+ st .lists (
72+ st .integers (min_value = 0 , max_value = array .shape [- 1 ] - 1 ),
73+ max_size = array .shape [- 1 ] - 1 ,
74+ unique = True ,
75+ )
76+ )
77+ if nan_idx :
78+ array [..., nan_idx ] = np .nan
79+ return array
80+
81+
6882numeric_dtypes = (
6983 npst .integer_dtypes (endianness = "=" )
7084 | npst .unsigned_integer_dtypes (endianness = "=" )
@@ -96,20 +110,18 @@ def cftime_arrays(
96110SKIPPED_FUNCS = ["var" , "std" , "nanvar" , "nanstd" ]
97111
98112func_st = st .sampled_from ([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS ])
99- numeric_arrays = npst .arrays (
100- elements = {"allow_subnormal" : False }, shape = npst .array_shapes (), dtype = numeric_dtypes
101- )
102- numeric_like_arrays = npst .arrays (
103- elements = {"allow_subnormal" : False }, shape = npst .array_shapes (), dtype = numeric_like_dtypes
104- )
105- all_arrays = (
106- npst .arrays (
107- elements = {"allow_subnormal" : False },
108- shape = npst .array_shapes (),
109- dtype = numeric_like_dtypes ,
110- )
111- | cftime_arrays ()
112- )
113+
114+
115+ @st .composite
116+ def numpy_arrays (draw : st .DrawFn , * , dtype ) -> np .ndarray :
117+ array = draw (npst .arrays (elements = {"allow_subnormal" : False }, shape = npst .array_shapes (), dtype = dtype ))
118+ array = insert_nans (draw , array )
119+ return array
120+
121+
122+ numeric_arrays = numpy_arrays (dtype = numeric_dtypes )
123+ numeric_like_arrays = numpy_arrays (dtype = numeric_like_dtypes )
124+ all_arrays = numeric_like_arrays | cftime_arrays ()
113125
114126
115127def by_arrays (
@@ -153,16 +165,4 @@ def chunked_arrays(
153165) -> dask .array .Array :
154166 array = draw (arrays )
155167 chunks = draw (chunks (shape = array .shape ))
156-
157- if array .dtype .kind in "cf" :
158- nan_idx = draw (
159- st .lists (
160- st .integers (min_value = 0 , max_value = array .shape [- 1 ] - 1 ),
161- max_size = array .shape [- 1 ] - 1 ,
162- unique = True ,
163- )
164- )
165- if nan_idx :
166- array [..., nan_idx ] = np .nan
167-
168168 return from_array (array , chunks = chunks )
0 commit comments