Skip to content

Commit cf7f7fa

Browse files
committed
Fix bug that was happening when slicing data
1 parent 24e305c commit cf7f7fa

File tree

4 files changed

+158
-53
lines changed

4 files changed

+158
-53
lines changed

obsarray/err_corr.py

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -69,23 +69,68 @@ def form(self) -> str:
6969
"""Form name"""
7070
pass
7171

72-
def expand_dim_matrix(self, submatrix, sli):
72+
def expand_dim_matrix(self, submatrix, submatrix_dim, sli):
7373
return expand_errcorr_dims(
7474
in_corr=submatrix,
75-
in_dim=self.dims,
75+
in_dim=submatrix_dim,
7676
out_dim=list(self._obj[self._unc_var_name][sli].dims),
77-
dim_sizes={
78-
dim: self._obj.dims[dim]
79-
for dim in self._obj[self._unc_var_name][sli].dims
80-
},
77+
dim_sizes=self.get_sliced_dim_sizes_uncvar(sli)
8178
)
8279

80+
def get_sliced_dim_sizes_uncvar(self, sli) -> dict:
81+
"""
82+
return dictionary with sizes of sliced dimensions of unc variable, including all dimensions.
83+
84+
:param sli: slice (tuple with slice for each dimension)
85+
:return: shape of included sliced dimensions
86+
"""
87+
uncvar_dims = self._obj[self._unc_var_name][sli].dims
88+
uncvar_shape = self._obj[self._unc_var_name][sli].shape
89+
return {uncvar_dims[idim]: uncvar_shape[idim] for idim in range(len(uncvar_dims))}
90+
91+
def get_sliced_dim_sizes_errcorr(self, sli) -> dict:
92+
"""
93+
return dictionary with sizes of sliced dimensions of unc variable, including only dimensions which are included in the current error correlation form.
94+
95+
:param sli: slice (tuple with slice for each dimension)
96+
:return: shape of included sliced dimensions
97+
"""
98+
uncvar_sizes=self.get_sliced_dim_sizes_uncvar(sli)
99+
sliced_dims = self.get_sliced_dims_errcorr(sli)
100+
101+
return {dim: uncvar_sizes[dim] for dim in sliced_dims}
102+
103+
def get_sliced_dims_errcorr(self, sli) -> list:
104+
"""
105+
return dimensions which are within the slice and included in the current error correlation form.
106+
107+
:param sli: slice (tuple with slice for each dimension)
108+
:return: list with sliced dimensions
109+
"""
110+
all_dims = self._obj[self._unc_var_name].dims
111+
return [all_dims[idim] for idim in range(len(all_dims)) if (isinstance(sli[idim],slice) and all_dims[idim] in self.dims)]
112+
113+
def get_sliced_shape_errcorr(self, sli)->tuple:
114+
"""
115+
return shape of sliced uncertainty variable, including only dimensions which are included in the current error correlation form.
116+
117+
:param sli: slice (tuple with slice for each dimension)
118+
:return: shape of included sliced dimensions
119+
"""
120+
uncvar_sizes=self.get_sliced_dim_sizes_uncvar(sli)
121+
sliced_dims = self.get_sliced_dims_errcorr(sli)
122+
123+
return tuple([uncvar_sizes[dim] for dim in sliced_dims])
124+
83125
def slice_full_cov(self, full_matrix, sli):
84-
mask_array = np.ones(self._obj[self._unc_var_name].shape, dtype=bool)
126+
return self.slice_flattened_matrix(full_matrix,self._obj[self._unc_var_name].shape,sli)
127+
128+
def slice_flattened_matrix(self, flattened_matrix, variable_shape, sli):
129+
mask_array = np.ones(variable_shape, dtype=bool)
85130
mask_array[sli] = False
86131

87132
return np.delete(
88-
np.delete(full_matrix, mask_array.ravel(), 0), mask_array.ravel(), 1
133+
np.delete(flattened_matrix, mask_array.ravel(), 0), mask_array.ravel(), 1
89134
)
90135

