Skip to content

feat: jax_likelihood_functions/multi/ port #18

@Jammy2211

Description

@Jammy2211

Overview

Port the JAX-likelihood multi-dataset scripts from autolens_workspace_test to autogalaxy_workspace_test, mirroring the imaging port (#8/PR #9) and the interferometer port shipping today (#16/PR #17). The library prerequisites — AnalysisImaging and AnalysisInterferometer pytree registration — are both in place. This is task 5/9 of the autogalaxy_workspace_test epic (#5).

Note on what "multi" means here. The original prompt described these as "imaging dataset + interferometer dataset". The actual autolens reference scripts use multi-band imaging (g and r filters, both AnalysisImaging) combined via af.FactorGraphModel. The autogalaxy port mirrors the same structure — two ag.AnalysisImaging factors, one per band, with a single galaxy whose bulge.ell_comps is per-band (option B) and everything else shared.

Path A is structurally different from imaging/interferometer. FactorGraphModel exposes no fit_from — it sums each child factor's log_likelihood_function. So Path A jit-wraps a parameter-vector entry point that mirrors what fitness._vmap does internally: instance_from_vectorlog_likelihood_function. The PASS line reflects this: "PASS: jit(log_likelihood_function) round-trip matches NumPy scalar."

Plan

  • Add scripts/jax_likelihood_functions/multi/ with the standard __init__.py.
  • Port 8 scripts from autolens (simulator, lp, mge, mge_group, rectangular, rectangular_mge, delaunay, delaunay_mge), swapping the lens/source pair for a single autogalaxy galaxy. No *_dspl.py to skip (none exist in the autolens multi/ reference).
  • For each ported fit script: NumPy baseline → jax.jit(parameter-vector → factor_graph.log_likelihood_function) round-trip → scalar parity assertion → print PASS.
  • Append entries to smoke_tests.txt under a new # jax_likelihood_functions/multi/ section.
  • Match the imaging/interferometer ports' tolerance pattern: lp/mge/mge_group at rtol=1e-4, adapt-regularization variants at rtol=1e-2 if NumPy/JAX float-ordering drift surfaces.

Caveats

  • delaunay_mge.py may or may not trip JAX 0.7's pytype_aval_mappings removal. The imaging counterpart is disabled; the interferometer version works fine. Try it; mirror whichever neighbour's treatment fits.
  • Spawn-off awareness. If any class reachable from factor_graph.log_likelihood_function isn't pytree-friendly, stop and open a per-class library issue. Do NOT add ad-hoc register_pytree_node calls in the workspace script.
  • No new fixture dependencies. Mirror the autogalaxy interferometer port: the simulator generates everything inline (Sersic+Exponential galaxy, per-band noise seeds, no external fixtures).
Detailed implementation plan

Affected Repositories

  • autogalaxy_workspace_test (primary, only repo touched)

Work Classification

Workspace.

Branch Survey

Repository Current Branch Dirty?
./autogalaxy_workspace_test main clean

Suggested branch: feature/autogalaxy-wst-jax-lh-multi
Worktree root: ~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-lh-multi/ (created by /start_workspace)

Implementation Steps

  1. Create scripts/jax_likelihood_functions/multi/__init__.py (empty).
  2. Port simulator.py first — the auto-simulate pattern in fit scripts depends on it. Single galaxy (Sersic bulge + Exponential disk), redshift=0.5. For each waveband (g, r): different intensity to give chromatic variation, distinct noise_seed. Output to dataset/multi/jax_test/{g,r}_{data,psf,noise_map}.fits and a shared galaxies.json. PSF is ag.Convolver.from_gaussian.
  3. Port the 7 fit scripts (lp, mge, mge_group, rectangular, rectangular_mge, delaunay, delaunay_mge). For each:
    • Replace import autolens as alimport autogalaxy as ag
    • Drop the lens model (no mass profile, no shear in autogalaxy)
    • Replace source with a single galaxy model
    • Keep the per-band model.copy() + af.GaussianPrior on galaxy.bulge.ell_comps_{0,1} pattern (option B)
    • Build factor_graph = af.FactorGraphModel(*analysis_factor_list, use_jax=True) with ag.AnalysisImaging(dataset) factors
    • Path A: analysis_np_list (use_jax=False) → NumPy factor_graph_np.log_likelihood_function(instance_np)analysis_jit_list (use_jax=True) → jax.jit(parameters → instance_from_vector → log_likelihood_function) → scalar parity → print("PASS: jit(log_likelihood_function) round-trip matches NumPy scalar.")
    • Drop the autolens-specific EXPECTED_VMAP_LOG_LIKELIHOOD hardcoded baseline — that won't match autogalaxy.
  4. Append the ported scripts to smoke_tests.txt. Mirror imaging port's commented-out treatment of delaunay_mge.py only if JAX 0.7 actually trips on the multi version (test it; if it works, enable like interferometer).
  5. Run each ported script standalone with JAX_ENABLE_X64=True python scripts/jax_likelihood_functions/multi/<name>.py and verify the PASS line.

Reference scripts (read-only)

/home/jammy/Code/PyAutoLabs/autolens_workspace_test/scripts/jax_likelihood_functions/multi/

  • ✓ Port: simulator.py, lp.py, mge.py, mge_group.py, rectangular.py, rectangular_mge.py, delaunay.py, delaunay_mge.py
  • No *_dspl.py files in autolens multi/ to skip.

Key Files

  • autogalaxy_workspace_test/scripts/jax_likelihood_functions/multi/__init__.py — new
  • autogalaxy_workspace_test/scripts/jax_likelihood_functions/multi/{simulator,lp,mge,mge_group,rectangular,rectangular_mge,delaunay,delaunay_mge}.py — new (8 files)
  • autogalaxy_workspace_test/smoke_tests.txt — append entries
  • autogalaxy_workspace_test/config/build/env_vars.yaml — only if per-path overrides are needed (existing jax_likelihood_functions/ substring override should cover the new path)

Original Prompt

Click to expand starting prompt

Create scripts/jax_likelihood_functions/multi/ in @autogalaxy_workspace_test with autogalaxy
ports of every autolens multi-dataset JAX-likelihood script, excluding *_dspl.py.

Scripts to port

From @autolens_workspace_test/scripts/jax_likelihood_functions/multi/:

  • simulator.py
  • lp.py
  • mge.py
  • mge_group.py
  • rectangular.py
  • rectangular_mge.py
  • delaunay.py
  • delaunay_mge.py

Skip: any *_dspl.py.

Context

The multi/ scripts exercise af.FactorGraphModel / multi-dataset joint fits. The autolens
versions combine an imaging dataset and an interferometer dataset with tied lens-galaxy params.
Autogalaxy versions should tie galaxy params across datasets (no lens/source split).

Pytree prerequisite

Both AnalysisImaging and AnalysisInterferometer on autogalaxy need pytree-registered
fit_from. If tasks 3 and 4 have landed first, this should follow naturally. If any multi-dataset
registration is missing (e.g. the factor-graph combining analysis), spawn a library task.

Three-step JAX pattern

Same as tasks 3 and 4.

Deliverables

  1. autogalaxy_workspace_test/scripts/jax_likelihood_functions/multi/__init__.py
  2. Ported scripts.
  3. Appended to smoke_tests.txt.

Depends on

Tasks 3 and 4 (imaging + interferometer pytree registration scaffolding in place).

Umbrella issue

Task 5/9. Track under the epic issue on PyAutoLabs/autogalaxy_workspace_test.

Refs umbrella #5 (task 5/9)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions