diff --git a/scripts/ellipse/modeling_visualization_jit.py b/scripts/ellipse/modeling_visualization_jit.py new file mode 100644 index 0000000..58131b0 --- /dev/null +++ b/scripts/ellipse/modeling_visualization_jit.py @@ -0,0 +1,226 @@ +""" +End-to-end test: jit-cached visualization during a real Nautilus ellipse fit. +============================================================================= + +Single-galaxy autogalaxy ellipse port of the autolens +``scripts/imaging/modeling_visualization_jit.py`` end-to-end test. + +This test runs in two parts: + +Part 1 — **Caching probe.** Uses a parametric single-``Ellipse`` model. +Calls ``analysis.fit_for_visualization(instance)`` twice and asserts the +second call is much faster than the first (confirming the compiled +function is cached on the analysis instance, not recompiled per +visualization). + +Part 2 — **Live Nautilus quick-update.** Runs a real (short) Nautilus +fit with the same ellipse model. Asserts that ``fit_ellipse.png`` files +land on disk, proving the JIT-cached fit_for_visualization fires +correctly during the live search callback. + +This script deliberately opts in with +``AnalysisEllipse(use_jax=True, use_jax_for_visualization=True)``. +Default ellipse model-fit scripts elsewhere in the workspace leave both +flags at ``False`` and are therefore untouched by this change. +""" + +import shutil +import subprocess +import sys +import time +from os import path +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autogalaxy as ag +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() + + +""" +__Dataset__ + +Reuse the ``jax_test`` imaging dataset (auto-simulated on first run). +""" +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset_unmasked = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask_radius = 3.0 +mask = ag.Mask2D.circular( + shape_native=dataset_unmasked.shape_native, + pixel_scales=dataset_unmasked.pixel_scales, + radius=mask_radius, +) +dataset = dataset_unmasked.apply_mask(mask=mask) + + +""" +============================================================================ +Part 1 — Caching probe +============================================================================ + +Model: single parametric ``Ellipse`` with tight priors so the +prior-median instance lands inside the mask. +""" +print("\n" + "=" * 72) +print("Part 1: Ellipse caching probe") +print("=" * 72) + +ellipse_mge = af.Model(ag.Ellipse) +ellipse_mge.centre.centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +ellipse_mge.centre.centre_1 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +ellipse_mge.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.2) +ellipse_mge.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=-0.05, upper_limit=0.1) +ellipse_mge.major_axis = 1.0 + +model_mge = af.Collection(ellipses=af.Collection(ellipse_0=ellipse_mge)) + +register_model(model_mge) + +analysis_mge = ag.AnalysisEllipse( + dataset=dataset, + use_jax=True, + use_jax_for_visualization=True, +) + +instance_mge = model_mge.instance_from_prior_medians() + +t0 = time.perf_counter() +fit_1 = analysis_mge.fit_for_visualization(instance_mge) +jax.block_until_ready(fit_1.log_likelihood) +t1 = time.perf_counter() +compile_time = t1 - t0 +print(f"First call (compile + run): {compile_time:.3f}s") +print(f" log_likelihood leaf type: {type(fit_1.log_likelihood).__name__}") +assert isinstance( + fit_1.log_likelihood, jnp.ndarray +), f"expected jax.Array, got {type(fit_1.log_likelihood)}" + +t0 = time.perf_counter() +fit_2 = analysis_mge.fit_for_visualization(instance_mge) +jax.block_until_ready(fit_2.log_likelihood) +t1 = time.perf_counter() +cached_time = t1 - t0 +print(f"Second call (cached): {cached_time:.3f}s") +print(f"Speedup: {compile_time / max(cached_time, 1e-9):.1f}x") + +assert cached_time < compile_time * 0.5, ( + f"Cached call ({cached_time:.3f}s) not faster than compile " + f"({compile_time:.3f}s) — JIT cache is not being hit." +) +assert ( + analysis_mge._jitted_fit_from is not None +), "expected _jitted_fit_from to be cached on the analysis instance after first call" +print("PASS: Ellipse jit-cached fit_for_visualization works and is reused.") + + +""" +__Visualization Sanity__ + +Phase D.2.b.ii — autogalaxy ellipse variant (no Tracer / no lensing). +``FitEllipseSummed`` doesn't expose a top-level ``model_data`` (the +model field lives per-ellipse on the component ``FitEllipse``s), so the +Sanity check is restricted to ``figure_of_merit`` — the aggregate +log-likelihood the search consumes. Catches JAX-trace mismatches that +would leave the cosmetic plot OK but the underlying FoM nan/inf. + +Runs on the cached ``fit_2`` from Part 1 so the warm JIT path is +exercised (compile already paid above). +""" + +_fom = float(fit_2.figure_of_merit) +assert np.isfinite(_fom), ( + f"figure_of_merit = {_fom} — chi² nan/inf, fit collapsed" +) +print( + f" PASS Visualization Sanity (autogalaxy ellipse): " + f"figure_of_merit = {_fom:.4f}" +) + + +""" +============================================================================ +Part 2 — Live Nautilus quick-update with single Ellipse +============================================================================ + +Same parametric ``Ellipse`` model. The live search fires quick-update +visualization every ``iterations_per_quick_update`` calls; we verify +``fit_ellipse.png`` lands on disk. +""" +print("\n" + "=" * 72) +print("Part 2: Live Nautilus with ellipse + jit-visualization") +print("=" * 72) + +ellipse_2 = af.Model(ag.Ellipse) +ellipse_2.centre.centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +ellipse_2.centre.centre_1 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +ellipse_2.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.2) +ellipse_2.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=-0.05, upper_limit=0.1) +ellipse_2.major_axis = 1.0 + +model_mge2 = af.Collection(ellipses=af.Collection(ellipse_0=ellipse_2)) + +register_model(model_mge2) + +analysis_mge2 = ag.AnalysisEllipse( + dataset=dataset, + use_jax=True, + use_jax_for_visualization=True, +) + +output_root = Path("scripts") / "ellipse" / "images" / "modeling_visualization_jit" +if output_root.exists(): + shutil.rmtree(output_root) +output_root.mkdir(parents=True) + +search = af.Nautilus( + path_prefix=str(output_root), + name="ellipse_jit", + n_live=50, + n_like_max=1500, + iterations_per_quick_update=500, + number_of_cores=1, +) + +print("Running Nautilus ...") +result = search.fit(model=model_mge2, analysis=analysis_mge2) + +# The Nautilus output goes to output////image/. +# The quick-update visualizer writes fit_ellipse.png to that image +# folder during each quick update. +output_search_root = Path("output") / output_root / "ellipse_jit" +produced_pngs = list(output_search_root.rglob("fit_ellipse.png")) +print(f"fit_ellipse.png files produced: {len(produced_pngs)}") +for p in produced_pngs: + print(f" {p}") +assert len(produced_pngs) > 0, ( + f"no fit_ellipse.png produced under {output_search_root} — " + "quick-update visualization did not fire" +) + +# Note: _jitted_fit_from is built on the worker process Nautilus forks for the +# search loop, not the parent's analysis_mge2 instance — so we don't assert it +# post-search. Part 1 above already verifies the cache is set on the calling +# process. + +print( + "\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates " + "for ellipse, fit_ellipse.png written." +) diff --git a/scripts/ellipse/visualization_jax.py b/scripts/ellipse/visualization_jax.py new file mode 100644 index 0000000..e470b39 --- /dev/null +++ b/scripts/ellipse/visualization_jax.py @@ -0,0 +1,166 @@ +""" +Visualization JAX Pilot: Ellipse Analysis (autogalaxy) +====================================================== + +Tests that ``VisualizerEllipse.visualize`` with +``use_jax_for_visualization=True`` dispatches through the JIT-cached +``fit_for_visualization`` path that the parent ``af.Analysis`` already +provides. ``AnalysisEllipse.__init__`` passes ``**kwargs`` to its +parent, so ``use_jax_for_visualization=True`` flows through to the +PyAutoFit-level dispatch without a library-side change. + +Scope +----- +- Single ``af.Model(ag.Ellipse)`` model — no multipoles. Multipoles are + exercised by the non-JAX ``ellipse/visualization.py`` already. +- Calls ``VisualizerEllipse.visualize`` only (not ``visualize_before_fit``). +- Reuses the ``dataset/imaging/jax_test`` dataset that the + ``jax_likelihood_functions`` scripts produce. +- ``use_jax=True`` turns on the JAX path; ``use_jax_for_visualization=True`` + routes ``Visualizer*.visualize`` through ``analysis.fit_for_visualization``. +""" + +import shutil +import subprocess +import sys +from os import path +from pathlib import Path +from types import SimpleNamespace + +from autoconf import conf + +conf.instance.push( + new_path=path.join(path.dirname(path.realpath(__file__)), "config"), + output_path=path.join(path.dirname(path.realpath(__file__)), "images"), +) + +import autofit as af +import autogalaxy as ag +from autofit.jax.pytrees import enable_pytrees, register_model +from autogalaxy.ellipse.model.visualizer import VisualizerEllipse + +enable_pytrees() + + +""" +__Dataset__ + +Reuse the ``jax_test`` imaging dataset (auto-simulated on first run). +""" +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset_unmasked = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask = ag.Mask2D.circular( + shape_native=dataset_unmasked.shape_native, + pixel_scales=dataset_unmasked.pixel_scales, + radius=3.0, +) +dataset = dataset_unmasked.apply_mask(mask=mask) + + +""" +__Model__ + +A single fixed ``Ellipse`` so ``instance_from_prior_medians()`` produces +a deterministic instance with the major-axis well inside the mask. +""" +ellipse = af.Model(ag.Ellipse) +ellipse.centre.centre_0 = 0.0 +ellipse.centre.centre_1 = 0.0 +ellipse.ell_comps.ell_comps_0 = 0.1 +ellipse.ell_comps.ell_comps_1 = 0.05 +ellipse.major_axis = 1.0 + +model = af.Collection(ellipses=af.Collection(ellipse_0=ellipse)) + +register_model(model) + + +""" +__Analysis__ + +``use_jax=True`` turns on the JAX path; ``use_jax_for_visualization=True`` +tells the visualizer to dispatch through the JIT-cached +``fit_for_visualization`` helper on the parent ``af.Analysis``. +``AnalysisEllipse.__init__`` accepts ``**kwargs`` and forwards them to +``super().__init__``, so no AnalysisEllipse signature change is needed. +""" +analysis = ag.AnalysisEllipse( + dataset=dataset, + use_jax=True, + use_jax_for_visualization=True, + title_prefix="JAX_PILOT", +) + + +""" +__Paths__ +""" +image_path = Path("scripts") / "ellipse" / "images" / "visualization_jax" +if image_path.exists(): + shutil.rmtree(image_path) +image_path.mkdir(parents=True) +output_path = image_path / "output" +output_path.mkdir(parents=True) +paths = SimpleNamespace(image_path=image_path, output_path=output_path) + + +""" +__Run visualize on the eager-JAX fit__ +""" +instance = model.instance_from_prior_medians() + +print("Running VisualizerEllipse.visualize with use_jax_for_visualization=True ...") +VisualizerEllipse.visualize( + analysis=analysis, + paths=paths, + instance=instance, + during_analysis=False, +) + +# `fit_ellipse.png` is one of the artifacts that the non-JAX +# `ellipse/visualization.py` script asserts on; check the same one here. +assert (image_path / "fit_ellipse.png").exists(), ( + "fit_ellipse.png was not produced by the JAX-backed visualizer" +) +print("PILOT SUCCEEDED — JAX-backed ellipse visualization produced fit_ellipse.png.") + + +""" +__Visualization Sanity__ + +Phase D.2.b.ii — autogalaxy ellipse variant (no Tracer / no lensing +latents). ``FitEllipseSummed`` is the fit object; ``figure_of_merit`` is +the aggregate log-likelihood the search would use. Asserts FoM is finite +on the prior-median instance, catching JAX-trace mismatches that would +leave the cosmetic plot OK but the underlying log-likelihood nan/inf. + +Note: ``FitEllipseSummed`` doesn't expose a single ``model_data`` array +the way ``FitImaging`` does — the model field lives per-ellipse on the +component ``FitEllipse`` objects. The figure_of_merit check covers the +end-to-end fit; per-component model_data assertion would just duplicate +the FoM signal with extra fragility. +""" +import numpy as _sanity_np + +_fit_for_vis = analysis.fit_from(instance=instance) +_fom = float(_fit_for_vis.figure_of_merit) +assert _sanity_np.isfinite(_fom), ( + f"figure_of_merit = {_fom} — chi² nan/inf, fit collapsed" +) +print( + f" PASS Visualization Sanity (autogalaxy ellipse): " + f"figure_of_merit = {_fom:.4f}" +)