Skip to content

Commit e8039d0

Browse files
authored
Merge pull request #2 from comet-toolkit/small_fixes
Small fixes for use in punpy
2 parents 195a69c + e2aa0bb commit e8039d0

File tree

6 files changed

+70
-54
lines changed

6 files changed

+70
-54
lines changed

notebook/test_register_err_corr_form.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,32 @@ def build_matrix(self, idx: np.ndarray) -> np.ndarray:
1414
return "abc"
1515

1616

17-
ds = xr.open_dataset("obs_example.nc")
18-
19-
ds.unc["temperature"]["u_sys_temperature"] = (
20-
["x", "y", "time"],
21-
ds.temperature * 0.03,
22-
{
23-
"err_corr": [
24-
{
25-
"dim": "x",
26-
"form": "new",
27-
"params": [],
28-
},
29-
{
30-
"dim": "y",
31-
"form": "systematic",
32-
"params": [],
33-
},
34-
{
35-
"dim": "time",
36-
"form": "systematic",
37-
"params": [],
38-
},
39-
]
40-
},
41-
)
42-
43-
p = ds.unc["temperature"]["u_sys_temperature"]
44-
45-
pass
17+
# ds = xr.open_dataset("obs_example.nc")
18+
#
19+
# ds.unc["temperature"]["u_sys_temperature"] = (
20+
# ["x", "y", "time"],
21+
# ds.temperature * 0.03,
22+
# {
23+
# "err_corr": [
24+
# {
25+
# "dim": "x",
26+
# "form": "new",
27+
# "params": [],
28+
# },
29+
# {
30+
# "dim": "y",
31+
# "form": "systematic",
32+
# "params": [],
33+
# },
34+
# {
35+
# "dim": "time",
36+
# "form": "systematic",
37+
# "params": [],
38+
# },
39+
# ]
40+
# },
41+
# )
42+
#
43+
# p = ds.unc["temperature"]["u_sys_temperature"]
44+
#
45+
# pass

obsarray/err_corr.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class BaseErrCorrForm(abc.ABC):
4949

5050
def __init__(self, xarray_obj, unc_var_name, dims, params, units):
5151
self._obj = xarray_obj
52-
self.unc_var_name = unc_var_name
52+
self._unc_var_name = unc_var_name
5353
self.dims = dims if isinstance(dims, list) else [dims]
5454
self.params = params if isinstance(params, list) else [params]
5555
self.units = units
@@ -69,18 +69,18 @@ def form(self) -> str:
6969
"""Form name"""
7070
pass
7171

72-
def expand_dim_matrix(self, submatrix):
72+
def expand_dim_matrix(self, submatrix, sli):
7373
return expand_errcorr_dims(
7474
in_corr=submatrix,
7575
in_dim=self.dims,
76-
out_dim=list(self._obj[self.unc_var_name].dims),
76+
out_dim=list(self._obj[self._unc_var_name][sli].dims),
7777
dim_sizes={
78-
dim: self._obj.dims[dim] for dim in self._obj[self.unc_var_name].dims
78+
dim: self._obj.dims[dim] for dim in self._obj[self._unc_var_name][sli].dims
7979
},
8080
)
8181

8282
def slice_full_cov(self, full_matrix, sli):
83-
mask_array = np.ones(self._obj[self.unc_var_name].shape, dtype=bool)
83+
mask_array = np.ones(self._obj[self._unc_var_name].shape, dtype=bool)
8484
mask_array[sli] = False
8585

