Skip to content

Commit 0c449ba

Browse files
committed
add type annotations
1 parent cf7f7fa commit 0c449ba

File tree

1 file changed

+47
-24
lines changed

1 file changed

+47
-24
lines changed

obsarray/err_corr.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""err_corr_forms - module for the defintion of error-correlation parameterisation forms"""
22

33
import abc
4-
from typing import Callable, Type, Union
4+
from typing import Callable, Type, Union, List
55
import numpy as np
66
from comet_maths.linear_algebra.matrix_conversion import expand_errcorr_dims
77

@@ -69,61 +69,72 @@ def form(self) -> str:
6969
"""Form name"""
7070
pass
7171

72-
def expand_dim_matrix(self, submatrix, submatrix_dim, sli):
72+
def expand_dim_matrix(
73+
self, submatrix: np.ndarray, submatrix_dim: Union[str, List[str]], sli: tuple
74+
):
7375
return expand_errcorr_dims(
7476
in_corr=submatrix,
7577
in_dim=submatrix_dim,
7678
out_dim=list(self._obj[self._unc_var_name][sli].dims),
77-
dim_sizes=self.get_sliced_dim_sizes_uncvar(sli)
79+
dim_sizes=self.get_sliced_dim_sizes_uncvar(sli),
7880
)
7981

80-
def get_sliced_dim_sizes_uncvar(self, sli) -> dict:
82+
def get_sliced_dim_sizes_uncvar(self, sli: tuple) -> dict:
8183
"""
82-
return dictionary with sizes of sliced dimensions of unc variable, including all dimensions.
84+
Return dictionary with sizes of sliced dimensions of unc variable, including all dimensions.
8385
8486
:param sli: slice (tuple with slice for each dimension)
8587
:return: shape of included sliced dimensions
8688
"""
8789
uncvar_dims = self._obj[self._unc_var_name][sli].dims
8890
uncvar_shape = self._obj[self._unc_var_name][sli].shape
89-
return {uncvar_dims[idim]: uncvar_shape[idim] for idim in range(len(uncvar_dims))}
91+
return {
92+
uncvar_dims[idim]: uncvar_shape[idim] for idim in range(len(uncvar_dims))
93+
}
9094

91-
def get_sliced_dim_sizes_errcorr(self, sli) -> dict:
95+
def get_sliced_dim_sizes_errcorr(self, sli: tuple) -> dict:
9296
"""
93-
return dictionary with sizes of sliced dimensions of unc variable, including only dimensions which are included in the current error correlation form.
97+
Return dictionary with sizes of sliced dimensions of unc variable, including only dimensions which are
98+
included in the current error correlation form.
9499
95100
:param sli: slice (tuple with slice for each dimension)
96101
:return: shape of included sliced dimensions
97102
"""
98-
uncvar_sizes=self.get_sliced_dim_sizes_uncvar(sli)
103+
uncvar_sizes = self.get_sliced_dim_sizes_uncvar(sli)
99104
sliced_dims = self.get_sliced_dims_errcorr(sli)
100105

101106
return {dim: uncvar_sizes[dim] for dim in sliced_dims}
102107

103-
def get_sliced_dims_errcorr(self, sli) -> list:
108+
def get_sliced_dims_errcorr(self, sli: tuple) -> list:
104109
"""
105-
return dimensions which are within the slice and included in the current error correlation form.
110+
Return dimensions which are within the slice and included in the current error correlation form.
106111
107112
:param sli: slice (tuple with slice for each dimension)
108113
:return: list with sliced dimensions
109114
"""
110115
all_dims = self._obj[self._unc_var_name].dims
111-
return [all_dims[idim] for idim in range(len(all_dims)) if (isinstance(sli[idim],slice) and all_dims[idim] in self.dims)]
116+
return [
117+
all_dims[idim]
118+
for idim in range(len(all_dims))
119+
if (isinstance(sli[idim], slice) and all_dims[idim] in self.dims)
120+
]
112121

113-
def get_sliced_shape_errcorr(self, sli)->tuple:
122+
def get_sliced_shape_errcorr(self, sli: tuple) -> tuple:
114123
"""
115124
return shape of sliced uncertainty variable, including only dimensions which are included in the current error correlation form.
116125
117126
:param sli: slice (tuple with slice for each dimension)
118127
:return: shape of included sliced dimensions
119128
"""
120-
uncvar_sizes=self.get_sliced_dim_sizes_uncvar(sli)
129+
uncvar_sizes = self.get_sliced_dim_sizes_uncvar(sli)
121130
sliced_dims = self.get_sliced_dims_errcorr(sli)
122131

123132
return tuple([uncvar_sizes[dim] for dim in sliced_dims])
124133

125-
def slice_full_cov(self, full_matrix, sli):
126-
return self.slice_flattened_matrix(full_matrix,self._obj[self._unc_var_name].shape,sli)
134+
def slice_full_cov(self, full_matrix: np.ndarrary, sli: tuple) -> np.ndarray:
135+
return self.slice_flattened_matrix(
136+
full_matrix, self._obj[self._unc_var_name].shape, sli
137+
)
127138

128139
def slice_flattened_matrix(self, flattened_matrix, variable_shape, sli):
129140
mask_array = np.ones(variable_shape, dtype=bool)
@@ -155,7 +166,9 @@ def build_dot_matrix(self, sli: Union[np.ndarray, tuple]) -> np.ndarray:
155166
:return: populated error-correlation matrix
156167
"""
157168

