Skip to content

Commit c846629

Browse files
committed
adding propagate_ds
fix small bug for flags (check for string)
1 parent 70c63c8 commit c846629

File tree

1 file changed

+43
-4
lines changed

1 file changed

+43
-4
lines changed

obsarray/templater/template_util.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@
1111

1212

1313
def create_ds(
14-
template: Dict[str, Dict], size: Dict[str, int], metadata: Optional[Dict] = None
14+
template: Dict[str, Dict], size: Dict[str, int], metadata: Optional[Dict] = None,
15+
propagate_ds: Optional[xarray.Dataset] = None,
1516
) -> xarray.Dataset:
1617
"""
1718
Returns template dataset
1819
1920
:param template: dictionary defining ds variable structure, as defined below.
2021
:param size: dictionary of dataset dimensions, entry per dataset dimension with value of size as int
2122
:param metadata: dictionary of dataset metadata
22-
23+
:param propagate_ds: template dataset is populated with data from propagate_ds for their variables with
24+
common names and dimensions. Useful for transferring common data between datasets at different processing levels
25+
(e.g. times, etc.).
2326
:returns: template dataset
2427
2528
For the ``template`` dictionary each key/value pair defines one variable, where the key is the variable name and the value is a dictionary with the following entries:
@@ -44,6 +47,10 @@ def create_ds(
4447
if metadata is not None:
4548
ds = TemplateUtil.add_metadata(ds, metadata)
4649

50+
# Propagate variable data
51+
if propagate_ds is not None:
52+
TemplateUtil.propagate_values(ds, propagate_ds)
53+
4754
return ds
4855

4956

@@ -78,7 +85,6 @@ def add_variables(
7885
7986
:returns: dataset with defined variables
8087
"""
81-
8288
for var_name in template.keys():
8389
var = TemplateUtil._create_var(var_name, template[var_name], size)
8490

@@ -117,7 +123,7 @@ def _create_var(
117123
)
118124

119125
# Create variable and add to dataset
120-
if dtype == str:
126+
if isinstance(dtype,str):
121127
if dtype == "flag":
122128
flag_meanings = attributes.pop("flag_meanings")
123129
variable = du.create_flags_variable(
@@ -195,6 +201,39 @@ def add_metadata(ds: xarray.Dataset, metadata: Dict) -> xarray.Dataset:
195201

196202
return ds
197203

204+
@staticmethod
205+
def propagate_values(target_ds, source_ds, exclude=None):
206+
"""
207+
Populates target_ds in-place with data from source_ds for their variables with common names and dimensions.
208+
Useful for transferring common data between datasets at different processing levels (e.g. times, etc.).
209+
210+
N.B. propagates data only, not variables as a whole with attributes etc.
211+
212+
:type target_ds: xarray.Dataset
213+
:param target_ds: ds to populate (perhaps data at new processing level)
214+
215+
:type source_ds: xarray.Dataset
216+
:param source_ds: ds to take data from (perhaps data at previous processing level)
217+
"""
218+
219+
# Find variable names common to target_ds and source_ds, excluding specified exclude variables
220+
common_variable_names = list(set(target_ds).intersection(source_ds))
221+
#common_variable_names = list(set(target_ds.variables).intersection(source_ds.variables))
222+
#print(common_variable_names)
223+
224+
if exclude is not None:
225+
common_variable_names = [name for name in common_variable_names if name not in exclude]
226+
227+
# Remove any common variables that have different dimensions in target_ds and source_ds
228+
common_variable_names = [name for name in common_variable_names if target_ds[name].dims == source_ds[name].dims]
229+
230+
# Propagate data
231+
for common_variable_name in common_variable_names:
232+
if target_ds[common_variable_name].shape == source_ds[common_variable_name].shape:
233+
target_ds[common_variable_name].values = source_ds[common_variable_name].values
234+
235+
#to do - add method to propagate common unpopulated metadata
236+
198237

199238
if __name__ == "__main__":
200239
pass

0 commit comments

Comments
 (0)