Skip to content

Commit a6d7729

Browse files
authored
Merge pull request #177 from Jammy2211/feature/jax_interferometer
Feature/jax interferometer
2 parents aac0fbd + ff0a76a commit a6d7729

16 files changed

Lines changed: 628 additions & 590 deletions

File tree

autoarray/dataset/interferometer/dataset.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from autoarray.dataset.interferometer.w_tilde import WTildeInterferometer
1111
from autoarray.dataset.grids import GridsDataset
1212
from autoarray.operators.transformer import TransformerNUFFT
13-
13+
from autoarray.mask.mask_2d import Mask2D
1414
from autoarray.structures.visibilities import Visibilities
1515
from autoarray.structures.visibilities import VisibilitiesNoiseMap
1616

@@ -25,8 +25,9 @@ def __init__(
2525
data: Visibilities,
2626
noise_map: VisibilitiesNoiseMap,
2727
uv_wavelengths: np.ndarray,
28-
real_space_mask,
28+
real_space_mask: Mask2D,
2929
transformer_class=TransformerNUFFT,
30+
dft_preload_transform: bool = True,
3031
preprocessing_directory=None,
3132
):
3233
"""
@@ -73,6 +74,9 @@ def __init__(
7374
transformer_class
7475
The class of the Fourier Transform which maps images from real space to Fourier space visibilities and
7576
the uv-plane.
77+
dft_preload_transform
78+
If True, precomputes and stores the cosine and sine terms for the Fourier transform.
79+
This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets).
7680
"""
7781
self.real_space_mask = real_space_mask
7882

