Skip to content

Commit f8a12e8

Browse files
committed
Merge branch 'main' of github.com:comet-toolkit/obsarray into main
2 parents d3034ed + f1c68e4 commit f8a12e8

File tree

6 files changed

+129
-21
lines changed

6 files changed

+129
-21
lines changed

obsarray/templater/dataset_util.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def create_unc_variable(
153153
# set undefined dims as random
154154
defined_err_corr_dims = []
155155
for erd in err_corr:
156-
if isinstance(erd["dim"],str):
156+
if isinstance(erd["dim"], str):
157157
defined_err_corr_dims.append(erd["dim"])
158158
else:
159159
defined_err_corr_dims.extend(erd["dim"])
@@ -173,7 +173,10 @@ def create_unc_variable(
173173
units_str = DatasetUtil.return_err_corr_units_str(idx)
174174

175175
form = ecdef["form"]
176-
attributes[dim_str] = ecdef["dim"]
176+
if isinstance(ecdef["dim"], list) and len(ecdef["dim"]) == 1:
177+
attributes[dim_str] = ecdef["dim"]
178+
else:
179+
attributes[dim_str] = ecdef["dim"]
177180
attributes[form_str] = ecdef["form"]
178181
attributes[units_str] = ecdef["units"] if "units" in ecdef else []
179182

@@ -383,7 +386,9 @@ def unpack_flags(da):
383386

384387
ds = xarray.Dataset()
385388
for flag_meaning, flag_mask in zip(flag_meanings, flag_masks):
386-
ds[flag_meaning] = DatasetUtil.create_variable(list(da.shape), bool, dim_names=list(da.dims))
389+
ds[flag_meaning] = DatasetUtil.create_variable(
390+
list(da.shape), bool, dim_names=list(da.dims)
391+
)
387392
ds[flag_meaning] = (da & flag_mask).astype(bool)
388393

389394
return ds
@@ -447,7 +452,9 @@ def set_flag(da, flag_name, error_if_set=False):
447452
set_flags = DatasetUtil.unpack_flags(da)[flag_name]
448453

449454
if numpy.any(set_flags == True) and error_if_set:
450-
raise ValueError("Flag " + flag_name + " already set for variable " + da.name)
455+
raise ValueError(
456+
"Flag " + flag_name + " already set for variable " + da.name
457+
)
451458

452459
# Find flag mask
453460
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
@@ -473,7 +480,9 @@ def unset_flag(da, flag_name, error_if_unset=False):
473480
set_flags = DatasetUtil.unpack_flags(da)[flag_name]
474481

475482
if numpy.any(set_flags == False) and error_if_unset:
476-
raise ValueError("Flag " + flag_name + " already set for variable " + da.name)
483+
raise ValueError(
484+
"Flag " + flag_name + " already set for variable " + da.name
485+
)
477486

478487
# Find flag mask
479488
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
@@ -501,7 +510,7 @@ def get_set_flags(da):
501510

502511
set_flags = []
503512
for flag_meaning, flag_mask in zip(flag_meanings, flag_masks):
504-
if (da & flag_mask):
513+
if da & flag_mask:
505514
set_flags.append(flag_meaning)
506515

507516
return set_flags

obsarray/templater/template_util.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,19 @@
1111

1212

1313
def create_ds(
14-
template: Dict[str, Dict], size: Dict[str, int], metadata: Optional[Dict] = None,
15-
propagate_ds: Optional[xarray.Dataset] = None,
14+
template: Dict[str, Dict],
15+
size: Dict[str, int],
16+
metadata: Optional[Dict] = None,
17+
append_ds: Optional[xarray.Dataset] = None,
18+
propagate_ds: Optional[xarray.Dataset] = None,
1619
) -> xarray.Dataset:
1720
"""
1821
Returns template dataset
1922
2023
:param template: dictionary defining ds variable structure, as defined below.
2124
:param size: dictionary of dataset dimensions, entry per dataset dimension with value of size as int
2225
:param metadata: dictionary of dataset metadata
26+
:param append_ds: base dataset to append with template variables
2327
:param propagate_ds: template dataset is populated with data from propagate_ds for their variables with
2428
common names and dimensions. Useful for transferring common data between datasets at different processing levels
2529
(e.g. times, etc.).
@@ -38,7 +42,7 @@ def create_ds(
3842
"""
3943

4044
# Create dataset
41-
ds = xarray.Dataset()
45+
ds = append_ds if append_ds is not None else xarray.Dataset()
4246

4347
# Add variables
4448
ds = TemplateUtil.add_variables(ds, template, size)
@@ -85,7 +89,9 @@ def add_variables(
8589
8690
:returns: dataset with defined variables
8791
"""
92+
8893
for var_name in template.keys():
94+
8995
var = TemplateUtil._create_var(var_name, template[var_name], size)
9096

9197
ds[var_name] = var
@@ -123,7 +129,7 @@ def _create_var(
123129
)
124130

125131
# Create variable and add to dataset
126-
if isinstance(dtype,str):
132+
if dtype == str:
127133
if dtype == "flag":
128134
flag_meanings = attributes.pop("flag_meanings")
129135
variable = du.create_flags_variable(
@@ -218,21 +224,32 @@ def propagate_values(target_ds, source_ds, exclude=None):
218224

219225
# Find variable names common to target_ds and source_ds, excluding specified exclude variables
220226
common_variable_names = list(set(target_ds).intersection(source_ds))
221-
#common_variable_names = list(set(target_ds.variables).intersection(source_ds.variables))
222-
#print(common_variable_names)
227+
# common_variable_names = list(set(target_ds.variables).intersection(source_ds.variables))
228+
# print(common_variable_names)
223229

224230
if exclude is not None:
225-
common_variable_names = [name for name in common_variable_names if name not in exclude]
231+
common_variable_names = [
232+
name for name in common_variable_names if name not in exclude
233+
]
226234

227235
# Remove any common variables that have different dimensions in target_ds and source_ds
228-
common_variable_names = [name for name in common_variable_names if target_ds[name].dims == source_ds[name].dims]
236+
common_variable_names = [
237+
name
238+
for name in common_variable_names
239+
if target_ds[name].dims == source_ds[name].dims
240+
]
229241

230242
# Propagate data
231243
for common_variable_name in common_variable_names:
232-
if target_ds[common_variable_name].shape == source_ds[common_variable_name].shape:
233-
target_ds[common_variable_name].values = source_ds[common_variable_name].values
234-
235-
#to do - add method to propagate common unpopulated metadata
244+
if (
245+
target_ds[common_variable_name].shape
246+
== source_ds[common_variable_name].shape
247+
):
248+
target_ds[common_variable_name].values = source_ds[
249+
common_variable_name
250+
].values
251+
252+
# to do - add method to propagate common unpopulated metadata
236253

237254

238255
if __name__ == "__main__":

obsarray/templater/tests/test_dataset_util.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,36 @@ def test_create_variable_3D_int_attributes(self):
173173
self.assertEqual(-127, array_variable[2, 4, 2])
174174
self.assertEqual("std", array_variable.attrs["standard_name"])
175175

176-
def test_create_unc_variable(self):
176+
def test_create_unc_variable_1dimundef(self):
177+
err_corr = [
178+
{
179+
"dim": ["x"],
180+
"form": "rectangle_absolute",
181+
"params": [1, 2],
182+
"units": ["m", "m"],
183+
},
184+
]
185+
186+
unc_variable = DatasetUtil.create_unc_variable(
187+
[7],
188+
np.int8,
189+
["x"],
190+
pdf_shape="gaussian",
191+
err_corr=err_corr,
192+
)
193+
194+
expected_attrs = {
195+
"err_corr_1_dim": ["x"],
196+
"err_corr_1_form": "rectangle_absolute",
197+
"err_corr_1_units": ["m", "m"],
198+
"err_corr_1_params": [1, 2],
199+
}
200+
201+
actual_attrs = unc_variable.attrs
202+
203+
self.assertTrue(expected_attrs.items() <= actual_attrs.items())
204+
205+
def test_create_unc_variable_1undef(self):
177206
err_corr = [
178207
{
179208
"dim": "x",

obsarray/templater/tests/test_template_util.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,50 @@ def test_create_template_dataset(self):
285285
self.assertEqual(type(ds["array_variable"]), xarray.DataArray)
286286
self.assertEqual("value", ds.attrs["metadata1"])
287287

288+
def test_create_template_dataset_withunc(self):
289+
dim_sizes = {"dim1": 25}
290+
291+
test_variables = {
292+
"array_variable": {
293+
"dim": ["dim1"],
294+
"dtype": np.float32,
295+
"attributes": {
296+
"standard_name": "array_variable_std_name",
297+
"long_name": "array_variable_long_name",
298+
"units": "units",
299+
"preferred_symbol": "av",
300+
"unc_comps": ["u_array_variable"],
301+
},
302+
"encoding": {"dtype": np.uint16, "scale_factor": 1.0, "offset": 0.0},
303+
},
304+
"u_array_variable": {
305+
"dim": ["dim1"],
306+
"dtype": np.float32,
307+
"attributes": {
308+
"err_corr": [
309+
{
310+
"dim": ["dim1"],
311+
"form": "rectangle_absolute",
312+
"params": [1, 2],
313+
"units": ["m", "m"],
314+
},
315+
],
316+
"standard_name": "array_variable_std_name",
317+
"long_name": "array_variable_long_name",
318+
"units": "units",
319+
"preferred_symbol": "av",
320+
},
321+
},
322+
}
323+
324+
test_metadata = {"metadata1": "value"}
325+
326+
ds = create_ds(test_variables, dim_sizes, test_metadata)
327+
328+
self.assertEqual(type(ds), xarray.Dataset)
329+
self.assertEqual(type(ds["array_variable"]), xarray.DataArray)
330+
self.assertEqual("value", ds.attrs["metadata1"])
331+
288332

289333
if __name__ == "__main__":
290334
unittest.main()

obsarray/test/test_unc_accessor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,15 @@ def unitcall(ds):
648648

649649
self.assertRaises(ValueError, unitcall, self.ds)
650650

651+
def test_abs_value_uncnounits(self):
652+
653+
del self.ds["u_str_temperature"].attrs["units"]
654+
655+
xr.testing.assert_equal(
656+
self.ds.unc["temperature"]["u_str_temperature"].value,
657+
self.ds.unc["temperature"]["u_str_temperature"].abs_value,
658+
)
659+
651660
def test_pdf_shape(self):
652661
self.assertEqual(
653662
self.ds.unc["temperature"]["u_ran_temperature"].pdf_shape, "gaussian"

obsarray/unc_accessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def abs_value(self) -> xr.DataArray:
219219
if self.units == "%":
220220
return self.value / 100 * self.var_value
221221

222-
elif self.units != self.var_units:
222+
elif (self.units != self.var_units) and (self.units is not None):
223223
raise ValueError(
224224
"Unit mismatch between observation variable and uncertainty variable:\n"
225225
"* '{}' - '{}'\n"
@@ -246,7 +246,7 @@ def is_random(self) -> bool:
246246
247247
:return: random uncertainty flag
248248
"""
249-
if len(self.err_corr)>0:
249+
if len(self.err_corr) > 0:
250250
return all(e[1].is_random is True for e in self.err_corr)
251251
else:
252252
return False

0 commit comments

Comments
 (0)