158-
return self.expand_dim_matrix(self.build_matrix(sli), self.get_sliced_dims_errcorr(sli), sli)
169+
return self.expand_dim_matrix(
170+
self.build_matrix(sli), self.get_sliced_dims_errcorr(sli), sli
171+
)
159172

160173

161174
def register_err_corr_form(form_name: str) -> Callable:
@@ -179,7 +192,7 @@ class RandomCorrelation(BaseErrCorrForm):
179192
form = "random"
180193
is_random = True
181194

182-
def build_matrix(self, sli):
195+
def build_matrix(self, sli: tuple) -> np.ndarray:
183196
"""
184197
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
185198
in this parameterisation
@@ -204,7 +217,7 @@ class SystematicCorrelation(BaseErrCorrForm):
204217
form = "systematic"
205218
is_systematic = True
206219

207-
def build_matrix(self, sli):
220+
def build_matrix(self, sli: tuple) -> np.ndarray:
208221
"""
209222
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
210223
in this parameterisation
@@ -228,7 +241,7 @@ class ErrCorrMatrixCorrelation(BaseErrCorrForm):
228241

229242
form = "err_corr_matrix"
230243

231-
def build_matrix(self, sli):
244+
def build_matrix(self, sli: tuple) -> np.ndarray:
232245
"""
233246
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
234247
in this parameterisation
@@ -241,11 +254,21 @@ def build_matrix(self, sli):
241254
all_dims = self._obj[self._unc_var_name].dims
242255
all_dims_sizes = self._obj.sizes
243256

244-
sli_submatrix = tuple([sli[i] for i in range(len(all_dims)) if all_dims[i] in self.dims])
257+
sli_submatrix = tuple(
258+
[sli[i] for i in range(len(all_dims)) if all_dims[i] in self.dims]
259+
)
245260

246-
sliced_shape = tuple([all_dims_sizes[all_dims[i]] for i in range(len(all_dims)) if all_dims[i] in self.dims])
261+
sliced_shape = tuple(
262+
[
263+
all_dims_sizes[all_dims[i]]
264+
for i in range(len(all_dims))
265+
if all_dims[i] in self.dims
266+
]
267+
)
247268

248-
submatrix = self.slice_flattened_matrix(self._obj[self.params[0]],sliced_shape,sli_submatrix)
269+
submatrix = self.slice_flattened_matrix(
270+
self._obj[self.params[0]], sliced_shape, sli_submatrix
271+
)
249272

250273
return submatrix
251274

@@ -255,7 +278,7 @@ class EnsembleCorrelation(BaseErrCorrForm):
255278

256279
form = "ensemble"
257280

258-
def build_matrix(self, sli):
281+
def build_matrix(self, sli: tuple) -> np.ndarray:
259282
"""
260283
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
261284
in this parameterisation

0 commit comments

Comments
 (0)