Skip to content

Commit bf900af

Browse files
Jammy2211claude
authored andcommitted
Rewrap point fit data + magnifications for JAX backend parity
Two fixes in AbstractFitPoint so point-source fits work under use_jax=True: 1. __init__ rewraps observed positions (`data`) with the analysis backend when use_jax=True. Datasets are loaded as numpy-backed Grid2DIrregular from JSON, so without rewrapping the fit's data._xp stays np even when xp=jnp is passed in, and downstream deflection-grid propagation fails. 2. magnifications_at_positions rewraps the raw jax.Array returned by LensCalc.magnification_2d_via_hessian_from (which skips ArrayIrregular wrapping on the JAX path) so callers can use .array uniformly across backends without hitting AttributeError on jax tracers. Together with PyAutoArray Grid2DIrregular xp propagation, this unblocks full JIT tracing of FitPositionsSource. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 543a3e2 commit bf900af

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

autolens/point/fit/abstract.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def __init__(
7474
"""
7575

7676
self.name = name
77+
if xp is not np and data._xp is not xp:
78+
data = aa.Grid2DIrregular(values=data.array, xp=xp)
7779
self._data = data
7880
self._noise_map = noise_map
7981
self.tracer = tracer
@@ -131,9 +133,12 @@ def magnifications_at_positions(self) -> aa.ArrayIrregular:
131133
use_multi_plane=use_multi_plane,
132134
plane_j=plane_j,
133135
)
134-
return abs(
135-
od.magnification_2d_via_hessian_from(grid=self.positions, xp=self._xp)
136+
magnifications = od.magnification_2d_via_hessian_from(
137+
grid=self.positions, xp=self._xp
136138
)
139+
if self.use_jax:
140+
magnifications = aa.ArrayIrregular(values=magnifications)
141+
return abs(magnifications)
137142

138143
@property
139144
def source_plane_coordinate(self) -> Tuple[float, float]:

0 commit comments

Comments
 (0)