Skip to content

Commit 3f90a44

Browse files
authored
Merge pull request #14 from comet-toolkit/small_fixes
Small fixes
2 parents 251cae1 + 25a2062 commit 3f90a44

File tree

13 files changed

+336
-99
lines changed

13 files changed

+336
-99
lines changed

.github/workflows/pull_request.yml

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions
2+
name: Pull Request
3+
4+
on:
5+
pull_request:
6+
branches:
7+
- main
8+
push:
9+
branches:
10+
- main
11+
12+
jobs:
13+
lint_code:
14+
runs-on: ubuntu-latest
15+
strategy:
16+
matrix:
17+
python-version: ["3.11"]
18+
steps:
19+
- uses: actions/checkout@v3
20+
- name: Set up Python ${{ matrix.python-version }}
21+
uses: actions/setup-python@v3
22+
with:
23+
python-version: ${{ matrix.python-version }}
24+
- name: Install dependencies
25+
run: |
26+
python -m pip install --upgrade pip
27+
pip install .[dev]
28+
- name: Analysing the code with pre-commit lint checks
29+
run: |
30+
pre-commit run -a
31+
32+
test_code_python3p8:
33+
runs-on: ubuntu-latest
34+
strategy:
35+
matrix:
36+
python-version: ["3.8"]
37+
steps:
38+
- uses: actions/checkout@v4
39+
- name: Set up Python ${{ matrix.python-version }}
40+
uses: actions/setup-python@v3
41+
with:
42+
python-version: ${{ matrix.python-version }}
43+
- name: Install dependencies
44+
run: |
45+
python -m pip install --upgrade pip
46+
pip install .[dev]
47+
- name: Test code
48+
run: |
49+
mkdir test_report
50+
tox
51+
52+
53+
test_code_and_coverage_report_python3p11:
54+
runs-on: ubuntu-latest
55+
strategy:
56+
matrix:
57+
python-version: ["3.11"]
58+
steps:
59+
- uses: actions/checkout@v4
60+
- name: Set up Python ${{ matrix.python-version }}
61+
uses: actions/setup-python@v3
62+
with:
63+
python-version: ${{ matrix.python-version }}
64+
- name: Install dependencies
65+
run: |
66+
python -m pip install --upgrade pip
67+
pip install .[dev]
68+
- name: Test code
69+
run: |
70+
mkdir test_report
71+
tox
72+
- name: html to pdf
73+
uses: fifsky/html-to-pdf-action@master
74+
with:
75+
htmlFile: test_report/cov_report/index.html
76+
outputFile: test_report/cov_report/cov_report.pdf
77+
pdfOptions: '{"format": "A4", "margin": {"top": "10mm", "left": "10mm", "right": "10mm", "bottom": "10mm"}}'
78+
- name: Archive code coverage results
79+
uses: actions/upload-artifact@v4
80+
with:
81+
name: code-coverage-report
82+
path: test_report/cov_report/cov_report.pdf

.github/workflows/push_branch.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions
2+
name: Push to branch
3+
4+
on:
5+
push:
6+
branches:
7+
- '*'
8+
9+
jobs:
10+
test_code:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ["3.12"]
15+
steps:
16+
- uses: actions/checkout@v4
17+
- name: Set up Python ${{ matrix.python-version }}
18+
uses: actions/setup-python@v3
19+
with:
20+
python-version: ${{ matrix.python-version }}
21+
- name: Install dependencies
22+
run: |
23+
python -m pip install --upgrade pip
24+
pip install .[dev]
25+
- name: Test code
26+
run: |
27+
tox

obsarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from obsarray.err_corr import err_corr_forms
99
from obsarray.templater.template_util import create_ds
1010
from obsarray.templater.dstemplater import DSTemplater
11+
from obsarray.templater.dswriter import DSWriter
1112

1213
__version__ = get_versions()["version"]
1314
del get_versions

obsarray/err_corr.py

