Skip to content

Commit 0a25723

Browse files
committed
small bug fixes
adding flag functionality back in from hypernets_dsutil
1 parent c846629 commit 0a25723

File tree

2 files changed

+190
-7
lines changed

2 files changed

+190
-7
lines changed

obsarray/templater/dataset_util.py

Lines changed: 183 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def create_variable(
8686
fill_value = DatasetUtil.get_default_fill_value(dtype)
8787

8888
default_array = DatasetUtil.create_default_array(
89-
dim_sizes, dtype, fill_value=fill_value
89+
dim_sizes, dtype, dim_names, fill_value=fill_value
9090
)
9191

9292
if dim_names is None:
@@ -153,7 +153,10 @@ def create_unc_variable(
153153
# set undefined dims as random
154154
defined_err_corr_dims = []
155155
for erd in err_corr:
156-
defined_err_corr_dims.append(erd["dim"])
156+
if isinstance(erd["dim"],str):
157+
defined_err_corr_dims.append(erd["dim"])
158+
else:
159+
defined_err_corr_dims.extend(erd["dim"])
157160

158161
missing_err_corr_dims = [
159162
dim for dim in dim_names if dim not in defined_err_corr_dims
@@ -346,6 +349,184 @@ def get_default_fill_value(dtype: numpy.typecodes) -> Union[int, float]:
346349
elif dtype == numpy.float64:
347350
return numpy.float64(9.969209968386869e36)
348351

352+
@staticmethod
353+
def _get_flag_encoding(da):
354+
"""
355+
Returns flag encoding for flag type data array
356+
:type da: xarray.DataArray
357+
:param da: data array
358+
:return: flag meanings
359+
:rtype: list
360+
:return: flag masks
361+
:rtype: list
362+
"""
363+
364+
try:
365+
flag_meanings = da.attrs["flag_meanings"].split()
366+
flag_masks = [int(fm) for fm in da.attrs["flag_masks"].split(",")]
367+
except KeyError:
368+
raise KeyError(da.name + " not a flag variable")
369+
370+
return flag_meanings, flag_masks
371+
372+
@staticmethod
373+
def unpack_flags(da):
374+
"""
375+
Breaks down flag data array into dataset of boolean masks for each flag
376+
:type da: xarray.DataArray
377+
:param da: dataset
378+
:return: flag masks
379+
:rtype: xarray.Dataset
380+
"""
381+
382+
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
383+
384+
ds = xarray.Dataset()
385+
for flag_meaning, flag_mask in zip(flag_meanings, flag_masks):
386+
ds[flag_meaning] = DatasetUtil.create_variable(list(da.shape), bool, dim_names=list(da.dims))
387+
ds[flag_meaning] = (da & flag_mask).astype(bool)
388+
389+
return ds
390+
391+
@staticmethod
392+
def get_flags_mask_or(da, flags=None):
393+
"""
394+
Returns boolean mask for set of flags, defined as logical or of flags
395+
396+
:type da: xarray.DataArray
397+
:param da: dataset
398+
399+
:type flags: list
400+
:param flags: list of flags (if unset all data flags selected)
401+
402+
:return: flag masks
403+
:rtype: numpy.ndarray
404+
"""
405+
406+
flags_ds = DatasetUtil.unpack_flags(da)
407+
408+
flags = flags if flags is not None else flags_ds.variables
409+
mask_flags = [flags_ds[flag].values for flag in flags]
410+
411+
return numpy.logical_or.reduce(mask_flags)
412+
413+
@staticmethod
414+
def get_flags_mask_and(da, flags=None):
415+
"""
416+
Returns boolean mask for set of flags, defined as logical and of flags
417+
418+
:type da: xarray.DataArray
419+
:param da: dataset
420+
421+
:type flags: list
422+
:param flags: list of flags (if unset all data flags selected)
423+
424+
:return: flag masks
425+
:rtype: numpy.ndarray
426+
"""
427+
428+
flags_ds = DatasetUtil.unpack_flags(da)
429+
430+
flags = flags if flags is not None else flags_ds.variables
431+
mask_flags = [flags_ds[flag].values for flag in flags]
432+
433+
return numpy.logical_and.reduce(mask_flags)
434+
435+
@staticmethod
436+
def set_flag(da, flag_name, error_if_set=False):
437+
"""
438+
Sets named flag for elements in data array
439+
:type da: xarray.DataArray
440+
:param da: dataset
441+
:type flag_name: str
442+
:param flag_name: name of flag to set
443+
:type error_if_set: bool
444+
:param error_if_set: raises error if chosen flag is already set for any element
445+
"""
446+
447+
set_flags = DatasetUtil.unpack_flags(da)[flag_name]
448+
449+
if numpy.any(set_flags == True) and error_if_set:
450+
raise ValueError("Flag " + flag_name + " already set for variable " + da.name)
451+
452+
# Find flag mask
453+
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
454+
flag_bit = flag_meanings.index(flag_name)
455+
flag_mask = flag_masks[flag_bit]
456+
457+
da.values = da.values | flag_mask
458+
459+
return da
460+
461+
@staticmethod
462+
def unset_flag(da, flag_name, error_if_unset=False):
463+
"""
464+
Unsets named flag for specified index of dataset variable
465+
:type da: xarray.DataArray
466+
:param da: data array
467+
:type flag_name: str
468+
:param flag_name: name of flag to unset
469+
:type error_if_unset: bool
470+
:param error_if_unset: raises error if chosen flag is already set at specified index
471+
"""
472+
473+
set_flags = DatasetUtil.unpack_flags(da)[flag_name]
474+
475+
if numpy.any(set_flags == False) and error_if_unset:
476+
raise ValueError("Flag " + flag_name + " already set for variable " + da.name)
477+
478+
# Find flag mask
479+
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
480+
flag_bit = flag_meanings.index(flag_name)
481+
flag_mask = flag_masks[flag_bit]
482+
483+
da.values = da.values & ~flag_mask
484+
485+
return da
486+
487+
@staticmethod
488+
def get_set_flags(da):
489+
"""
490+
Return list of set flags for single element data array
491+
:type da: xarray.DataArray
492+
:param da: single element data array
493+
:return: set flags
494+
:rtype: list
495+
"""
496+
497+
if da.shape != ():
498+
raise ValueError("Must pass single element data array")
499+
500+
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
501+
502+
set_flags = []
503+
for flag_meaning, flag_mask in zip(flag_meanings, flag_masks):
504+
if (da & flag_mask):
505+
set_flags.append(flag_meaning)
506+
507+
return set_flags
508+
509+
@staticmethod
510+
def check_flag_set(da, flag_name):
511+
"""
512+
Returns if flag for single element data array
513+
:type da: xarray.DataArray
514+
:param da: single element data array
515+
:type flag_name: str
516+
:param flag_name: name of flag to set
517+
:return: set flags
518+
:rtype: list
519+
"""
520+
521+
if da.shape != ():
522+
raise ValueError("Must pass single element data array")
523+
524+
set_flags = DatasetUtil.get_set_flags(da)
525+
526+
if flag_name in set_flags:
527+
return True
528+
return False
529+
349530

350531
if __name__ == "__main__":
351532
pass

obsarray/unc_accessor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def expand_sli(self, sli: Optional[tuple] = None) -> tuple:
7979
if sli is None:
8080
out_sli = tuple([slice(None)] * self._obj[self._unc_var_name].ndim)
8181

82-
# if the sli tuple has the correct shape, it can be used directly
83-
elif self._obj[self._unc_var_name].ndim == len(sli):
84-
out_sli = sli
82+
# # if the sli tuple has the correct shape, it can be used directly
83+
# elif self._obj[self._unc_var_name].ndim == len(sli):
84+
# out_sli = sli
8585

8686
# If different shape, set each dimension to slice(None) and then change the
8787
# ones provided in the new slice. E.g. if providing [0] for a variable with
@@ -246,8 +246,10 @@ def is_random(self) -> bool:
246246
247247
:return: random uncertainty flag
248248
"""
249-
250-
return all(e[1].is_random is True for e in self.err_corr)
249+
if len(self.err_corr)>0:
250+
return all(e[1].is_random is True for e in self.err_corr)
251+
else:
252+
return False
251253

252254
@property
253255
def is_structured(self) -> bool:

0 commit comments

Comments
 (0)