@@ -86,7 +90,9 @@ def __init__(
8690
self.uv_wavelengths = uv_wavelengths
8791

8892
self.transformer = transformer_class(
89-
uv_wavelengths=uv_wavelengths, real_space_mask=real_space_mask
93+
uv_wavelengths=uv_wavelengths,
94+
real_space_mask=real_space_mask,
95+
preload_transform=dft_preload_transform,
9096
)
9197

9298
self.preprocessing_directory = (
@@ -114,6 +120,7 @@ def from_fits(
114120
noise_map_hdu=0,
115121
uv_wavelengths_hdu=0,
116122
transformer_class=TransformerNUFFT,
123+
dft_preload_transform: bool = True,
117124
):
118125
"""
119126
Factory for loading the interferometer data_type from .fits files, as well as computing properties like the
@@ -139,6 +146,7 @@ def from_fits(
139146
noise_map=noise_map,
140147
uv_wavelengths=uv_wavelengths,
141148
transformer_class=transformer_class,
149+
dft_preload_transform=dft_preload_transform,
142150
)
143151

144152
def w_tilde_preprocessing(self):

autoarray/fit/fit_interferometer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def chi_squared(self) -> float:
113113
Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map.
114114
"""
115115
return fit_util.chi_squared_complex_from(
116-
chi_squared_map=self.chi_squared_map,
116+
chi_squared_map=self.chi_squared_map.array,
117117
)
118118

119119
@property

autoarray/fit/fit_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def chi_squared_complex_from(*, chi_squared_map: jnp.ndarray) -> float:
158158
chi_squared_map
159159
The chi-squared-map of values of the model-data fit to the dataset.
160160
"""
161-
chi_squared_real = jnp.sum(np.array(chi_squared_map.real))
162-
chi_squared_imag = jnp.sum(np.array(chi_squared_map.imag))
161+
chi_squared_real = jnp.sum(chi_squared_map.real)
162+
chi_squared_imag = jnp.sum(chi_squared_map.imag)
163163
return chi_squared_real + chi_squared_imag
164164

165165

autoarray/fit/plot/fit_interferometer_plotters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def figures_2d(
183183
auto_labels=AutoLabels(
184184
title="Model Visibilities", filename="model_data"
185185
),
186-
color_array=np.real(self.fit.model_data),
186+
color_array=np.real(self.fit.model_data.array),
187187
)
188188

189189
if residual_map_real:

autoarray/inversion/inversion/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def mapping_matrix(self) -> np.ndarray:
287287
If there are multiple linear objects, the mapping matrices are stacked such that their simultaneous linear
288288
equations are solved simultaneously. This property returns the stacked mapping matrix.
289289
"""
290-
return np.hstack(
290+
return jnp.hstack(
291291
[linear_obj.mapping_matrix for linear_obj in self.linear_obj_list]
292292
)
293293

autoarray/inversion/inversion/dataset_interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def __init__(
3636
noise_map
3737
An array describing the RMS standard deviation error in each pixel used for computing quantities like the
3838
chi-squared in a fit (in PyAutoGalaxy and PyAutoLens the recommended units are electrons per second).
39+
grids
40+
The grids of (y,x) Cartesian coordinates that the image data is paired with, which are used for evaluting
41+
light profiles and calculations associated with a pixelization.
3942
over_sampler
4043
Performs over-sampling whereby the masked image pixels are split into sub-pixels, which are all
4144
mapped via the mapper with sub-fractional values of flux.
@@ -50,9 +53,6 @@ def __init__(
5053
w_tilde
5154
The w_tilde matrix used by the w-tilde formalism to construct the data vector and
5255
curvature matrix during an inversion efficiently..
53-
grids
54-
The grids of (y,x) Cartesian coordinates that the image data is paired with, which are used for evaluting
55-
light profiles and calculations associated with a pixelization.
5656
noise_covariance_matrix
5757
A noise-map covariance matrix representing the covariance between noise in every `data` value, which
5858
can be used via a bespoke fit to account for correlated noise in the data.

autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ def w_tilde_via_preload_from(w_tilde_preload, native_index_for_slim_index):
387387
return w_tilde_via_preload
388388

389389

390-
@numba_util.jit()
391390
def data_vector_via_transformed_mapping_matrix_from(
392391
transformed_mapping_matrix: np.ndarray,
393392
visibilities: np.ndarray,
@@ -406,31 +405,24 @@ def data_vector_via_transformed_mapping_matrix_from(
406405
noise_map
407406
Flattened 1D array of the noise-map used by the inversion during the fit.
408407
"""
408+
# Extract components
409+
vis_real = visibilities.real
410+
vis_imag = visibilities.imag
411+
f_real = transformed_mapping_matrix.real
412+
f_imag = transformed_mapping_matrix.imag
413+
noise_real = noise_map.real
414+
noise_imag = noise_map.imag
409415

410-
data_vector = np.zeros(transformed_mapping_matrix.shape[1])
411-
412-
visibilities_real = visibilities.real
413-
visibilities_imag = visibilities.imag
414-
transformed_mapping_matrix_real = transformed_mapping_matrix.real
415-
transformed_mapping_matrix_imag = transformed_mapping_matrix.imag
416-
noise_map_real = noise_map.real
417-
noise_map_imag = noise_map.imag
418-
419-
for vis_1d_index in range(transformed_mapping_matrix.shape[0]):
420-
for pix_1d_index in range(transformed_mapping_matrix.shape[1]):
421-
real_value = (
422-
visibilities_real[vis_1d_index]
423-
* transformed_mapping_matrix_real[vis_1d_index, pix_1d_index]
424-
/ (noise_map_real[vis_1d_index] ** 2.0)
425-
)
426-
imag_value = (
427-
visibilities_imag[vis_1d_index]
428-
* transformed_mapping_matrix_imag[vis_1d_index, pix_1d_index]
429-
/ (noise_map_imag[vis_1d_index] ** 2.0)
430-
)
431-
data_vector[pix_1d_index] += real_value + imag_value
416+
# Square noise components
417+
inv_var_real = 1.0 / (noise_real**2)
418+
inv_var_imag = 1.0 / (noise_imag**2)
432419

433-
return data_vector
420+
# Real and imaginary contributions
421+
weighted_real = (vis_real * inv_var_real)[:, None] * f_real
422+
weighted_imag = (vis_imag * inv_var_imag)[:, None] * f_imag
423+
424+
# Sum over visibilities
425+
return np.sum(weighted_real + weighted_imag, axis=0)
434426

435427

436428
@numba_util.jit()
@@ -512,7 +504,6 @@ def curvature_matrix_via_w_tilde_curvature_preload_interferometer_from(
512504
return curvature_matrix
513505

514506

515-
@numba_util.jit()
516507
def mapped_reconstructed_visibilities_from(
517508
transformed_mapping_matrix: np.ndarray, reconstruction: np.ndarray
518509
) -> np.ndarray:
@@ -525,20 +516,7 @@ def mapped_reconstructed_visibilities_from(
525516
The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels.
526517
527518
"""
528-
mapped_reconstructed_visibilities = (0.0 + 0.0j) * np.zeros(
529-
transformed_mapping_matrix.shape[0]
530-
)
531-
532-
transformed_mapping_matrix_real = transformed_mapping_matrix.real
533-
transformed_mapping_matrix_imag = transformed_mapping_matrix.imag
534-
535-
for i in range(transformed_mapping_matrix.shape[0]):
536-
for j in range(reconstruction.shape[0]):
537-
mapped_reconstructed_visibilities[i] += (
538-
reconstruction[j] * transformed_mapping_matrix_real[i, j]
539-
) + 1.0j * (reconstruction[j] * transformed_mapping_matrix_imag[i, j])
540-
541-
return mapped_reconstructed_visibilities
519+
return transformed_mapping_matrix @ reconstruction
542520

543521

544522
"""

autoarray/inversion/inversion/interferometer/mapping.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import jax.numpy as jnp
12
import numpy as np
23
from typing import Dict, List, Optional, Union
34

@@ -76,8 +77,8 @@ def data_vector(self) -> np.ndarray:
7677
"""
7778

7879
return inversion_interferometer_util.data_vector_via_transformed_mapping_matrix_from(
79-
transformed_mapping_matrix=np.array(self.operated_mapping_matrix),
80-
visibilities=np.array(self.data),
80+
transformed_mapping_matrix=self.operated_mapping_matrix,
81+
visibilities=self.data,
8182
noise_map=np.array(self.noise_map),
8283
)
8384

@@ -106,13 +107,13 @@ def curvature_matrix(self) -> np.ndarray:
106107
noise_map=self.noise_map.imag,
107108
)
108109

109-
curvature_matrix = np.add(real_curvature_matrix, imag_curvature_matrix)
110+
curvature_matrix = jnp.add(real_curvature_matrix, imag_curvature_matrix)
110111

111112
if len(self.no_regularization_index_list) > 0:
112113
curvature_matrix = inversion_util.curvature_matrix_with_added_to_diag_from(
113114
curvature_matrix=curvature_matrix,
114-
no_regularization_index_list=self.no_regularization_index_list,
115115
value=self.settings.no_regularization_add_to_curvature_diag_value,
116+
no_regularization_index_list=self.no_regularization_index_list,
116117
)
117118

118119
return curvature_matrix
@@ -152,10 +153,8 @@ def mapped_reconstructed_data_dict(
152153

153154
visibilities = (
154155
inversion_interferometer_util.mapped_reconstructed_visibilities_from(
155-
transformed_mapping_matrix=np.array(
156-
operated_mapping_matrix_list[index]
157-
),
158-
reconstruction=np.array(reconstruction),
156+
transformed_mapping_matrix=operated_mapping_matrix_list[index],
157+
reconstruction=reconstruction,
159158
)
160159
)
161160

autoarray/inversion/inversion/interferometer/w_tilde.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ def data_vector(self) -> np.ndarray:
8585
8686
The calculation is described in more detail in `inversion_util.w_tilde_data_interferometer_from`.
8787
"""
88-
return np.dot(
89-
self.linear_obj_list[0].mapping_matrix.T, self.w_tilde.dirty_image
90-
)
88+
return np.dot(self.mapping_matrix.T, self.w_tilde.dirty_image)
9189

9290
@cached_property
9391
@profile_func

0 commit comments

Comments
 (0)