Lines changed: 109 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""err_corr_forms - module for the defintion of error-correlation parameterisation forms"""
22

33
import abc
4-
from typing import Callable, Type, Union
4+
from typing import Callable, Type, Union, List
55
import numpy as np
66
from comet_maths.linear_algebra.matrix_conversion import expand_errcorr_dims
77

@@ -69,23 +69,79 @@ def form(self) -> str:
6969
"""Form name"""
7070
pass
7171

72-
def expand_dim_matrix(self, submatrix, sli):
72+
def expand_dim_matrix(
73+
self, submatrix: np.ndarray, submatrix_dim: Union[str, List[str]], sli: tuple
74+
):
7375
return expand_errcorr_dims(
7476
in_corr=submatrix,
75-
in_dim=self.dims,
77+
in_dim=submatrix_dim,
7678
out_dim=list(self._obj[self._unc_var_name][sli].dims),
77-
dim_sizes={
78-
dim: self._obj.dims[dim]
79-
for dim in self._obj[self._unc_var_name][sli].dims
80-
},
79+
dim_sizes=self.get_sliced_dim_sizes_uncvar(sli),
8180
)
8281

83-
def slice_full_cov(self, full_matrix, sli):
84-
mask_array = np.ones(self._obj[self._unc_var_name].shape, dtype=bool)
82+
def get_sliced_dim_sizes_uncvar(self, sli: tuple) -> dict:
83+
"""
84+
Return dictionary with sizes of sliced dimensions of unc variable, including all dimensions.
85+
86+
:param sli: slice (tuple with slice for each dimension)
87+
:return: shape of included sliced dimensions
88+
"""
89+
uncvar_dims = self._obj[self._unc_var_name][sli].dims
90+
uncvar_shape = self._obj[self._unc_var_name][sli].shape
91+
return {
92+
uncvar_dims[idim]: uncvar_shape[idim] for idim in range(len(uncvar_dims))
93+
}
94+
95+
def get_sliced_dim_sizes_errcorr(self, sli: tuple) -> dict:
96+
"""
97+
Return dictionary with sizes of sliced dimensions of unc variable, including only dimensions which are
98+
included in the current error correlation form.
99+
100+
:param sli: slice (tuple with slice for each dimension)
101+
:return: shape of included sliced dimensions
102+
"""
103+
uncvar_sizes = self.get_sliced_dim_sizes_uncvar(sli)
104+
sliced_dims = self.get_sliced_dims_errcorr(sli)
105+
106+
return {dim: uncvar_sizes[dim] for dim in sliced_dims}
107+
108+
def get_sliced_dims_errcorr(self, sli: tuple) -> list:
109+
"""
110+
Return dimensions which are within the slice and included in the current error correlation form.
111+
112+
:param sli: slice (tuple with slice for each dimension)
113+
:return: list with sliced dimensions
114+
"""
115+
all_dims = self._obj[self._unc_var_name].dims
116+
return [
117+
all_dims[idim]
118+
for idim in range(len(all_dims))
119+
if (isinstance(sli[idim], slice) and all_dims[idim] in self.dims)
120+
]
121+
122+
def get_sliced_shape_errcorr(self, sli: tuple) -> tuple:
123+
"""
124+
return shape of sliced uncertainty variable, including only dimensions which are included in the current error correlation form.
125+
126+
:param sli: slice (tuple with slice for each dimension)
127+
:return: shape of included sliced dimensions
128+
"""
129+
uncvar_sizes = self.get_sliced_dim_sizes_uncvar(sli)
130+
sliced_dims = self.get_sliced_dims_errcorr(sli)
131+
132+
return tuple([uncvar_sizes[dim] for dim in sliced_dims])
133+
134+
def slice_full_cov(self, full_matrix: np.ndarray, sli: tuple) -> np.ndarray:
135+
return self.slice_flattened_matrix(
136+
full_matrix, self._obj[self._unc_var_name].shape, sli
137+
)
138+
139+
def slice_flattened_matrix(self, flattened_matrix, variable_shape, sli):
140+
mask_array = np.ones(variable_shape, dtype=bool)
85141
mask_array[sli] = False
86142

87143
return np.delete(
88-
np.delete(full_matrix, mask_array.ravel(), 0), mask_array.ravel(), 1
144+
np.delete(flattened_matrix, mask_array.ravel(), 0), mask_array.ravel(), 1
89145
)
90146

91147
@abc.abstractmethod
@@ -100,6 +156,20 @@ def build_matrix(self, sli: Union[np.ndarray, tuple]) -> np.ndarray:
100156
"""
101157
pass
102158

159+
def build_dot_matrix(self, sli: Union[np.ndarray, tuple]) -> np.ndarray:
160+
"""
161+
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
162+
in this parameterisation
163+
164+
:param sli: slice of observation variable to return error-correlation matrix for
165+
166+
:return: populated error-correlation matrix
167+
"""
168+
169+
return self.expand_dim_matrix(
170+
self.build_matrix(sli), self.get_sliced_dims_errcorr(sli), sli
171+
)
172+
103173

