Skip to content

Commit c2eb492

Browse files
committed
if err_corr defined for a len 1 list of dims, simplify to str in attrs
1 parent 14ee36b commit c2eb492

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

obsarray/templater/dataset_util.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def create_unc_variable(
153153
# set undefined dims as random
154154
defined_err_corr_dims = []
155155
for erd in err_corr:
156-
if isinstance(erd["dim"],str):
156+
if isinstance(erd["dim"], str):
157157
defined_err_corr_dims.append(erd["dim"])
158158
else:
159159
defined_err_corr_dims.extend(erd["dim"])
@@ -173,7 +173,10 @@ def create_unc_variable(
173173
units_str = DatasetUtil.return_err_corr_units_str(idx)
174174

175175
form = ecdef["form"]
176-
attributes[dim_str] = ecdef["dim"]
176+
if isinstance(ecdef["dim"], list) and len(ecdef["dim"]) == 1:
177+
attributes[dim_str] = ecdef["dim"]
178+
else:
179+
attributes[dim_str] = ecdef["dim"]
177180
attributes[form_str] = ecdef["form"]
178181
attributes[units_str] = ecdef["units"] if "units" in ecdef else []
179182

@@ -383,7 +386,9 @@ def unpack_flags(da):
383386

384387
ds = xarray.Dataset()
385388
for flag_meaning, flag_mask in zip(flag_meanings, flag_masks):
386-
ds[flag_meaning] = DatasetUtil.create_variable(list(da.shape), bool, dim_names=list(da.dims))
389+
ds[flag_meaning] = DatasetUtil.create_variable(
390+
list(da.shape), bool, dim_names=list(da.dims)
391+
)
387392
ds[flag_meaning] = (da & flag_mask).astype(bool)
388393

389394
return ds
@@ -447,7 +452,9 @@ def set_flag(da, flag_name, error_if_set=False):
447452
set_flags = DatasetUtil.unpack_flags(da)[flag_name]
448453

449454
if numpy.any(set_flags == True) and error_if_set:
450-
raise ValueError("Flag " + flag_name + " already set for variable " + da.name)
455+
raise ValueError(
456+
"Flag " + flag_name + " already set for variable " + da.name
457+
)
451458

452459
# Find flag mask
453460
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
@@ -473,7 +480,9 @@ def unset_flag(da, flag_name, error_if_unset=False):
473480
set_flags = DatasetUtil.unpack_flags(da)[flag_name]
474481

475482
if numpy.any(set_flags == False) and error_if_unset:
476-
raise ValueError("Flag " + flag_name + " already set for variable " + da.name)
483+
raise ValueError(
484+
"Flag " + flag_name + " already set for variable " + da.name
485+
)
477486

478487
# Find flag mask
479488
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
@@ -501,7 +510,7 @@ def get_set_flags(da):
501510

502511
set_flags = []
503512
for flag_meaning, flag_mask in zip(flag_meanings, flag_masks):
504-
if (da & flag_mask):
513+
if da & flag_mask:
505514
set_flags.append(flag_meaning)
506515

507516
return set_flags

obsarray/templater/tests/test_dataset_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def test_create_unc_variable_1dimundef(self):
192192
)
193193

194194
expected_attrs = {
195-
"err_corr_1_dim": "x",
195+
"err_corr_1_dim": ["x"],
196196
"err_corr_1_form": "rectangle_absolute",
197197
"err_corr_1_units": ["m", "m"],
198198
"err_corr_1_params": [1, 2],

0 commit comments

Comments
 (0)