@@ -86,7 +86,7 @@ def create_variable(
8686 fill_value = DatasetUtil .get_default_fill_value (dtype )
8787
8888 default_array = DatasetUtil .create_default_array (
89- dim_sizes , dtype , fill_value = fill_value
89+ dim_sizes , dtype , dim_names , fill_value = fill_value
9090 )
9191
9292 if dim_names is None :
@@ -153,7 +153,10 @@ def create_unc_variable(
153153 # set undefined dims as random
154154 defined_err_corr_dims = []
155155 for erd in err_corr :
156- defined_err_corr_dims .append (erd ["dim" ])
156+ if isinstance (erd ["dim" ],str ):
157+ defined_err_corr_dims .append (erd ["dim" ])
158+ else :
159+ defined_err_corr_dims .extend (erd ["dim" ])
157160
158161 missing_err_corr_dims = [
159162 dim for dim in dim_names if dim not in defined_err_corr_dims
@@ -346,6 +349,184 @@ def get_default_fill_value(dtype: numpy.typecodes) -> Union[int, float]:
346349 elif dtype == numpy .float64 :
347350 return numpy .float64 (9.969209968386869e36 )
348351
352+ @staticmethod
353+ def _get_flag_encoding (da ):
354+ """
355+ Returns flag encoding for flag type data array
356+ :type da: xarray.DataArray
357+ :param da: data array
358+ :return: flag meanings
359+ :rtype: list
360+ :return: flag masks
361+ :rtype: list
362+ """
363+
364+ try :
365+ flag_meanings = da .attrs ["flag_meanings" ].split ()
366+ flag_masks = [int (fm ) for fm in da .attrs ["flag_masks" ].split ("," )]
367+ except KeyError :
368+ raise KeyError (da .name + " not a flag variable" )
369+
370+ return flag_meanings , flag_masks
371+
372+ @staticmethod
373+ def unpack_flags (da ):
374+ """
375+ Breaks down flag data array into dataset of boolean masks for each flag
376+ :type da: xarray.DataArray
377+ :param da: dataset
378+ :return: flag masks
379+ :rtype: xarray.Dataset
380+ """
381+
382+ flag_meanings , flag_masks = DatasetUtil ._get_flag_encoding (da )
383+
384+ ds = xarray .Dataset ()
385+ for flag_meaning , flag_mask in zip (flag_meanings , flag_masks ):
386+ ds [flag_meaning ] = DatasetUtil .create_variable (list (da .shape ), bool , dim_names = list (da .dims ))
387+ ds [flag_meaning ] = (da & flag_mask ).astype (bool )
388+
389+ return ds
390+
391+ @staticmethod
392+ def get_flags_mask_or (da , flags = None ):
393+ """
394+ Returns boolean mask for set of flags, defined as logical or of flags
395+
396+ :type da: xarray.DataArray
397+ :param da: dataset
398+
399+ :type flags: list
400+ :param flags: list of flags (if unset all data flags selected)
401+
402+ :return: flag masks
403+ :rtype: numpy.ndarray
404+ """
405+
406+ flags_ds = DatasetUtil .unpack_flags (da )
407+
408+ flags = flags if flags is not None else flags_ds .variables
409+ mask_flags = [flags_ds [flag ].values for flag in flags ]
410+
411+ return numpy .logical_or .reduce (mask_flags )
412+
413+ @staticmethod
414+ def get_flags_mask_and (da , flags = None ):
415+ """
416+ Returns boolean mask for set of flags, defined as logical and of flags
417+
418+ :type da: xarray.DataArray
419+ :param da: dataset
420+
421+ :type flags: list
422+ :param flags: list of flags (if unset all data flags selected)
423+
424+ :return: flag masks
425+ :rtype: numpy.ndarray
426+ """
427+
428+ flags_ds = DatasetUtil .unpack_flags (da )
429+
430+ flags = flags if flags is not None else flags_ds .variables
431+ mask_flags = [flags_ds [flag ].values for flag in flags ]
432+
433+ return numpy .logical_and .reduce (mask_flags )
434+
435+ @staticmethod
436+ def set_flag (da , flag_name , error_if_set = False ):
437+ """
438+ Sets named flag for elements in data array
439+ :type da: xarray.DataArray
440+ :param da: dataset
441+ :type flag_name: str
442+ :param flag_name: name of flag to set
443+ :type error_if_set: bool
444+ :param error_if_set: raises error if chosen flag is already set for any element
445+ """
446+
447+ set_flags = DatasetUtil .unpack_flags (da )[flag_name ]
448+
449+ if numpy .any (set_flags == True ) and error_if_set :
450+ raise ValueError ("Flag " + flag_name + " already set for variable " + da .name )
451+
452+ # Find flag mask
453+ flag_meanings , flag_masks = DatasetUtil ._get_flag_encoding (da )
454+ flag_bit = flag_meanings .index (flag_name )
455+ flag_mask = flag_masks [flag_bit ]
456+
457+ da .values = da .values | flag_mask
458+
459+ return da
460+
461+ @staticmethod
462+ def unset_flag (da , flag_name , error_if_unset = False ):
463+ """
464+ Unsets named flag for specified index of dataset variable
465+ :type da: xarray.DataArray
466+ :param da: data array
467+ :type flag_name: str
468+ :param flag_name: name of flag to unset
469+ :type error_if_unset: bool
470+ :param error_if_unset: raises error if chosen flag is already set at specified index
471+ """
472+
473+ set_flags = DatasetUtil .unpack_flags (da )[flag_name ]
474+
475+ if numpy .any (set_flags == False ) and error_if_unset :
476+ raise ValueError ("Flag " + flag_name + " already set for variable " + da .name )
477+
478+ # Find flag mask
479+ flag_meanings , flag_masks = DatasetUtil ._get_flag_encoding (da )
480+ flag_bit = flag_meanings .index (flag_name )
481+ flag_mask = flag_masks [flag_bit ]
482+
483+ da .values = da .values & ~ flag_mask
484+
485+ return da
486+
487+ @staticmethod
488+ def get_set_flags (da ):
489+ """
490+ Return list of set flags for single element data array
491+ :type da: xarray.DataArray
492+ :param da: single element data array
493+ :return: set flags
494+ :rtype: list
495+ """
496+
497+ if da .shape != ():
498+ raise ValueError ("Must pass single element data array" )
499+
500+ flag_meanings , flag_masks = DatasetUtil ._get_flag_encoding (da )
501+
502+ set_flags = []
503+ for flag_meaning , flag_mask in zip (flag_meanings , flag_masks ):
504+ if (da & flag_mask ):
505+ set_flags .append (flag_meaning )
506+
507+ return set_flags
508+
509+ @staticmethod
510+ def check_flag_set (da , flag_name ):
511+ """
512+ Returns if flag for single element data array
513+ :type da: xarray.DataArray
514+ :param da: single element data array
515+ :type flag_name: str
516+ :param flag_name: name of flag to set
517+ :return: set flags
518+ :rtype: list
519+ """
520+
521+ if da .shape != ():
522+ raise ValueError ("Must pass single element data array" )
523+
524+ set_flags = DatasetUtil .get_set_flags (da )
525+
526+ if flag_name in set_flags :
527+ return True
528+ return False
529+
349530
350531if __name__ == "__main__" :
351532 pass
0 commit comments