Overview
Port the JAX-likelihood multi-dataset scripts from autolens_workspace_test to autogalaxy_workspace_test, mirroring the imaging port (#8/PR #9) and the interferometer port shipping today (#16/PR #17). The library prerequisites — AnalysisImaging and AnalysisInterferometer pytree registration — are both in place. This is task 5/9 of the autogalaxy_workspace_test epic (#5).
Note on what "multi" means here. The original prompt described these as "imaging dataset + interferometer dataset". The actual autolens reference scripts use multi-band imaging (g and r filters, both AnalysisImaging) combined via af.FactorGraphModel. The autogalaxy port mirrors the same structure — two ag.AnalysisImaging factors, one per band, with a single galaxy whose bulge.ell_comps is per-band (option B) and everything else shared.
Path A is structurally different from imaging/interferometer. FactorGraphModel exposes no fit_from — it sums each child factor's log_likelihood_function. So Path A jit-wraps a parameter-vector entry point that mirrors what fitness._vmap does internally: instance_from_vector → log_likelihood_function. The PASS line reflects this: "PASS: jit(log_likelihood_function) round-trip matches NumPy scalar."
Plan
- Add
scripts/jax_likelihood_functions/multi/ with the standard __init__.py.
- Port 8 scripts from autolens (
simulator, lp, mge, mge_group, rectangular, rectangular_mge, delaunay, delaunay_mge), swapping the lens/source pair for a single autogalaxy galaxy. No *_dspl.py to skip (none exist in the autolens multi/ reference).
- For each ported fit script: NumPy baseline →
jax.jit(parameter-vector → factor_graph.log_likelihood_function) round-trip → scalar parity assertion → print PASS.
- Append entries to
smoke_tests.txt under a new # jax_likelihood_functions/multi/ section.
- Match the imaging/interferometer ports' tolerance pattern:
lp/mge/mge_group at rtol=1e-4, adapt-regularization variants at rtol=1e-2 if NumPy/JAX float-ordering drift surfaces.
Caveats
delaunay_mge.py may or may not trip JAX 0.7's pytype_aval_mappings removal. The imaging counterpart is disabled; the interferometer version works fine. Try it; mirror whichever neighbour's treatment fits.
- Spawn-off awareness. If any class reachable from
factor_graph.log_likelihood_function isn't pytree-friendly, stop and open a per-class library issue. Do NOT add ad-hoc register_pytree_node calls in the workspace script.
- No new fixture dependencies. Mirror the autogalaxy interferometer port: the simulator generates everything inline (Sersic+Exponential galaxy, per-band noise seeds, no external fixtures).
Detailed implementation plan
Affected Repositories
autogalaxy_workspace_test (primary, only repo touched)
Work Classification
Workspace.
Branch Survey
| Repository |
Current Branch |
Dirty? |
./autogalaxy_workspace_test |
main |
clean |
Suggested branch: feature/autogalaxy-wst-jax-lh-multi
Worktree root: ~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-lh-multi/ (created by /start_workspace)
Implementation Steps
- Create
scripts/jax_likelihood_functions/multi/__init__.py (empty).
- Port
simulator.py first — the auto-simulate pattern in fit scripts depends on it. Single galaxy (Sersic bulge + Exponential disk), redshift=0.5. For each waveband (g, r): different intensity to give chromatic variation, distinct noise_seed. Output to dataset/multi/jax_test/{g,r}_{data,psf,noise_map}.fits and a shared galaxies.json. PSF is ag.Convolver.from_gaussian.
- Port the 7 fit scripts (
lp, mge, mge_group, rectangular, rectangular_mge, delaunay, delaunay_mge). For each:
- Replace
import autolens as al → import autogalaxy as ag
- Drop the
lens model (no mass profile, no shear in autogalaxy)
- Replace
source with a single galaxy model
- Keep the per-band
model.copy() + af.GaussianPrior on galaxy.bulge.ell_comps_{0,1} pattern (option B)
- Build
factor_graph = af.FactorGraphModel(*analysis_factor_list, use_jax=True) with ag.AnalysisImaging(dataset) factors
- Path A:
analysis_np_list (use_jax=False) → NumPy factor_graph_np.log_likelihood_function(instance_np) → analysis_jit_list (use_jax=True) → jax.jit(parameters → instance_from_vector → log_likelihood_function) → scalar parity → print("PASS: jit(log_likelihood_function) round-trip matches NumPy scalar.")
- Drop the autolens-specific
EXPECTED_VMAP_LOG_LIKELIHOOD hardcoded baseline — that won't match autogalaxy.
- Append the ported scripts to
smoke_tests.txt. Mirror imaging port's commented-out treatment of delaunay_mge.py only if JAX 0.7 actually trips on the multi version (test it; if it works, enable like interferometer).
- Run each ported script standalone with
JAX_ENABLE_X64=True python scripts/jax_likelihood_functions/multi/<name>.py and verify the PASS line.
Reference scripts (read-only)
/home/jammy/Code/PyAutoLabs/autolens_workspace_test/scripts/jax_likelihood_functions/multi/
- ✓ Port:
simulator.py, lp.py, mge.py, mge_group.py, rectangular.py, rectangular_mge.py, delaunay.py, delaunay_mge.py
- No
*_dspl.py files in autolens multi/ to skip.
Key Files
autogalaxy_workspace_test/scripts/jax_likelihood_functions/multi/__init__.py — new
autogalaxy_workspace_test/scripts/jax_likelihood_functions/multi/{simulator,lp,mge,mge_group,rectangular,rectangular_mge,delaunay,delaunay_mge}.py — new (8 files)
autogalaxy_workspace_test/smoke_tests.txt — append entries
autogalaxy_workspace_test/config/build/env_vars.yaml — only if per-path overrides are needed (existing jax_likelihood_functions/ substring override should cover the new path)
Original Prompt
Click to expand starting prompt
Create scripts/jax_likelihood_functions/multi/ in @autogalaxy_workspace_test with autogalaxy
ports of every autolens multi-dataset JAX-likelihood script, excluding *_dspl.py.
Scripts to port
From @autolens_workspace_test/scripts/jax_likelihood_functions/multi/:
simulator.py
lp.py
mge.py
mge_group.py
rectangular.py
rectangular_mge.py
delaunay.py
delaunay_mge.py
Skip: any *_dspl.py.
Context
The multi/ scripts exercise af.FactorGraphModel / multi-dataset joint fits. The autolens
versions combine an imaging dataset and an interferometer dataset with tied lens-galaxy params.
Autogalaxy versions should tie galaxy params across datasets (no lens/source split).
Pytree prerequisite
Both AnalysisImaging and AnalysisInterferometer on autogalaxy need pytree-registered
fit_from. If tasks 3 and 4 have landed first, this should follow naturally. If any multi-dataset
registration is missing (e.g. the factor-graph combining analysis), spawn a library task.
Three-step JAX pattern
Same as tasks 3 and 4.
Deliverables
autogalaxy_workspace_test/scripts/jax_likelihood_functions/multi/__init__.py
- Ported scripts.
- Appended to
smoke_tests.txt.
Depends on
Tasks 3 and 4 (imaging + interferometer pytree registration scaffolding in place).
Umbrella issue
Task 5/9. Track under the epic issue on PyAutoLabs/autogalaxy_workspace_test.
Refs umbrella #5 (task 5/9)
Overview
Port the JAX-likelihood multi-dataset scripts from
autolens_workspace_testtoautogalaxy_workspace_test, mirroring the imaging port (#8/PR #9) and the interferometer port shipping today (#16/PR #17). The library prerequisites —AnalysisImagingandAnalysisInterferometerpytree registration — are both in place. This is task 5/9 of the autogalaxy_workspace_test epic (#5).Note on what "multi" means here. The original prompt described these as "imaging dataset + interferometer dataset". The actual autolens reference scripts use multi-band imaging (g and r filters, both
AnalysisImaging) combined viaaf.FactorGraphModel. The autogalaxy port mirrors the same structure — twoag.AnalysisImagingfactors, one per band, with a single galaxy whosebulge.ell_compsis per-band (option B) and everything else shared.Path A is structurally different from imaging/interferometer.
FactorGraphModelexposes nofit_from— it sums each child factor'slog_likelihood_function. So Path A jit-wraps a parameter-vector entry point that mirrors whatfitness._vmapdoes internally:instance_from_vector→log_likelihood_function. The PASS line reflects this:"PASS: jit(log_likelihood_function) round-trip matches NumPy scalar."Plan
scripts/jax_likelihood_functions/multi/with the standard__init__.py.simulator,lp,mge,mge_group,rectangular,rectangular_mge,delaunay,delaunay_mge), swapping the lens/source pair for a single autogalaxy galaxy. No*_dspl.pyto skip (none exist in the autolensmulti/reference).jax.jit(parameter-vector → factor_graph.log_likelihood_function)round-trip → scalar parity assertion → print PASS.smoke_tests.txtunder a new# jax_likelihood_functions/multi/section.lp/mge/mge_groupat rtol=1e-4, adapt-regularization variants at rtol=1e-2 if NumPy/JAX float-ordering drift surfaces.Caveats
delaunay_mge.pymay or may not trip JAX 0.7'spytype_aval_mappingsremoval. The imaging counterpart is disabled; the interferometer version works fine. Try it; mirror whichever neighbour's treatment fits.factor_graph.log_likelihood_functionisn't pytree-friendly, stop and open a per-class library issue. Do NOT add ad-hocregister_pytree_nodecalls in the workspace script.Detailed implementation plan
Affected Repositories
autogalaxy_workspace_test(primary, only repo touched)Work Classification
Workspace.
Branch Survey
./autogalaxy_workspace_testSuggested branch:
feature/autogalaxy-wst-jax-lh-multiWorktree root:
~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-lh-multi/(created by/start_workspace)Implementation Steps
scripts/jax_likelihood_functions/multi/__init__.py(empty).simulator.pyfirst — the auto-simulate pattern in fit scripts depends on it. Single galaxy (Sersic bulge + Exponential disk), redshift=0.5. For each waveband (g, r): differentintensityto give chromatic variation, distinctnoise_seed. Output todataset/multi/jax_test/{g,r}_{data,psf,noise_map}.fitsand a sharedgalaxies.json. PSF isag.Convolver.from_gaussian.lp,mge,mge_group,rectangular,rectangular_mge,delaunay,delaunay_mge). For each:import autolens as al→import autogalaxy as aglensmodel (no mass profile, no shear in autogalaxy)sourcewith a singlegalaxymodelmodel.copy()+af.GaussianPriorongalaxy.bulge.ell_comps_{0,1}pattern (option B)factor_graph = af.FactorGraphModel(*analysis_factor_list, use_jax=True)withag.AnalysisImaging(dataset)factorsanalysis_np_list(use_jax=False) → NumPyfactor_graph_np.log_likelihood_function(instance_np)→analysis_jit_list(use_jax=True) →jax.jit(parameters → instance_from_vector → log_likelihood_function)→ scalar parity →print("PASS: jit(log_likelihood_function) round-trip matches NumPy scalar.")EXPECTED_VMAP_LOG_LIKELIHOODhardcoded baseline — that won't match autogalaxy.smoke_tests.txt. Mirror imaging port's commented-out treatment ofdelaunay_mge.pyonly if JAX 0.7 actually trips on the multi version (test it; if it works, enable like interferometer).JAX_ENABLE_X64=True python scripts/jax_likelihood_functions/multi/<name>.pyand verify the PASS line.Reference scripts (read-only)
/home/jammy/Code/PyAutoLabs/autolens_workspace_test/scripts/jax_likelihood_functions/multi/simulator.py,lp.py,mge.py,mge_group.py,rectangular.py,rectangular_mge.py,delaunay.py,delaunay_mge.py*_dspl.pyfiles in autolensmulti/to skip.Key Files
autogalaxy_workspace_test/scripts/jax_likelihood_functions/multi/__init__.py— newautogalaxy_workspace_test/scripts/jax_likelihood_functions/multi/{simulator,lp,mge,mge_group,rectangular,rectangular_mge,delaunay,delaunay_mge}.py— new (8 files)autogalaxy_workspace_test/smoke_tests.txt— append entriesautogalaxy_workspace_test/config/build/env_vars.yaml— only if per-path overrides are needed (existingjax_likelihood_functions/substring override should cover the new path)Original Prompt
Click to expand starting prompt
Create
scripts/jax_likelihood_functions/multi/in @autogalaxy_workspace_test with autogalaxyports of every autolens multi-dataset JAX-likelihood script, excluding
*_dspl.py.Scripts to port
From @autolens_workspace_test/scripts/jax_likelihood_functions/multi/:
simulator.pylp.pymge.pymge_group.pyrectangular.pyrectangular_mge.pydelaunay.pydelaunay_mge.pySkip: any
*_dspl.py.Context
The
multi/scripts exerciseaf.FactorGraphModel/ multi-dataset joint fits. The autolensversions combine an imaging dataset and an interferometer dataset with tied lens-galaxy params.
Autogalaxy versions should tie galaxy params across datasets (no lens/source split).
Pytree prerequisite
Both
AnalysisImagingandAnalysisInterferometeron autogalaxy need pytree-registeredfit_from. If tasks 3 and 4 have landed first, this should follow naturally. If any multi-datasetregistration is missing (e.g. the factor-graph combining analysis), spawn a library task.
Three-step JAX pattern
Same as tasks 3 and 4.
Deliverables
autogalaxy_workspace_test/scripts/jax_likelihood_functions/multi/__init__.pysmoke_tests.txt.Depends on
Tasks 3 and 4 (imaging + interferometer pytree registration scaffolding in place).
Umbrella issue
Task 5/9. Track under the epic issue on
PyAutoLabs/autogalaxy_workspace_test.Refs umbrella #5 (task 5/9)