Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 226 additions & 0 deletions scripts/ellipse/modeling_visualization_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""
End-to-end test: jit-cached visualization during a real Nautilus ellipse fit.
=============================================================================

Single-galaxy autogalaxy ellipse port of the autolens
``scripts/imaging/modeling_visualization_jit.py`` end-to-end test.

This test runs in two parts:

Part 1 — **Caching probe.** Uses a parametric single-``Ellipse`` model.
Calls ``analysis.fit_for_visualization(instance)`` twice and asserts the
second call is much faster than the first (confirming the compiled
function is cached on the analysis instance, not recompiled per
visualization).

Part 2 — **Live Nautilus quick-update.** Runs a real (short) Nautilus
fit with the same ellipse model. Asserts that ``fit_ellipse.png`` files
land on disk, proving the JIT-cached fit_for_visualization fires
correctly during the live search callback.

This script deliberately opts in with
``AnalysisEllipse(use_jax=True, use_jax_for_visualization=True)``.
Default ellipse model-fit scripts elsewhere in the workspace leave both
flags at ``False`` and are therefore untouched by this change.
"""

import shutil
import subprocess
import sys
import time
from os import path
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np

import autofit as af
import autogalaxy as ag
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()


"""
__Dataset__

Reuse the ``jax_test`` imaging dataset (auto-simulated on first run).
"""
dataset_path = path.join("dataset", "imaging", "jax_test")

if not path.exists(path.join(dataset_path, "data.fits")):
subprocess.run(
[sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"],
check=True,
)

dataset_unmasked = ag.Imaging.from_fits(
data_path=path.join(dataset_path, "data.fits"),
psf_path=path.join(dataset_path, "psf.fits"),
noise_map_path=path.join(dataset_path, "noise_map.fits"),
pixel_scales=0.2,
)

mask_radius = 3.0
mask = ag.Mask2D.circular(
shape_native=dataset_unmasked.shape_native,
pixel_scales=dataset_unmasked.pixel_scales,
radius=mask_radius,
)
dataset = dataset_unmasked.apply_mask(mask=mask)


"""
============================================================================
Part 1 — Caching probe
============================================================================

Model: single parametric ``Ellipse`` with tight priors so the
prior-median instance lands inside the mask.
"""
print("\n" + "=" * 72)
print("Part 1: Ellipse caching probe")
print("=" * 72)

ellipse_mge = af.Model(ag.Ellipse)
ellipse_mge.centre.centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1)
ellipse_mge.centre.centre_1 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1)
ellipse_mge.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.2)
ellipse_mge.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=-0.05, upper_limit=0.1)
ellipse_mge.major_axis = 1.0

model_mge = af.Collection(ellipses=af.Collection(ellipse_0=ellipse_mge))

register_model(model_mge)

analysis_mge = ag.AnalysisEllipse(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
)

instance_mge = model_mge.instance_from_prior_medians()

t0 = time.perf_counter()
fit_1 = analysis_mge.fit_for_visualization(instance_mge)
jax.block_until_ready(fit_1.log_likelihood)
t1 = time.perf_counter()
compile_time = t1 - t0
print(f"First call (compile + run): {compile_time:.3f}s")
print(f" log_likelihood leaf type: {type(fit_1.log_likelihood).__name__}")
assert isinstance(
fit_1.log_likelihood, jnp.ndarray
), f"expected jax.Array, got {type(fit_1.log_likelihood)}"

t0 = time.perf_counter()
fit_2 = analysis_mge.fit_for_visualization(instance_mge)
jax.block_until_ready(fit_2.log_likelihood)
t1 = time.perf_counter()
cached_time = t1 - t0
print(f"Second call (cached): {cached_time:.3f}s")
print(f"Speedup: {compile_time / max(cached_time, 1e-9):.1f}x")

assert cached_time < compile_time * 0.5, (
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
assert (
analysis_mge._jitted_fit_from is not None
), "expected _jitted_fit_from to be cached on the analysis instance after first call"
print("PASS: Ellipse jit-cached fit_for_visualization works and is reused.")


"""
__Visualization Sanity__

Phase D.2.b.ii — autogalaxy ellipse variant (no Tracer / no lensing).
``FitEllipseSummed`` doesn't expose a top-level ``model_data`` (the
model field lives per-ellipse on the component ``FitEllipse``s), so the
Sanity check is restricted to ``figure_of_merit`` — the aggregate
log-likelihood the search consumes. Catches JAX-trace mismatches that
would leave the cosmetic plot OK but the underlying FoM nan/inf.