8686
return np.delete(
@@ -138,10 +138,10 @@ def build_matrix(self, sli):
138138
dims_matrix = np.eye(n_elems)
139139

140140
# expand to correlation matrix over all variable dims
141-
full_matrix = self.expand_dim_matrix(dims_matrix)
141+
return self.expand_dim_matrix(dims_matrix, sli)
142142

143-
# subset to slice
144-
return self.slice_full_cov(full_matrix, sli)
143+
# # subset to slice
144+
# return self.slice_full_cov(full_matrix, sli)
145145

146146

147147
@register_err_corr_form("systematic")
@@ -167,10 +167,10 @@ def build_matrix(self, sli):
167167
dims_matrix = np.ones((n_elems, n_elems))
168168

169169
# expand to correlation matrix over all variable dims
170-
full_matrix = self.expand_dim_matrix(dims_matrix)
170+
return self.expand_dim_matrix(dims_matrix,sli)
171171

172172
# subset to slice
173-
return self.slice_full_cov(full_matrix, sli)
173+
# return self.slice_full_cov(full_matrix, sli)
174174

175175

176176
@register_err_corr_form("err_corr_matrix")
@@ -189,10 +189,10 @@ def build_matrix(self, sli):
189189
"""
190190

191191
# expand to correlation matrix over all variable dims
192-
full_matrix = self.expand_dim_matrix(self._obj[self.params[0]])
192+
return self.expand_dim_matrix(self._obj[self.params[0]],sli)
193193

194-
# subset to slice
195-
return self.slice_full_cov(full_matrix, sli)
194+
# # subset to slice
195+
# return self.slice_full_cov(full_matrix, sli)
196196

197197

198198
@register_err_corr_form("ensemble")

obsarray/templater/template_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _create_var(
129129
)
130130

131131
# Create variable and add to dataset
132-
if dtype == str:
132+
if isinstance(dtype,str):
133133
if dtype == "flag":
134134
flag_meanings = attributes.pop("flag_meanings")
135135
variable = du.create_flags_variable(

obsarray/test/test_err_corr_forms.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,35 @@ class TestSystematicUnc(unittest.TestCase):
121121
def setUp(self) -> None:
122122
self.ds = create_ds()
123123

124-
def test_build_matrix_1stdim(self):
124+
def build_matrix_1stdim(self):
125125
rc = SystematicCorrelation(self.ds, "u_sys_temperature", ["x"], [], [])
126126

127127
ecrm = rc.build_matrix((slice(None), slice(None), slice(None)))
128+
# np.testing.assert_equal(ecrm, np.ones((12, 12)))
128129

129-
np.testing.assert_equal(ecrm, np.ones((12, 12)))
130+
return ecrm
130131

131-
def test_build_matrix_2nddim(self):
132+
def build_matrix_2nddim(self):
132133
rc = SystematicCorrelation(self.ds, "u_sys_temperature", ["y"], [], [])
133134

134135
ecrm = rc.build_matrix((slice(None), slice(None), slice(None)))
136+
# np.testing.assert_equal(ecrm, np.ones((12, 12)))
137+
138+
return ecrm
139+
140+
def build_matrix_3ddim(self):
141+
rc = SystematicCorrelation(self.ds, "u_sys_temperature", ["time"], [], [])
142+
143+
ecrm = rc.build_matrix((slice(None), slice(None), slice(None)))
144+
# np.testing.assert_equal(ecrm, np.ones((12, 12)))
135145

136-
np.testing.assert_equal(ecrm, np.ones((12, 12)))
146+
return ecrm
137147

148+
def test_build_matrix(self):
149+
x=self.build_matrix_1stdim()
150+
y=self.build_matrix_2nddim()
151+
time=self.build_matrix_3ddim()
152+
np.testing.assert_equal((x.dot(y)).dot(time), np.ones((12, 12)))
138153

139154
if __name__ == "main":
140155
unittest.main()

obsarray/test/test_unc_accessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def compare_err_corr_form(self, form, exp_form):
3131
self.assertEqual(form.form, exp_form.form)
3232
self.assertCountEqual(form.params, exp_form.params)
3333
self.assertCountEqual(form.units, exp_form.units)
34-
self.assertCountEqual(form.unc_var_name, exp_form.unc_var_name)
34+
self.assertCountEqual(form._unc_var_name, exp_form._unc_var_name)
3535

3636

3737
def create_ds():
@@ -99,7 +99,7 @@ def create_ds():
9999
"err_corr": [
100100
{
101101
"dim": "x",
102-
"form": "custom",
102+
"form": "err_corr_matrix",
103103
"params": ["err_corr_str_temperature"],
104104
},
105105
{

obsarray/unc_accessor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,10 @@ def err_corr_matrix(self) -> xr.DataArray:
291291
# populate with error-correlation matrices built be each error-correlation
292292
# parameterisation object
293293
for dim_err_corr in self.err_corr:
294-
err_corr_matrix.values = err_corr_matrix.values.dot(
295-
dim_err_corr[1].build_matrix(self._sli)
296-
)
294+
if np.all([dim in self._obj[self._unc_var_name][self._sli].dims for dim in dim_err_corr[1].dims]):
295+
err_corr_matrix.values = err_corr_matrix.values.dot(
296+
dim_err_corr[1].build_matrix(self._sli)
297+
)
297298

298299
return err_corr_matrix
299300

@@ -563,7 +564,7 @@ def total_err_cov_matrix(self) -> xr.DataArray:
563564
)
564565
covs_sum = np.zeros(total_err_cov_matrix.shape)
565566
for unc in self:
566-
covs_sum += unc.err_cov_matrix().values
567+
covs_sum += unc[self._sli].err_cov_matrix().values
567568

568569
total_err_cov_matrix.values = covs_sum
569570

@@ -582,7 +583,7 @@ def structured_err_cov_matrix(self):
582583

583584
covs_sum = np.zeros(structured_err_cov_matrix.shape)
584585
for unc in self:
585-
covs_sum += unc.err_cov_matrix().values
586+
covs_sum += unc[self._sli].err_cov_matrix().values
586587

587588
structured_err_cov_matrix.values = covs_sum
588589

0 commit comments

Comments
 (0)