Skip to content

feat: autolens interferometer JAX visualization coverage #86

@Jammy2211

Description

@Jammy2211

Overview

The autolens_workspace_test interferometer dataset has a NumPy visualization.py but no visualization_jax.py or modeling_visualization_jit.py — those exist only for imaging. PyAutoLens's AnalysisInterferometer already dispatches via analysis.fit_for_visualization (visualizer.py:96, 209) and has full pytree registration, so the wiring is in place; this task adds the two missing test scripts. Phase 1A of PyAutoPrompt/issued/jax_visualization.md (the JAX visualization roadmap).

Plan

  • Add scripts/interferometer/visualization_jax.py — mirrors scripts/imaging/visualization_jax.py but uses AnalysisInterferometer and the interferometer simulator. Must include enable_pytrees() + register_model(model) from the start (lesson from PR fix: register_model in visualization_jax.py to actually exercise JIT path #85).
  • Add scripts/interferometer/modeling_visualization_jit.py — mirrors the imaging analogue's two-part shape (caching probe + live Nautilus run with iterations_per_quick_update, asserting the JIT cache fires and subplot_fit.png lands).
  • Reuse the existing simulator under scripts/jax_likelihood_functions/interferometer/.
  • Update config/build/env_vars.yaml to add imaging/visualization_jax-style overrides for the interferometer scripts (unset PYAUTO_DISABLE_JAX etc.).
  • Verify both scripts run locally with JAX enabled — they should print PILOT SUCCEEDED and produce the expected PNGs.
Detailed implementation plan

Affected Repositories

  • autolens_workspace_test (primary, only repo)

Work Classification

Workspace

Branch Survey

Repository Current Branch Dirty?
./autolens_workspace_test main README.md (unrelated automated version bump — not in scope for this task)

Suggested branch: feature/autolens-interferometer-jax-viz
Worktree root: ~/Code/PyAutoLabs-wt/autolens-interferometer-jax-viz/ (created later by /start_workspace)

Implementation Steps

  1. scripts/interferometer/visualization_jax.py — mirror the structure of scripts/imaging/visualization_jax.py (post-PR-fix: register_model in visualization_jax.py to actually exercise JIT path #85), with these differences:

    • Use al.AnalysisInterferometer instead of al.AnalysisImaging.
    • Use the interferometer dataset path: dataset/interferometer/jax_test.
    • Auto-simulate via subprocess.run([sys.executable, "scripts/jax_likelihood_functions/interferometer/simulator.py"], check=True) if missing.
    • Build the model with the parametric MGE pattern from scripts/jax_likelihood_functions/interferometer/mge.py (lens + source MGE).
    • Critical (lesson from PR fix: register_model in visualization_jax.py to actually exercise JIT path #85): import from autofit.jax.pytrees import enable_pytrees, register_model and call enable_pytrees() at module level + register_model(model) after building the model. Without these, jax.jit(fit_from) cannot trace the ModelInstance.
    • Critical: No try/except wrapper — call VisualizerInterferometer.visualize directly so any failure surfaces loudly. Assert subplot_fit.png (or fit.png, depending on which the interferometer plotter produces — verify via the imaging analogue's pattern).
    • Reuse config_source/visualize/plots.yaml from the existing interferometer/visualization.py so the visualization output is bounded.
  2. scripts/interferometer/modeling_visualization_jit.py — mirror scripts/imaging/modeling_visualization_jit.py:

    • Part 1: caching probe — call analysis.fit_for_visualization(instance) twice and assert the second call is significantly faster than the first (_jitted_fit_from is cached on the analysis instance).
    • Part 2: live Nautilus run with iterations_per_quick_update=500, n_like_max=1500, n_live=50, asserts fit.png files land under the output search root.
    • Same enable_pytrees() + register_model(model) setup as Part 1.
    • Use MGE linear light profiles (matches the imaging analogue's Part 2) so linear_light_profile_intensity_dict is exercised on the interferometer side.
  3. config/build/env_vars.yaml — add overrides analogous to the autolens imaging entries:

    • pattern: "interferometer/visualization_jax" → unset PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_FAST_PLOTS (mirrors the new imaging/visualization_jax entry from PR fix: register_model in visualization_jax.py to actually exercise JIT path #85).
    • pattern: "interferometer/modeling_visualization_jit" → unset PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_TEST_MODE, PYAUTO_FAST_PLOTS (mirrors the existing imaging/modeling_visualization_jit entry).
  4. Verification — run both with JAX enabled inside the worktree:

    cd $WT_ROOT/autolens_workspace_test
    JAX_ENABLE_X64=True python scripts/interferometer/visualization_jax.py
    JAX_ENABLE_X64=True python scripts/interferometer/modeling_visualization_jit.py

    Both must print PILOT SUCCEEDED (or the jit-cache pass message) and produce the expected PNGs.

Known risk

autolens_workspace_test interferometer scripts have historically been red on CI because of a gitignored sma.fits fixture (complete.md L1080). The simulator at scripts/jax_likelihood_functions/interferometer/simulator.py may depend on this fixture. If running the simulator from the worktree fails on a missing fixture, the autogalaxy port (complete.md L970) wrote a self-contained simulator using np.random.default_rng(seed=1) for 200 synthetic baselines — that's the fallback to mirror.

Key Files

  • scripts/interferometer/visualization_jax.py (NEW)
  • scripts/interferometer/modeling_visualization_jit.py (NEW)
  • config/build/env_vars.yaml (EDIT — add 2 override entries)

Reference patterns

  • scripts/imaging/visualization_jax.py — post-PR-fix: register_model in visualization_jax.py to actually exercise JIT path #85 pattern with enable_pytrees + register_model
  • scripts/imaging/modeling_visualization_jit.py — caching-probe + live-Nautilus pattern
  • scripts/jax_likelihood_functions/interferometer/mge.py — interferometer MGE model setup
  • PyAutoLens autolens/interferometer/model/visualizer.py:96, 209 — already dispatches via fit_for_visualization
  • complete.md L970 — autogalaxy port self-contained simulator pattern (fallback if sma.fits is missing)

Original Prompt

Click to expand starting prompt

(elided — full text in PyAutoPrompt/issued/jax_viz_interferometer_coverage.md)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions