11"""err_corr_forms - module for the defintion of error-correlation parameterisation forms"""
22
33import abc
4- from typing import Callable , Type , Union
4+ from typing import Callable , Type , Union , List
55import numpy as np
66from comet_maths .linear_algebra .matrix_conversion import expand_errcorr_dims
77
@@ -69,23 +69,79 @@ def form(self) -> str:
6969 """Form name"""
7070 pass
7171
72- def expand_dim_matrix (self , submatrix , sli ):
72+ def expand_dim_matrix (
73+ self , submatrix : np .ndarray , submatrix_dim : Union [str , List [str ]], sli : tuple
74+ ):
7375 return expand_errcorr_dims (
7476 in_corr = submatrix ,
75- in_dim = self . dims ,
77+ in_dim = submatrix_dim ,
7678 out_dim = list (self ._obj [self ._unc_var_name ][sli ].dims ),
77- dim_sizes = {
78- dim : self ._obj .dims [dim ]
79- for dim in self ._obj [self ._unc_var_name ][sli ].dims
80- },
79+ dim_sizes = self .get_sliced_dim_sizes_uncvar (sli ),
8180 )
8281
83- def slice_full_cov (self , full_matrix , sli ):
84- mask_array = np .ones (self ._obj [self ._unc_var_name ].shape , dtype = bool )
82+ def get_sliced_dim_sizes_uncvar (self , sli : tuple ) -> dict :
83+ """
84+ Return dictionary with sizes of sliced dimensions of unc variable, including all dimensions.
85+
86+ :param sli: slice (tuple with slice for each dimension)
87+ :return: shape of included sliced dimensions
88+ """
89+ uncvar_dims = self ._obj [self ._unc_var_name ][sli ].dims
90+ uncvar_shape = self ._obj [self ._unc_var_name ][sli ].shape
91+ return {
92+ uncvar_dims [idim ]: uncvar_shape [idim ] for idim in range (len (uncvar_dims ))
93+ }
94+
95+ def get_sliced_dim_sizes_errcorr (self , sli : tuple ) -> dict :
96+ """
97+ Return dictionary with sizes of sliced dimensions of unc variable, including only dimensions which are
98+ included in the current error correlation form.
99+
100+ :param sli: slice (tuple with slice for each dimension)
101+ :return: shape of included sliced dimensions
102+ """
103+ uncvar_sizes = self .get_sliced_dim_sizes_uncvar (sli )
104+ sliced_dims = self .get_sliced_dims_errcorr (sli )
105+
106+ return {dim : uncvar_sizes [dim ] for dim in sliced_dims }
107+
108+ def get_sliced_dims_errcorr (self , sli : tuple ) -> list :
109+ """
110+ Return dimensions which are within the slice and included in the current error correlation form.
111+
112+ :param sli: slice (tuple with slice for each dimension)
113+ :return: list with sliced dimensions
114+ """
115+ all_dims = self ._obj [self ._unc_var_name ].dims
116+ return [
117+ all_dims [idim ]
118+ for idim in range (len (all_dims ))
119+ if (isinstance (sli [idim ], slice ) and all_dims [idim ] in self .dims )
120+ ]
121+
122+ def get_sliced_shape_errcorr (self , sli : tuple ) -> tuple :
123+ """
124+ return shape of sliced uncertainty variable, including only dimensions which are included in the current error correlation form.
125+
126+ :param sli: slice (tuple with slice for each dimension)
127+ :return: shape of included sliced dimensions
128+ """
129+ uncvar_sizes = self .get_sliced_dim_sizes_uncvar (sli )
130+ sliced_dims = self .get_sliced_dims_errcorr (sli )
131+
132+ return tuple ([uncvar_sizes [dim ] for dim in sliced_dims ])
133+
134+ def slice_full_cov (self , full_matrix : np .ndarray , sli : tuple ) -> np .ndarray :
135+ return self .slice_flattened_matrix (
136+ full_matrix , self ._obj [self ._unc_var_name ].shape , sli
137+ )
138+
139+ def slice_flattened_matrix (self , flattened_matrix , variable_shape , sli ):
140+ mask_array = np .ones (variable_shape , dtype = bool )
85141 mask_array [sli ] = False
86142
87143 return np .delete (
88- np .delete (full_matrix , mask_array .ravel (), 0 ), mask_array .ravel (), 1
144+ np .delete (flattened_matrix , mask_array .ravel (), 0 ), mask_array .ravel (), 1
89145 )
90146
91147 @abc .abstractmethod
@@ -100,6 +156,20 @@ def build_matrix(self, sli: Union[np.ndarray, tuple]) -> np.ndarray:
100156 """
101157 pass
102158
159+ def build_dot_matrix (self , sli : Union [np .ndarray , tuple ]) -> np .ndarray :
160+ """
161+ Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
162+ in this parameterisation
163+
164+ :param sli: slice of observation variable to return error-correlation matrix for
165+
166+ :return: populated error-correlation matrix
167+ """
168+
169+ return self .expand_dim_matrix (
170+ self .build_matrix (sli ), self .get_sliced_dims_errcorr (sli ), sli
171+ )
172+
103173
104174def register_err_corr_form (form_name : str ) -> Callable :
105175 """
@@ -122,7 +192,7 @@ class RandomCorrelation(BaseErrCorrForm):
122192 form = "random"
123193 is_random = True
124194
125- def build_matrix (self , sli ) :
195+ def build_matrix (self , sli : tuple ) -> np . ndarray :
126196 """
127197 Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
128198 in this parameterisation
@@ -133,16 +203,12 @@ def build_matrix(self, sli):
133203 """
134204
135205 # evaluate correlation over matrices in form defintion
136- dim_lens = [ len ( self ._obj [ dim ]) for dim in self . dims ]
206+ dim_lens = self .get_sliced_shape_errcorr ( sli )
137207 n_elems = int (np .prod (dim_lens ))
138208
139- dims_matrix = np .eye (n_elems )
140-
141- # expand to correlation matrix over all variable dims
142- return self .expand_dim_matrix (dims_matrix , sli )
209+ submatrix = np .eye (n_elems )
143210
144- # # subset to slice
145- # return self.slice_full_cov(full_matrix, sli)
211+ return submatrix
146212
147213
148214@register_err_corr_form ("systematic" )
@@ -151,7 +217,7 @@ class SystematicCorrelation(BaseErrCorrForm):
151217 form = "systematic"
152218 is_systematic = True
153219
154- def build_matrix (self , sli ) :
220+ def build_matrix (self , sli : tuple ) -> np . ndarray :
155221 """
156222 Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
157223 in this parameterisation
@@ -162,24 +228,20 @@ def build_matrix(self, sli):
162228 """
163229
164230 # evaluate correlation over matrices in form defintion
165- dim_lens = [ len ( self ._obj [ dim ]) for dim in self . dims ]
231+ dim_lens = self .get_sliced_shape_errcorr ( sli )
166232 n_elems = int (np .prod (dim_lens ))
167233
168- dims_matrix = np .ones ((n_elems , n_elems ))
169-
170- # expand to correlation matrix over all variable dims
171- return self .expand_dim_matrix (dims_matrix , sli )
234+ submatrix = np .ones ((n_elems , n_elems ))
172235
173- # subset to slice
174- # return self.slice_full_cov(full_matrix, sli)
236+ return submatrix
175237
176238
177239@register_err_corr_form ("err_corr_matrix" )
178240class ErrCorrMatrixCorrelation (BaseErrCorrForm ):
179241
180242 form = "err_corr_matrix"
181243
182- def build_matrix (self , sli ) :
244+ def build_matrix (self , sli : tuple ) -> np . ndarray :
183245 """
184246 Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
185247 in this parameterisation
@@ -189,19 +251,34 @@ def build_matrix(self, sli):
189251 :return: populated error-correlation matrix
190252 """
191253
192- # expand to correlation matrix over all variable dims
193- return self .expand_dim_matrix (self ._obj [self .params [0 ]], sli )
254+ all_dims = self ._obj [self ._unc_var_name ].dims
255+ all_dims_sizes = self ._obj .sizes
256+
257+ sli_submatrix = tuple (
258+ [sli [i ] for i in range (len (all_dims )) if all_dims [i ] in self .dims ]
259+ )
260+
261+ sliced_shape = tuple (
262+ [
263+ all_dims_sizes [all_dims [i ]]
264+ for i in range (len (all_dims ))
265+ if all_dims [i ] in self .dims
266+ ]
267+ )
268+
269+ submatrix = self .slice_flattened_matrix (
270+ self ._obj [self .params [0 ]], sliced_shape , sli_submatrix
271+ )
194272
195- # # subset to slice
196- # return self.slice_full_cov(full_matrix, sli)
273+ return submatrix
197274
198275
199276@register_err_corr_form ("ensemble" )
200277class EnsembleCorrelation (BaseErrCorrForm ):
201278
202279 form = "ensemble"
203280
204- def build_matrix (self , sli ) :
281+ def build_matrix (self , sli : tuple ) -> np . ndarray :
205282 """
206283 Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
207284 in this parameterisation
0 commit comments