Skip to content

Commit e3ce0f7

Browse files
authored
Merge pull request #202 from Jammy2211/feature/jax_w_tilde_preload
Feature/jax w tilde preload
2 parents 4ca111f + 6145a2a commit e3ce0f7

5 files changed

Lines changed: 624 additions & 105 deletions

File tree

autoarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from . import util
1010
from . import fixtures
1111
from . import mock as m
12+
from .dataset.interferometer.w_tilde import load_curvature_preload_if_compatible
1213
from .dataset import preprocess
1314
from .dataset.abstract.dataset import AbstractDataset
1415
from .dataset.abstract.w_tilde import AbstractWTilde

autoarray/dataset/interferometer/dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def apply_w_tilde(
164164
batch_size: int = 128,
165165
show_progress: bool = False,
166166
show_memory: bool = False,
167+
use_jax: bool = False,
167168
):
168169
"""
169170
The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities
@@ -192,7 +193,7 @@ def apply_w_tilde(
192193

193194
if curvature_preload is None:
194195

195-
logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")
196+
logger.info("INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours.")
196197

197198
curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
198199
noise_map_real=self.noise_map.array.real,
@@ -201,6 +202,7 @@ def apply_w_tilde(
201202
grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array,
202203
show_memory=show_memory,
203204
show_progress=show_progress,
205+
use_jax=use_jax,
204206
)
205207

206208
dirty_image = self.transformer.image_from(

autoarray/dataset/interferometer/w_tilde.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,207 @@
1+
import json
2+
import hashlib
13
import numpy as np
4+
from pathlib import Path
5+
from typing import Any, Dict, Optional, Tuple, Union
26

37
from autoarray.dataset.abstract.w_tilde import AbstractWTilde
48
from autoarray.mask.mask_2d import Mask2D
59

610

11+
def _bbox_from_mask(mask_bool: np.ndarray) -> Tuple[int, int, int, int]:
12+
"""
13+
Return bbox (y_min, y_max, x_min, x_max) of the unmasked region.
14+
mask_bool: True=masked, False=unmasked
15+
"""
16+
ys, xs = np.where(~mask_bool)
17+
if ys.size == 0:
18+
raise ValueError("Mask has no unmasked pixels; cannot compute bbox.")
19+
return int(ys.min()), int(ys.max()), int(xs.min()), int(xs.max())
20+
21+
22+
def _mask_sha256(mask_bool: np.ndarray) -> str:
23+
"""
24+
Stable hash of the full boolean mask content (not just bbox).
25+
"""
26+
# Ensure contiguous, stable dtype
27+
arr = np.ascontiguousarray(mask_bool.astype(np.uint8))
28+
return hashlib.sha256(arr.tobytes()).hexdigest()
29+
30+
31+
def _as_pixel_scales_tuple(pixel_scales) -> Tuple[float, float]:
32+
"""
33+
Normalize pixel_scales to a stable 2-tuple of float.
34+
Works with AutoArray pixel_scales objects or raw tuples.
35+
"""
36+
try:
37+
# autoarray typically stores pixel_scales as tuple-like
38+
return (float(pixel_scales[0]), float(pixel_scales[1]))
39+
except Exception:
40+
# fallback: treat as scalar
41+
s = float(pixel_scales)
42+
return (s, s)
43+
44+
45+
def _np_float_tuple(x) -> Tuple[float, float]:
46+
return (float(x[0]), float(x[1]))
47+
48+
49+
def curvature_preload_metadata_from(real_space_mask) -> Dict[str, Any]:
50+
"""
51+
Build the minimal metadata required to decide whether a stored curvature_preload
52+
can be reused for the current WTildeInterferometer instance.
53+
54+
The preload depends on:
55+
- the *rectangular FFT grid extent* used for offset evaluation (bbox / extent)
56+
- pixel scales (radians per pixel)
57+
- (usually) the exact mask shape and content (recommended to hash)
58+
59+
Returns
60+
-------
61+
dict
62+
JSON-serializable metadata.
63+
"""
64+
mask_bool = np.asarray(real_space_mask, dtype=bool)
65+
y_min, y_max, x_min, x_max = _bbox_from_mask(mask_bool)
66+
y_extent = y_max - y_min + 1
67+
x_extent = x_max - x_min + 1
68+
69+
pixel_scales = _as_pixel_scales_tuple(real_space_mask.pixel_scales)
70+
71+
meta = {
72+
"format": "autoarray.w_tilde.curvature_preload.v1",
73+
"mask_shape": tuple(mask_bool.shape),
74+
"pixel_scales": pixel_scales,
75+
"bbox_unmasked": (y_min, y_max, x_min, x_max),
76+
"rect_shape": (y_extent, x_extent),
77+
# full-content hash: safest way to prevent accidental reuse
78+
"mask_sha256": _mask_sha256(mask_bool),
79+
}
80+
return meta
81+
82+
83+
def is_preload_metadata_compatible(
84+
real_space_mask,
85+
meta: Dict[str, Any],
86+
*,
87+
require_mask_hash: bool = True,
88+
atol: float = 0.0,
89+
) -> Tuple[bool, str]:
90+
"""
91+
Compare loaded metadata against current instance.
92+
93+
Parameters
94+
----------
95+
meta
96+
Metadata dict loaded from disk.
97+
require_mask_hash
98+
If True, require the full mask sha256 to match (safest).
99+
If False, only check bbox + shape + pixel scales.
100+
atol
101+
Tolerances for pixel scale comparisons (normally exact is fine
102+
because these are configuration constants, but tolerances allow
103+
for tiny float repr differences).
104+
105+
Returns
106+
-------
107+
(ok, reason)
108+
ok: bool, True if compatible
109+
reason: str, human-readable mismatch reason if not ok.
110+
"""
111+
current = curvature_preload_metadata_from(real_space_mask=real_space_mask)
112+
113+
# 1) format version
114+
if meta.get("format") != current["format"]:
115+
return False, f"format mismatch: {meta.get('format')} != {current['format']}"
116+
117+
# 2) mask shape
118+
if tuple(meta.get("mask_shape", ())) != tuple(current["mask_shape"]):
119+
return (
120+
False,
121+
f"mask_shape mismatch: {meta.get('mask_shape')} != {current['mask_shape']}",
122+
)
123+
124+
# 3) pixel scales
125+
ps_saved = _np_float_tuple(meta.get("pixel_scales", (np.nan, np.nan)))
126+
ps_curr = _np_float_tuple(current["pixel_scales"])
127+
128+
if not (
129+
np.isclose(ps_saved[0], ps_curr[0], atol=atol)
130+
and np.isclose(ps_saved[1], ps_curr[1], atol=atol)
131+
):
132+
return False, f"pixel_scales mismatch: {ps_saved} != {ps_curr}"
133+
134+
# 4) bbox / rect shape
135+
if tuple(meta.get("bbox_unmasked", ())) != tuple(current["bbox_unmasked"]):
136+
return (
137+
False,
138+
f"bbox_unmasked mismatch: {meta.get('bbox_unmasked')} != {current['bbox_unmasked']}",
139+
)
140+
141+
if tuple(meta.get("rect_shape", ())) != tuple(current["rect_shape"]):
142+
return (
143+
False,
144+
f"rect_shape mismatch: {meta.get('rect_shape')} != {current['rect_shape']}",
145+
)
146+
147+
# 5) full mask hash (optional but recommended)
148+
if require_mask_hash:
149+
if meta.get("mask_sha256") != current["mask_sha256"]:
150+
return False, "mask_sha256 mismatch (mask content differs)"
151+
152+
return True, "ok"
153+
154+
155+
def load_curvature_preload_if_compatible(
156+
file: Union[str, Path],
157+
real_space_mask,
158+
*,
159+
require_mask_hash: bool = True,
160+
) -> Optional[np.ndarray]:
161+
"""
162+
Load a saved curvature_preload if (and only if) it is compatible with the current mask geometry.
163+
164+
Parameters
165+
----------
166+
file
167+
Path to a previously saved NPZ.
168+
require_mask_hash
169+
If True, require the full mask content hash to match (safest).
170+
If False, only bbox + shape + pixel scales are checked.
171+
172+
Returns
173+
-------
174+
np.ndarray
175+
The loaded curvature_preload if compatible, otherwise raises ValueError.
176+
"""
177+
file = Path(file)
178+
if file.suffix.lower() != ".npz":
179+
file = file.with_suffix(".npz")
180+
181+
if not file.exists():
182+
raise FileNotFoundError(str(file))
183+
184+
with np.load(file, allow_pickle=False) as npz:
185+
if "curvature_preload" not in npz or "meta_json" not in npz:
186+
msg = f"File does not contain required fields: {file}"
187+
raise ValueError(msg)
188+
189+
meta_json = str(npz["meta_json"].item())
190+
meta = json.loads(meta_json)
191+
192+
ok, reason = is_preload_metadata_compatible(
193+
meta=meta,
194+
real_space_mask=real_space_mask,
195+
require_mask_hash=require_mask_hash,
196+
atol=1.0e-8,
197+
)
198+
199+
if not ok:
200+
raise ValueError(f"curvature_preload incompatible: {reason}")
201+
202+
return np.asarray(npz["curvature_preload"])
203+
204+
7205
class WTildeInterferometer(AbstractWTilde):
8206
def __init__(
9207
self,
@@ -122,3 +320,49 @@ def rect_index_for_mask_index(self) -> np.ndarray:
122320
)
123321

124322
return rect_indices
323+
324+
def save_curvature_preload(
325+
self,
326+
file: Union[str, Path],
327+
*,
328+
overwrite: bool = False,
329+
) -> Path:
330+
"""
331+
Save curvature_preload plus enough metadata to ensure it is only reused when safe.
332+
333+
Uses NPZ so we can store:
334+
- curvature_preload (array)
335+
- meta_json (string)
336+
337+
Parameters
338+
----------
339+
file
340+
Path to save to. Recommended suffix: ".npz".
341+
If you pass ".npy", we will still save an ".npz" next to it.
342+
overwrite
343+
If False and the file exists, raise FileExistsError.
344+
345+
Returns
346+
-------
347+
Path
348+
The path actually written (will end with ".npz").
349+
"""
350+
file = Path(file)
351+
352+
# Force .npz (storing metadata safely)
353+
if file.suffix.lower() != ".npz":
354+
file = file.with_suffix(".npz")
355+
356+
if file.exists() and not overwrite:
357+
raise FileExistsError(f"File already exists: {file}")
358+
359+
meta = curvature_preload_metadata_from(self.real_space_mask)
360+
361+
meta_json = json.dumps(meta, sort_keys=True)
362+
363+
np.savez_compressed(
364+
file,
365+
curvature_preload=np.asarray(self.curvature_preload),
366+
meta_json=np.asarray(meta_json),
367+
)
368+
return file

0 commit comments

Comments
 (0)