91136
@abc.abstractmethod
@@ -100,6 +145,18 @@ def build_matrix(self, sli: Union[np.ndarray, tuple]) -> np.ndarray:
100145
"""
101146
pass
102147

148+
def build_dot_matrix(self, sli: Union[np.ndarray, tuple]) -> np.ndarray:
149+
"""
150+
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
151+
in this parameterisation
152+
153+
:param sli: slice of observation variable to return error-correlation matrix for
154+
155+
:return: populated error-correlation matrix
156+
"""
157+
158+
return self.expand_dim_matrix(self.build_matrix(sli), self.get_sliced_dims_errcorr(sli), sli)
159+
103160

104161
def register_err_corr_form(form_name: str) -> Callable:
105162
"""
@@ -133,16 +190,12 @@ def build_matrix(self, sli):
133190
"""
134191

135192
# evaluate correlation over matrices in form defintion
136-
dim_lens = [len(self._obj[dim]) for dim in self.dims]
193+
dim_lens = self.get_sliced_shape_errcorr(sli)
137194
n_elems = int(np.prod(dim_lens))
138195

139-
dims_matrix = np.eye(n_elems)
140-
141-
# expand to correlation matrix over all variable dims
142-
return self.expand_dim_matrix(dims_matrix, sli)
196+
submatrix = np.eye(n_elems)
143197

144-
# # subset to slice
145-
# return self.slice_full_cov(full_matrix, sli)
198+
return submatrix
146199

147200

148201
@register_err_corr_form("systematic")
@@ -162,16 +215,12 @@ def build_matrix(self, sli):
162215
"""
163216

164217
# evaluate correlation over matrices in form defintion
165-
dim_lens = [len(self._obj[dim]) for dim in self.dims]
218+
dim_lens = self.get_sliced_shape_errcorr(sli)
166219
n_elems = int(np.prod(dim_lens))
167220

168-
dims_matrix = np.ones((n_elems, n_elems))
221+
submatrix = np.ones((n_elems, n_elems))
169222

170-
# expand to correlation matrix over all variable dims
171-
return self.expand_dim_matrix(dims_matrix, sli)
172-
173-
# subset to slice
174-
# return self.slice_full_cov(full_matrix, sli)
223+
return submatrix
175224

176225