Runs on the cached ``fit_2`` from Part 1 so the warm JIT path is
exercised (compile already paid above).
"""

_fom = float(fit_2.figure_of_merit)
assert np.isfinite(_fom), (
f"figure_of_merit = {_fom} — chi² nan/inf, fit collapsed"
)
print(
f" PASS Visualization Sanity (autogalaxy ellipse): "
f"figure_of_merit = {_fom:.4f}"
)


"""
============================================================================
Part 2 — Live Nautilus quick-update with single Ellipse
============================================================================

Same parametric ``Ellipse`` model. The live search fires quick-update
visualization every ``iterations_per_quick_update`` calls; we verify
``fit_ellipse.png`` lands on disk.
"""
print("\n" + "=" * 72)
print("Part 2: Live Nautilus with ellipse + jit-visualization")
print("=" * 72)

ellipse_2 = af.Model(ag.Ellipse)
ellipse_2.centre.centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1)
ellipse_2.centre.centre_1 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1)
ellipse_2.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.2)
ellipse_2.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=-0.05, upper_limit=0.1)
ellipse_2.major_axis = 1.0

model_mge2 = af.Collection(ellipses=af.Collection(ellipse_0=ellipse_2))

register_model(model_mge2)

analysis_mge2 = ag.AnalysisEllipse(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
)

output_root = Path("scripts") / "ellipse" / "images" / "modeling_visualization_jit"
if output_root.exists():
shutil.rmtree(output_root)
output_root.mkdir(parents=True)

search = af.Nautilus(
path_prefix=str(output_root),
name="ellipse_jit",
n_live=50,
n_like_max=1500,
iterations_per_quick_update=500,
number_of_cores=1,
)

print("Running Nautilus ...")
result = search.fit(model=model_mge2, analysis=analysis_mge2)

# The Nautilus output goes to output/<path_prefix>/<name>/<hash>/image/.
# The quick-update visualizer writes fit_ellipse.png to that image
# folder during each quick update.
output_search_root = Path("output") / output_root / "ellipse_jit"
produced_pngs = list(output_search_root.rglob("fit_ellipse.png"))
print(f"fit_ellipse.png files produced: {len(produced_pngs)}")
for p in produced_pngs:
print(f" {p}")
assert len(produced_pngs) > 0, (
f"no fit_ellipse.png produced under {output_search_root} — "
"quick-update visualization did not fire"
)

# Note: _jitted_fit_from is built on the worker process Nautilus forks for the
# search loop, not the parent's analysis_mge2 instance — so we don't assert it
# post-search. Part 1 above already verifies the cache is set on the calling
# process.

print(
"\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates "
"for ellipse, fit_ellipse.png written."
)
166 changes: 166 additions & 0 deletions scripts/ellipse/visualization_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
Visualization JAX Pilot: Ellipse Analysis (autogalaxy)
======================================================

Tests that ``VisualizerEllipse.visualize`` with
``use_jax_for_visualization=True`` dispatches through the JIT-cached
``fit_for_visualization`` path that the parent ``af.Analysis`` already
provides. ``AnalysisEllipse.__init__`` passes ``**kwargs`` to its
parent, so ``use_jax_for_visualization=True`` flows through to the
PyAutoFit-level dispatch without a library-side change.

Scope
-----
- Single ``af.Model(ag.Ellipse)`` model — no multipoles. Multipoles are
exercised by the non-JAX ``ellipse/visualization.py`` already.
- Calls ``VisualizerEllipse.visualize`` only (not ``visualize_before_fit``).
- Reuses the ``dataset/imaging/jax_test`` dataset that the
``jax_likelihood_functions`` scripts produce.
- ``use_jax=True`` turns on the JAX path; ``use_jax_for_visualization=True``
routes ``Visualizer*.visualize`` through ``analysis.fit_for_visualization``.
"""

import shutil
import subprocess
import sys
from os import path
from pathlib import Path
from types import SimpleNamespace

from autoconf import conf

conf.instance.push(
new_path=path.join(path.dirname(path.realpath(__file__)), "config"),
output_path=path.join(path.dirname(path.realpath(__file__)), "images"),
)

