Skip to content

Commit f1c68e4

Browse files
committed
add ability to append template variables to existing ds
1 parent c2eb492 commit f1c68e4

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

obsarray/templater/template_util.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,19 @@
1111

1212

1313
def create_ds(
14-
template: Dict[str, Dict], size: Dict[str, int], metadata: Optional[Dict] = None,
15-
propagate_ds: Optional[xarray.Dataset] = None,
14+
template: Dict[str, Dict],
15+
size: Dict[str, int],
16+
metadata: Optional[Dict] = None,
17+
append_ds: Optional[xarray.Dataset] = None,
18+
propagate_ds: Optional[xarray.Dataset] = None,
1619
) -> xarray.Dataset:
1720
"""
1821
Returns template dataset
1922
2023
:param template: dictionary defining ds variable structure, as defined below.
2124
:param size: dictionary of dataset dimensions, entry per dataset dimension with value of size as int
2225
:param metadata: dictionary of dataset metadata
26+
:param append_ds: base dataset to append with template variables
2327
:param propagate_ds: template dataset is populated with data from propagate_ds for their variables with
2428
common names and dimensions. Useful for transferring common data between datasets at different processing levels
2529
(e.g. times, etc.).
@@ -38,7 +42,7 @@ def create_ds(
3842
"""
3943

4044
# Create dataset
41-
ds = xarray.Dataset()
45+
ds = append_ds if append_ds is not None else xarray.Dataset()
4246

4347
# Add variables
4448
ds = TemplateUtil.add_variables(ds, template, size)
@@ -85,7 +89,9 @@ def add_variables(
8589
8690
:returns: dataset with defined variables
8791
"""
92+
8893
for var_name in template.keys():
94+
8995
var = TemplateUtil._create_var(var_name, template[var_name], size)
9096

9197
ds[var_name] = var
@@ -123,7 +129,7 @@ def _create_var(
123129
)
124130

125131
# Create variable and add to dataset
126-
if isinstance(dtype,str):
132+
if dtype == str:
127133
if dtype == "flag":
128134
flag_meanings = attributes.pop("flag_meanings")
129135
variable = du.create_flags_variable(
@@ -218,21 +224,32 @@ def propagate_values(target_ds, source_ds, exclude=None):
218224

219225
# Find variable names common to target_ds and source_ds, excluding specified exclude variables
220226
common_variable_names = list(set(target_ds).intersection(source_ds))
221-
#common_variable_names = list(set(target_ds.variables).intersection(source_ds.variables))
222-
#print(common_variable_names)
227+
# common_variable_names = list(set(target_ds.variables).intersection(source_ds.variables))
228+
# print(common_variable_names)
223229

224230
if exclude is not None:
225-
common_variable_names = [name for name in common_variable_names if name not in exclude]
231+
common_variable_names = [
232+
name for name in common_variable_names if name not in exclude
233+
]
226234

227235
# Remove any common variables that have different dimensions in target_ds and source_ds
228-
common_variable_names = [name for name in common_variable_names if target_ds[name].dims == source_ds[name].dims]
236+
common_variable_names = [
237+
name
238+
for name in common_variable_names
239+
if target_ds[name].dims == source_ds[name].dims
240+
]
229241

230242
# Propagate data
231243
for common_variable_name in common_variable_names:
232-
if target_ds[common_variable_name].shape == source_ds[common_variable_name].shape:
233-
target_ds[common_variable_name].values = source_ds[common_variable_name].values
234-
235-
#to do - add method to propagate common unpopulated metadata
244+
if (
245+
target_ds[common_variable_name].shape
246+
== source_ds[common_variable_name].shape
247+
):
248+
target_ds[common_variable_name].values = source_ds[
249+
common_variable_name
250+
].values
251+
252+
# to do - add method to propagate common unpopulated metadata
236253

237254

238255
if __name__ == "__main__":

0 commit comments

Comments
 (0)