177226
@register_err_corr_form("err_corr_matrix")
@@ -189,11 +238,16 @@ def build_matrix(self, sli):
189238
:return: populated error-correlation matrix
190239
"""
191240

192-
# expand to correlation matrix over all variable dims
193-
return self.expand_dim_matrix(self._obj[self.params[0]], sli)
241+
all_dims = self._obj[self._unc_var_name].dims
242+
all_dims_sizes = self._obj.sizes
243+
244+
sli_submatrix = tuple([sli[i] for i in range(len(all_dims)) if all_dims[i] in self.dims])
245+
246+
sliced_shape = tuple([all_dims_sizes[all_dims[i]] for i in range(len(all_dims)) if all_dims[i] in self.dims])
247+
248+
submatrix = self.slice_flattened_matrix(self._obj[self.params[0]],sliced_shape,sli_submatrix)
194249

195-
# # subset to slice
196-
# return self.slice_full_cov(full_matrix, sli)
250+
return submatrix
197251

198252

199253
@register_err_corr_form("ensemble")

obsarray/test/test_err_corr_forms.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,55 @@ def test_slice_full_cov_full(self):
8181

8282
np.testing.assert_equal(full_matrix, slice_matrix)
8383

84+
def test_get_sliced_dim_sizes_uncvar(self):
85+
basicerrcorr = self.BasicErrCorrForm(
86+
self.ds, "u_ran_temperature", ["x"], [], []
87+
)
88+
dim_sizes=basicerrcorr.get_sliced_dim_sizes_uncvar((slice(None),0,slice(0,2,1)))
89+
assert dim_sizes == {"x": 2, "time": 2}
90+
91+
def test_get_sliced_dim_sizes_errcorr(self):
92+
basicerrcorr = self.BasicErrCorrForm(
93+
self.ds, "u_ran_temperature", ["x"], [], []
94+
)
95+
dim_sizes = basicerrcorr.get_sliced_dim_sizes_errcorr((slice(None), 0, slice(0, 2, 1)))
96+
assert dim_sizes == {"x": 2}
97+
98+
def test_get_sliced_dims_errcorr(self):
99+
basicerrcorr = self.BasicErrCorrForm(
100+
self.ds, "u_ran_temperature", ["x"], [], []
101+
)
102+
dims = basicerrcorr.get_sliced_dims_errcorr((slice(None), 0, slice(0, 2, 1)))
103+
assert dims == ["x"]
104+
105+
def test_get_sliced_shape_errcorr(self):
106+
basicerrcorr = self.BasicErrCorrForm(
107+
self.ds, "u_ran_temperature", ["x"], [], []
108+
)
109+
shape = basicerrcorr.get_sliced_shape_errcorr((slice(None), 0, slice(0, 2, 1)))
110+
assert shape == (2,)
111+
basicerrcorr = self.BasicErrCorrForm(
112+
self.ds, "u_ran_temperature", ["x", "time"], [], []
113+
)
114+
shape = basicerrcorr.get_sliced_shape_errcorr((slice(None), 0, slice(0, 2, 1)))
115+
assert shape == (2,2)
116+
117+
def test_slice_flattened_matrix(self):
118+
basicerrcorr = self.BasicErrCorrForm(
119+
self.ds, "u_ran_temperature", ["x"], [], []
120+
)
121+
122+
full_matrix = np.arange(144).reshape((12, 12))
123+
slice_matrix = basicerrcorr.slice_flattened_matrix(
124+
full_matrix, (2,2,3), (slice(None), slice(None), 0)
125+
)
126+
127+
exp_slice_matrix = np.array(
128+
[[0, 3, 6, 9], [36, 39, 42, 45], [72, 75, 78, 81], [108, 111, 114, 117]]
129+
)
130+
131+
np.testing.assert_equal(slice_matrix, exp_slice_matrix)
132+
84133
def test_slice_full_cov_slice(self):
85134
basicerrcorr = self.BasicErrCorrForm(
86135
self.ds, "u_ran_temperature", ["x"], [], []
@@ -97,22 +146,21 @@ def test_slice_full_cov_slice(self):
97146

98147
np.testing.assert_equal(slice_matrix, exp_slice_matrix)
99148

100-
101149
class TestRandomUnc(unittest.TestCase):
102150
def setUp(self) -> None:
103151
self.ds = create_ds()
104152

105153
def test_build_matrix_1stdim(self):
106154
rc = RandomCorrelation(self.ds, "u_ran_temperature", ["x"], [], [])
107155

108-
ecrm = rc.build_matrix((slice(None), slice(None), slice(None)))
156+
ecrm = rc.build_dot_matrix((slice(None), slice(None), slice(None)))
109157

110158
np.testing.assert_equal(ecrm, np.eye(12))
111159

112160
def test_build_matrix_2nddim(self):
113161
rc = RandomCorrelation(self.ds, "u_ran_temperature", ["y"], [], [])
114162

115-
ecrm = rc.build_matrix((slice(None), slice(None), slice(None)))
163+
ecrm = rc.build_dot_matrix((slice(None), slice(None), slice(None)))
116164

117165
np.testing.assert_equal(ecrm, np.eye(12))
118166

@@ -124,33 +172,33 @@ def setUp(self) -> None:
124172
def build_matrix_1stdim(self):
125173
rc = SystematicCorrelation(self.ds, "u_sys_temperature", ["x"], [], [])
126174

127-
ecrm = rc.build_matrix((slice(None), slice(None), slice(None)))
175+
ecrm = rc.build_dot_matrix((slice(None), slice(None), slice(None)))
128176
# np.testing.assert_equal(ecrm, np.ones((12, 12)))
129177

130178
return ecrm
131179

132180
def build_matrix_2nddim(self):
133181
rc = SystematicCorrelation(self.ds, "u_sys_temperature", ["y"], [], [])
134182

135-
ecrm = rc.build_matrix((slice(None), slice(None), slice(None)))
183+
ecrm = rc.build_dot_matrix((slice(None), slice(None), slice(None)))
136184
# np.testing.assert_equal(ecrm, np.ones((12, 12)))
137185

138186
return ecrm
139187

140188
def build_matrix_3ddim(self):
141189
rc = SystematicCorrelation(self.ds, "u_sys_temperature", ["time"], [], [])
142190

143-
ecrm = rc.build_matrix((slice(None), slice(None), slice(None)))
191+
ecrm = rc.build_dot_matrix((slice(None), slice(None), slice(None)))
144192
# np.testing.assert_equal(ecrm, np.ones((12, 12)))
145193

146194
return ecrm
147195

148-
def test_build_matrix(self):
196+
def test_build_dot_matrix(self):
149197
x = self.build_matrix_1stdim()
150198
y = self.build_matrix_2nddim()
151199
time = self.build_matrix_3ddim()
200+
print(x.dot(y),x,y)
152201
np.testing.assert_equal((x.dot(y)).dot(time), np.ones((12, 12)))
153202

154-
155203
if __name__ == "main":
156204
unittest.main()

obsarray/test/test_unc_accessor.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ def create_ds():
117117
},
118118
)
119119

120+
ds["err_corr_str_temperature"] = (
121+
["x", "x"],
122+
np.eye(temperature.shape[0]),
123+
)
124+
120125
return ds
121126

122127

@@ -474,13 +479,15 @@ def test_systematic_unc(self, mock):
474479
self.ds.unc["temperature"][:, :, 0].systematic_unc()
475480
mock.assert_called_once_with(["u_sys_temperature"])
476481

477-
@patch(
478-
"obsarray.unc_accessor.Uncertainty.err_corr_matrix",
479-
return_value=xr.DataArray(np.ones((12, 12)), dims=["x.y.time", "x.y.time"]),
480-
)
481-
def test_total_err_corr_matrix(self, mock_err_corr_matrix):
482-
pass
483-
# tercm = self.ds.unc["temperature"].total_err_corr_matrix()
482+
def test_total_err_corr_matrix(self, ):
483+
tercm = self.ds.unc["temperature"].total_err_corr_matrix()
484+
assert tercm.shape==(12,12)
485+
tercm = self.ds.unc["temperature"][:,:,0].total_err_corr_matrix()
486+
assert tercm.shape==(4,4)
487+
tercm = self.ds.unc["temperature"][:,0,0:2].total_err_corr_matrix()
488+
assert tercm.shape==(4,4)
489+
tercm = self.ds.unc["temperature"][0,0,:].total_err_corr_matrix()
490+
assert tercm.shape==(3,3)
484491

485492
def test_structured_err_corr_matrix(self):
486493
pass
@@ -496,7 +503,7 @@ class TestUncertainty(unittest.TestCase):
496503
def setUp(self):
497504
self.ds = create_ds()
498505

499-
@patch("obsarray.unc_accessor.Uncertainty.expand_sli", return_value="slice")
506+
@patch("obsarray.unc_accessor.Uncertainty._expand_sli", return_value="slice")
500507
def test___getitem__(self, m):
501508
self.assertEqual(
502509
self.ds.unc["temperature"]["u_ran_temperature"]["in_slice"]._sli, "slice"
@@ -513,19 +520,19 @@ def test_expand_slice_1d_full(self):
513520
def test_expand_slice_1d_None(self):
514521
self.ds["new"] = (["time"], np.ones(3), {})
515522
self.ds.unc["new"]["u_new"] = (["time"], np.ones(3), {})
516-
sli = self.ds.unc["temperature"]["u_ran_temperature"]._expand_sli()
523+
sli = self.ds.unc["new"]["u_new"]._expand_sli()
517524
self.assertEqual((slice(None),), sli)
518525

519526
def test_expand_slice_full(self):
520-
sli = self.ds.unc["temperature"]["u_ran_temperature"].expand_sli((1, 1, 1))
527+
sli = self.ds.unc["temperature"]["u_ran_temperature"]._expand_sli((1, 1, 1))
521528
self.assertEqual((1, 1, 1), sli)
522529

523530
def test_expand_slice_None(self):
524-
sli = self.ds.unc["temperature"]["u_ran_temperature"].expand_sli()
531+
sli = self.ds.unc["temperature"]["u_ran_temperature"]._expand_sli()
525532
self.assertEqual((slice(None), slice(None), slice(None)), sli)
526533

527534
def test_expand_slice_first(self):
528-
sli = self.ds.unc["temperature"]["u_ran_temperature"].expand_sli((0,))
535+
sli = self.ds.unc["temperature"]["u_ran_temperature"]._expand_sli((0,))
529536
self.assertEqual((0, slice(None), slice(None)), sli)
530537

531538
def test_err_corr(self):

obsarray/unc_accessor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,14 +320,10 @@ def err_corr_matrix(self) -> xr.DataArray:
320320
# populate with error-correlation matrices built be each error-correlation
321321
# parameterisation object
322322
for dim_err_corr in self.err_corr:
323-
if np.all(
324-
[
325-
dim in self._obj[self._unc_var_name][self._sli].dims
326-
for dim in dim_err_corr[1].dims
327-
]
328-
):
323+
sliced_dims=dim_err_corr[1].get_sliced_dims_errcorr(self._sli)
324+
if len(dim_err_corr[1].get_sliced_dims_errcorr(self._sli))>0:
329325
err_corr_matrix.values = err_corr_matrix.values.dot(
330-
dim_err_corr[1].build_matrix(self._sli)
326+
dim_err_corr[1].build_dot_matrix(self._sli)
331327
)
332328

333329
return err_corr_matrix

0 commit comments

Comments
 (0)