Skip to content

feat: jax_likelihood_functions/interferometer/ scripts #16

@Jammy2211

Description

@Jammy2211

Overview

Port the JAX-likelihood interferometer scripts from autolens_workspace_test into autogalaxy_workspace_test, mirroring the imaging port that landed in #8 (PR #9). The library prerequisite — _register_fit_interferometer_pytrees on AnalysisInterferometer — shipped today as Jammy2211/PyAutoGalaxy#375 / PR #376, so this task is now unblocked.

Each ported script wraps analysis.fit_from in jax.jit and asserts the figure-of-merit matches the NumPy baseline (the three-step pattern from task 3), exercising the new pytree registration end-to-end. This is task 4/9 of the autogalaxy_workspace_test epic (#5).

Plan

  • Add scripts/jax_likelihood_functions/interferometer/ with the standard __init__.py.
  • Port 8 scripts from autolens (simulator, lp, mge, mge_group, rectangular, rectangular_mge, delaunay, delaunay_mge), swapping TracerGalaxies and al.AnalysisInterferometerag.AnalysisInterferometer.
  • For each ported fit script: NumPy baseline → jax.jit(analysis.fit_from) round-trip → scalar figure_of_merit parity assertion → print PASS.
  • Append entries to smoke_tests.txt under a new # jax_likelihood_functions/interferometer/ section; mirror imaging's commented-out treatment of delaunay_mge.py if JAX 0.7's pytype_aval_mappings removal still breaks it.
  • Skip lens-specific files: simulator_dspl.py, rectangular_dspl.py, rectangular_sparse.py.

Caveats

  • Some reference scripts on autolens CI are red. interferometer/mge.py and interferometer/rectangular.py have been failing on autolens_workspace_test smoke for ≥1 week (recorded in our complete.md — missing sma.fits fixture on CI, or rtol=1e-4 mismatches). The autogalaxy ports may surface the same issues. If they do, raise tolerances or commit the missing fixture; do not silently skip.
  • Adapt-image-bearing variants (rectangular_mge, delaunay, delaunay_mge): the imaging port deferred these initially due to the AdaptImages-across-JIT bug, then re-ported once the fix shipped. That fix is in (AdaptImages.galaxy_path_list + by-instance/by-path lookup in to_inversion.py), so the interferometer ports of these scripts should work first try.
  • Spawn-off awareness. If any profile / dataset class reachable from FitInterferometer isn't pytree-friendly during the port, stop and open a per-class registration issue. Do not paper over with ad-hoc register_pytree_node calls in the workspace script.
Detailed implementation plan

Affected Repositories

  • autogalaxy_workspace_test (primary, only repo touched)

Work Classification

Workspace.

Branch Survey

Repository Current Branch Dirty?
./autogalaxy_workspace_test main clean

Suggested branch: feature/autogalaxy-wst-jax-lh-interferometer
Worktree root: ~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-lh-interferometer/ (created by /start_workspace)

Implementation Steps

  1. Create scripts/jax_likelihood_functions/interferometer/__init__.py (empty package marker).
  2. Port simulator.py first — the auto-simulate pattern (should_simulate(...)) in fit scripts depends on it. Replace al.Tracer.from_galaxies(...) with the autogalaxy equivalent (ag.Galaxies(...), single-galaxy redshift-0.5 setup matching the imaging port's simulator.py).
  3. Port the 7 fit scripts: lp.py, mge.py, mge_group.py, rectangular.py, rectangular_mge.py, delaunay.py, delaunay_mge.py. For each:
    • Replace import autolens as alimport autogalaxy as ag
    • Replace al.Tracer model construction with the af.Collection(galaxies=af.Collection(...)) shape used in the imaging port.
    • Replace al.AnalysisInterferometerag.AnalysisInterferometer
    • Three-step pattern: NumPy fit_fromjax.jit(analysis.fit_from) → scalar parity assertion + print("PASS: jit(fit_from) round-trip matches NumPy scalar.")
  4. Append the ported scripts to smoke_tests.txt under a new section. Mirror imaging's commented-out treatment of delaunay_mge.py if JAX 0.7's pytype_aval_mappings removal still breaks it.
  5. Run each ported script standalone with JAX_ENABLE_X64=True python scripts/jax_likelihood_functions/interferometer/<name>.py to verify the PASS line is printed.
  6. /smoke_test autogalaxy_test to confirm the full smoke set stays green.

Reference scripts (read-only)

/home/jammy/Code/PyAutoLabs/autolens_workspace_test/scripts/jax_likelihood_functions/interferometer/

  • ✓ Port: simulator.py, lp.py, mge.py, mge_group.py, rectangular.py, rectangular_mge.py, delaunay.py, delaunay_mge.py
  • ✗ Skip (lens-specific): simulator_dspl.py, rectangular_dspl.py, rectangular_sparse.py

Key Files

  • autogalaxy_workspace_test/scripts/jax_likelihood_functions/interferometer/__init__.py — new
  • autogalaxy_workspace_test/scripts/jax_likelihood_functions/interferometer/{simulator,lp,mge,mge_group,rectangular,rectangular_mge,delaunay,delaunay_mge}.py — new (8 files)
  • autogalaxy_workspace_test/smoke_tests.txt — append entries
  • autogalaxy_workspace_test/config/build/env_vars.yaml — only if per-path overrides are needed (imaging port didn't need any)

Original Prompt

Click to expand starting prompt

Create scripts/jax_likelihood_functions/interferometer/ in @autogalaxy_workspace_test with
autogalaxy ports of every autolens JAX-likelihood interferometer script, excluding *_dspl.py
and rectangular_sparse.py unless it has an autogalaxy analogue (lens-specific; check first).

Scripts to port

From @autolens_workspace_test/scripts/jax_likelihood_functions/interferometer/:

  • simulator.py
  • lp.py
  • mge.py
  • mge_group.py
  • rectangular.py
  • rectangular_mge.py
  • delaunay.py
  • delaunay_mge.py

Skip: rectangular_dspl.py, simulator_dspl.py, and rectangular_sparse.py (confirm with
user if unsure whether the sparse interferometer path has an autogalaxy equivalent).

Pytree prerequisite — likely blocker

autogalaxy/interferometer/model/analysis.py has no pytree registration method. Compare the
autolens equivalent in @PyAutoLens/autolens/interferometer/model/analysis.py and mirror it on
autogalaxy — register FitInterferometer, DatasetModel, Galaxies.

If the registration is missing, stop and ship a PyAutoGalaxy library PR first (treat as spawn-off
task via /start_dev). Same policy as task 3: do not paper over with in-script registrations.

Three-step JAX pattern

Same contract as task 3 — NumPy baseline, JIT round-trip, scalar log-likelihood match. Print
PASS: jit(fit_from) round-trip matches NumPy scalar.

Deliverables

  1. autogalaxy_workspace_test/scripts/jax_likelihood_functions/interferometer/__init__.py
  2. Ported scripts.
  3. Appended to smoke_tests.txt.
  4. Any required PyAutoGalaxy library PRs merged first.

Depends on

Task 3 completing the PyAutoGalaxy AnalysisImaging._register_fit_imaging_pytrees scaffold —
some of its helpers (e.g. DatasetModel registration) will already be in place and should be
reused rather than duplicated.

Umbrella issue

Task 4/9. Track under the epic issue on PyAutoLabs/autogalaxy_workspace_test.

Refs umbrella #5 (task 4/9)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions