11"""err_corr_forms - module for the defintion of error-correlation parameterisation forms"""
22
33import abc
4- from typing import Callable , Type , Union
4+ from typing import Callable , Type , Union , List
55import numpy as np
66from comet_maths .linear_algebra .matrix_conversion import expand_errcorr_dims
77
@@ -69,61 +69,72 @@ def form(self) -> str:
6969 """Form name"""
7070 pass
7171
72- def expand_dim_matrix (self , submatrix , submatrix_dim , sli ):
72+ def expand_dim_matrix (
73+ self , submatrix : np .ndarray , submatrix_dim : Union [str , List [str ]], sli : tuple
74+ ):
7375 return expand_errcorr_dims (
7476 in_corr = submatrix ,
7577 in_dim = submatrix_dim ,
7678 out_dim = list (self ._obj [self ._unc_var_name ][sli ].dims ),
77- dim_sizes = self .get_sliced_dim_sizes_uncvar (sli )
79+ dim_sizes = self .get_sliced_dim_sizes_uncvar (sli ),
7880 )
7981
80- def get_sliced_dim_sizes_uncvar (self , sli ) -> dict :
82+ def get_sliced_dim_sizes_uncvar (self , sli : tuple ) -> dict :
8183 """
82- return dictionary with sizes of sliced dimensions of unc variable, including all dimensions.
84+ Return dictionary with sizes of sliced dimensions of unc variable, including all dimensions.
8385
8486 :param sli: slice (tuple with slice for each dimension)
8587 :return: shape of included sliced dimensions
8688 """
8789 uncvar_dims = self ._obj [self ._unc_var_name ][sli ].dims
8890 uncvar_shape = self ._obj [self ._unc_var_name ][sli ].shape
89- return {uncvar_dims [idim ]: uncvar_shape [idim ] for idim in range (len (uncvar_dims ))}
91+ return {
92+ uncvar_dims [idim ]: uncvar_shape [idim ] for idim in range (len (uncvar_dims ))
93+ }
9094
91- def get_sliced_dim_sizes_errcorr (self , sli ) -> dict :
95+ def get_sliced_dim_sizes_errcorr (self , sli : tuple ) -> dict :
9296 """
93- return dictionary with sizes of sliced dimensions of unc variable, including only dimensions which are included in the current error correlation form.
97+ Return dictionary with sizes of sliced dimensions of unc variable, including only dimensions which are
98+ included in the current error correlation form.
9499
95100 :param sli: slice (tuple with slice for each dimension)
96101 :return: shape of included sliced dimensions
97102 """
98- uncvar_sizes = self .get_sliced_dim_sizes_uncvar (sli )
103+ uncvar_sizes = self .get_sliced_dim_sizes_uncvar (sli )
99104 sliced_dims = self .get_sliced_dims_errcorr (sli )
100105
101106 return {dim : uncvar_sizes [dim ] for dim in sliced_dims }
102107
103- def get_sliced_dims_errcorr (self , sli ) -> list :
108+ def get_sliced_dims_errcorr (self , sli : tuple ) -> list :
104109 """
105- return dimensions which are within the slice and included in the current error correlation form.
110+ Return dimensions which are within the slice and included in the current error correlation form.
106111
107112 :param sli: slice (tuple with slice for each dimension)
108113 :return: list with sliced dimensions
109114 """
110115 all_dims = self ._obj [self ._unc_var_name ].dims
111- return [all_dims [idim ] for idim in range (len (all_dims )) if (isinstance (sli [idim ],slice ) and all_dims [idim ] in self .dims )]
116+ return [
117+ all_dims [idim ]
118+ for idim in range (len (all_dims ))
119+ if (isinstance (sli [idim ], slice ) and all_dims [idim ] in self .dims )
120+ ]
112121
113- def get_sliced_shape_errcorr (self , sli ) -> tuple :
122+ def get_sliced_shape_errcorr (self , sli : tuple ) -> tuple :
114123 """
115124 return shape of sliced uncertainty variable, including only dimensions which are included in the current error correlation form.
116125
117126 :param sli: slice (tuple with slice for each dimension)
118127 :return: shape of included sliced dimensions
119128 """
120- uncvar_sizes = self .get_sliced_dim_sizes_uncvar (sli )
129+ uncvar_sizes = self .get_sliced_dim_sizes_uncvar (sli )
121130 sliced_dims = self .get_sliced_dims_errcorr (sli )
122131
123132 return tuple ([uncvar_sizes [dim ] for dim in sliced_dims ])
124133
125- def slice_full_cov (self , full_matrix , sli ):
126- return self .slice_flattened_matrix (full_matrix ,self ._obj [self ._unc_var_name ].shape ,sli )
134+ def slice_full_cov (self , full_matrix : np .ndarrary , sli : tuple ) -> np .ndarray :
135+ return self .slice_flattened_matrix (
136+ full_matrix , self ._obj [self ._unc_var_name ].shape , sli
137+ )
127138
128139 def slice_flattened_matrix (self , flattened_matrix , variable_shape , sli ):
129140 mask_array = np .ones (variable_shape , dtype = bool )
@@ -155,7 +166,9 @@ def build_dot_matrix(self, sli: Union[np.ndarray, tuple]) -> np.ndarray:
155166 :return: populated error-correlation matrix
156167 """
157168
158- return self .expand_dim_matrix (self .build_matrix (sli ), self .get_sliced_dims_errcorr (sli ), sli )
169+ return self .expand_dim_matrix (
170+ self .build_matrix (sli ), self .get_sliced_dims_errcorr (sli ), sli
171+ )
159172
160173
161174def register_err_corr_form (form_name : str ) -> Callable :
@@ -179,7 +192,7 @@ class RandomCorrelation(BaseErrCorrForm):
179192 form = "random"
180193 is_random = True
181194
182- def build_matrix (self , sli ) :
195+ def build_matrix (self , sli : tuple ) -> np . ndarray :
183196 """
184197 Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
185198 in this parameterisation
@@ -204,7 +217,7 @@ class SystematicCorrelation(BaseErrCorrForm):
204217 form = "systematic"
205218 is_systematic = True
206219
207- def build_matrix (self , sli ) :
220+ def build_matrix (self , sli : tuple ) -> np . ndarray :
208221 """
209222 Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
210223 in this parameterisation
@@ -228,7 +241,7 @@ class ErrCorrMatrixCorrelation(BaseErrCorrForm):
228241
229242 form = "err_corr_matrix"
230243
231- def build_matrix (self , sli ) :
244+ def build_matrix (self , sli : tuple ) -> np . ndarray :
232245 """
233246 Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
234247 in this parameterisation
@@ -241,11 +254,21 @@ def build_matrix(self, sli):
241254 all_dims = self ._obj [self ._unc_var_name ].dims
242255 all_dims_sizes = self ._obj .sizes
243256
244- sli_submatrix = tuple ([sli [i ] for i in range (len (all_dims )) if all_dims [i ] in self .dims ])
257+ sli_submatrix = tuple (
258+ [sli [i ] for i in range (len (all_dims )) if all_dims [i ] in self .dims ]
259+ )
245260
246- sliced_shape = tuple ([all_dims_sizes [all_dims [i ]] for i in range (len (all_dims )) if all_dims [i ] in self .dims ])
261+ sliced_shape = tuple (
262+ [
263+ all_dims_sizes [all_dims [i ]]
264+ for i in range (len (all_dims ))
265+ if all_dims [i ] in self .dims
266+ ]
267+ )
247268
248- submatrix = self .slice_flattened_matrix (self ._obj [self .params [0 ]],sliced_shape ,sli_submatrix )
269+ submatrix = self .slice_flattened_matrix (
270+ self ._obj [self .params [0 ]], sliced_shape , sli_submatrix
271+ )
249272
250273 return submatrix
251274
@@ -255,7 +278,7 @@ class EnsembleCorrelation(BaseErrCorrForm):
255278
256279 form = "ensemble"
257280
258- def build_matrix (self , sli ) :
281+ def build_matrix (self , sli : tuple ) -> np . ndarray :
259282 """
260283 Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
261284 in this parameterisation
0 commit comments