import autofit as af
import autogalaxy as ag
from autofit.jax.pytrees import enable_pytrees, register_model
from autogalaxy.ellipse.model.visualizer import VisualizerEllipse

enable_pytrees()


"""
__Dataset__

Reuse the ``jax_test`` imaging dataset (auto-simulated on first run).
"""
dataset_path = path.join("dataset", "imaging", "jax_test")

if not path.exists(path.join(dataset_path, "data.fits")):
subprocess.run(
[sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"],
check=True,
)

dataset_unmasked = ag.Imaging.from_fits(
data_path=path.join(dataset_path, "data.fits"),
psf_path=path.join(dataset_path, "psf.fits"),
noise_map_path=path.join(dataset_path, "noise_map.fits"),
pixel_scales=0.2,
)

mask = ag.Mask2D.circular(
shape_native=dataset_unmasked.shape_native,
pixel_scales=dataset_unmasked.pixel_scales,
radius=3.0,
)
dataset = dataset_unmasked.apply_mask(mask=mask)


"""
__Model__

A single fixed ``Ellipse`` so ``instance_from_prior_medians()`` produces
a deterministic instance with the major-axis well inside the mask.
"""
ellipse = af.Model(ag.Ellipse)
ellipse.centre.centre_0 = 0.0
ellipse.centre.centre_1 = 0.0
ellipse.ell_comps.ell_comps_0 = 0.1
ellipse.ell_comps.ell_comps_1 = 0.05
ellipse.major_axis = 1.0

model = af.Collection(ellipses=af.Collection(ellipse_0=ellipse))

register_model(model)


"""
__Analysis__

``use_jax=True`` turns on the JAX path; ``use_jax_for_visualization=True``
tells the visualizer to dispatch through the JIT-cached
``fit_for_visualization`` helper on the parent ``af.Analysis``.
``AnalysisEllipse.__init__`` accepts ``**kwargs`` and forwards them to
``super().__init__``, so no AnalysisEllipse signature change is needed.
"""
analysis = ag.AnalysisEllipse(
dataset=dataset,
use_jax=True,
use_jax_for_visualization=True,
title_prefix="JAX_PILOT",
)


"""
__Paths__
"""
image_path = Path("scripts") / "ellipse" / "images" / "visualization_jax"
if image_path.exists():
shutil.rmtree(image_path)
image_path.mkdir(parents=True)
output_path = image_path / "output"
output_path.mkdir(parents=True)
paths = SimpleNamespace(image_path=image_path, output_path=output_path)


"""
__Run visualize on the eager-JAX fit__
"""
instance = model.instance_from_prior_medians()

print("Running VisualizerEllipse.visualize with use_jax_for_visualization=True ...")
VisualizerEllipse.visualize(
analysis=analysis,
paths=paths,
instance=instance,
during_analysis=False,
)

# `fit_ellipse.png` is one of the artifacts that the non-JAX
# `ellipse/visualization.py` script asserts on; check the same one here.
assert (image_path / "fit_ellipse.png").exists(), (
"fit_ellipse.png was not produced by the JAX-backed visualizer"
)
print("PILOT SUCCEEDED — JAX-backed ellipse visualization produced fit_ellipse.png.")


"""
__Visualization Sanity__

Phase D.2.b.ii — autogalaxy ellipse variant (no Tracer / no lensing
latents). ``FitEllipseSummed`` is the fit object; ``figure_of_merit`` is
the aggregate log-likelihood the search would use. Asserts FoM is finite
on the prior-median instance, catching JAX-trace mismatches that would
leave the cosmetic plot OK but the underlying log-likelihood nan/inf.

Note: ``FitEllipseSummed`` doesn't expose a single ``model_data`` array
the way ``FitImaging`` does — the model field lives per-ellipse on the
component ``FitEllipse`` objects. The figure_of_merit check covers the
end-to-end fit; per-component model_data assertion would just duplicate
the FoM signal with extra fragility.
"""
import numpy as _sanity_np

_fit_for_vis = analysis.fit_from(instance=instance)
_fom = float(_fit_for_vis.figure_of_merit)
assert _sanity_np.isfinite(_fom), (
f"figure_of_merit = {_fom} — chi² nan/inf, fit collapsed"
)
print(
f" PASS Visualization Sanity (autogalaxy ellipse): "
f"figure_of_merit = {_fom:.4f}"
)
Loading