104174
def register_err_corr_form(form_name: str) -> Callable:
105175
"""
@@ -122,7 +192,7 @@ class RandomCorrelation(BaseErrCorrForm):
122192
form = "random"
123193
is_random = True
124194

125-
def build_matrix(self, sli):
195+
def build_matrix(self, sli: tuple) -> np.ndarray:
126196
"""
127197
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
128198
in this parameterisation
@@ -133,16 +203,12 @@ def build_matrix(self, sli):
133203
"""
134204

135205
# evaluate correlation over matrices in form defintion
136-
dim_lens = [len(self._obj[dim]) for dim in self.dims]
206+
dim_lens = self.get_sliced_shape_errcorr(sli)
137207
n_elems = int(np.prod(dim_lens))
138208

139-
dims_matrix = np.eye(n_elems)
140-
141-
# expand to correlation matrix over all variable dims
142-
return self.expand_dim_matrix(dims_matrix, sli)
209+
submatrix = np.eye(n_elems)
143210

144-
# # subset to slice
145-
# return self.slice_full_cov(full_matrix, sli)
211+
return submatrix
146212

147213

148214
@register_err_corr_form("systematic")
@@ -151,7 +217,7 @@ class SystematicCorrelation(BaseErrCorrForm):
151217
form = "systematic"
152218
is_systematic = True
153219

154-
def build_matrix(self, sli):
220+
def build_matrix(self, sli: tuple) -> np.ndarray:
155221
"""
156222
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
157223
in this parameterisation
@@ -162,24 +228,20 @@ def build_matrix(self, sli):
162228
"""
163229

164230
# evaluate correlation over matrices in form defintion
165-
dim_lens = [len(self._obj[dim]) for dim in self.dims]
231+
dim_lens = self.get_sliced_shape_errcorr(sli)
166232
n_elems = int(np.prod(dim_lens))
167233

168-
dims_matrix = np.ones((n_elems, n_elems))
169-
170-
# expand to correlation matrix over all variable dims
171-
return self.expand_dim_matrix(dims_matrix, sli)
234+
submatrix = np.ones((n_elems, n_elems))
172235

173-
# subset to slice
174-
# return self.slice_full_cov(full_matrix, sli)
236+
return submatrix
175237

176238

177239
@register_err_corr_form("err_corr_matrix")
178240
class ErrCorrMatrixCorrelation(BaseErrCorrForm):
179241

180242
form = "err_corr_matrix"
181243

182-
def build_matrix(self, sli):
244+
def build_matrix(self, sli: tuple) -> np.ndarray:
183245
"""
184246
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
185247
in this parameterisation
@@ -189,19 +251,34 @@ def build_matrix(self, sli):
189251
:return: populated error-correlation matrix
190252
"""
191253

192-
# expand to correlation matrix over all variable dims
193-
return self.expand_dim_matrix(self._obj[self.params[0]], sli)
254+
all_dims = self._obj[self._unc_var_name].dims
255+
all_dims_sizes = self._obj.sizes
256+
257+
sli_submatrix = tuple(
258+
[sli[i] for i in range(len(all_dims)) if all_dims[i] in self.dims]
259+
)
260+
261+
sliced_shape = tuple(
262+
[
263+
all_dims_sizes[all_dims[i]]
264+
for i in range(len(all_dims))
265+
if all_dims[i] in self.dims
266+
]
267+
)
268+
269+
submatrix = self.slice_flattened_matrix(
270+
self._obj[self.params[0]], sliced_shape, sli_submatrix
271+
)
194272

195-
# # subset to slice
196-
# return self.slice_full_cov(full_matrix, sli)
273+
return submatrix
197274

198275

199276
@register_err_corr_form("ensemble")
200277
class EnsembleCorrelation(BaseErrCorrForm):
201278

202279
form = "ensemble"
203280

204-
def build_matrix(self, sli):
281+
def build_matrix(self, sli: tuple) -> np.ndarray:
205282
"""
206283
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
207284
in this parameterisation

obsarray/flag_accessor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,12 @@ def __setitem__(self, flag_meaning: str, flag_value: Union[bool, np.ndarray]):
194194
)
195195

196196
if flag_meaning not in flag_meanings:
197-
self._obj[
198-
self._flag_var_name
199-
].attrs = DatasetUtil.add_flag_meaning_to_attrs(
200-
self._obj[self._flag_var_name].attrs,
201-
flag_meaning,
202-
self._obj[self._flag_var_name].dtype,
197+
self._obj[self._flag_var_name].attrs = (
198+
DatasetUtil.add_flag_meaning_to_attrs(
199+
self._obj[self._flag_var_name].attrs,
200+
flag_meaning,
201+
self._obj[self._flag_var_name].dtype,
202+
)
203203
)
204204

205205
self[flag_meaning][:] = flag_value

0 commit comments

Comments
 (0)