Commit bf900af
Rewrap point fit data + magnifications for JAX backend parity
Two fixes in AbstractFitPoint so point-source fits work under use_jax=True:
1. __init__ rewraps observed positions (`data`) with the analysis backend
when use_jax=True. Datasets are loaded as numpy-backed Grid2DIrregular
from JSON, so without rewrapping the fit's data._xp stays np even when
xp=jnp is passed in, and downstream deflection-grid propagation fails.
2. magnifications_at_positions rewraps the raw jax.Array returned by
LensCalc.magnification_2d_via_hessian_from (which skips ArrayIrregular
wrapping on the JAX path) so callers can use .array uniformly across
backends without hitting AttributeError on jax tracers.
Together with PyAutoArray Grid2DIrregular xp propagation, this unblocks
full JIT tracing of FitPositionsSource.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>1 parent 543a3e2 commit bf900af
1 file changed
Lines changed: 7 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
| 77 | + | |
| 78 | + | |
77 | 79 | | |
78 | 80 | | |
79 | 81 | | |
| |||
131 | 133 | | |
132 | 134 | | |
133 | 135 | | |
134 | | - | |
135 | | - | |
| 136 | + | |
| 137 | + | |
136 | 138 | | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
137 | 142 | | |
138 | 143 | | |
139 | 144 | | |
| |||
0 commit comments