From 798af343807c3c1326801909a16a7cb7aff3cec6 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 26 Apr 2026 13:07:59 +0100 Subject: [PATCH 1/4] feat: add scripts/imaging/ port from autolens_workspace_test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds four imaging integration test scripts mirroring the autolens versions but targeting the single-galaxy autogalaxy API: - model_fit.py — end-to-end Nautilus fit on a parametric Sersic galaxy - visualization.py — VisualizerImaging assertions for parametric, rectangular and Delaunay variants (dataset / adapt_images / fit / galaxies / inversion) - visualization_jax.py — eager-JAX VisualizerImaging.visualize on an MGE galaxy - modeling_visualization_jit.py — JIT caching probe + live Nautilus quick-update with linear-MGE light profiles Reuses the dataset/imaging/jax_test/ dataset already produced by scripts/jax_likelihood_functions/imaging/simulator.py (auto-simulated on first run). Adds two minimal config overrides for the per-galaxy visualization runs. smoke_tests.txt: appends the two reliably-fast scripts (model_fit.py and visualization.py). The JAX scripts are intentionally left out of the curated smoke list until they pass on CI. Closes task 9/9 of the autogalaxy_workspace_test coverage epic (PyAutoLabs/autogalaxy_workspace_test#5). Depends on PyAutoLabs/PyAutoGalaxy#367 for use_jax_for_visualization kwarg threading through ag.AnalysisImaging. --- scripts/imaging/__init__.py | 0 scripts/imaging/config/visualize/plots.yaml | 31 ++ .../config_source/visualize/plots.yaml | 31 ++ scripts/imaging/model_fit.py | 110 ++++++ scripts/imaging/modeling_visualization_jit.py | 231 ++++++++++++ scripts/imaging/visualization.py | 339 ++++++++++++++++++ scripts/imaging/visualization_jax.py | 136 +++++++ smoke_tests.txt | 2 + 8 files changed, 880 insertions(+) create mode 100644 scripts/imaging/__init__.py create mode 100644 scripts/imaging/config/visualize/plots.yaml create mode 100644 scripts/imaging/config_source/visualize/plots.yaml create mode 100644 scripts/imaging/model_fit.py create mode 100644 scripts/imaging/modeling_visualization_jit.py create mode 100644 scripts/imaging/visualization.py create mode 100644 scripts/imaging/visualization_jax.py diff --git a/scripts/imaging/__init__.py b/scripts/imaging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/imaging/config/visualize/plots.yaml b/scripts/imaging/config/visualize/plots.yaml new file mode 100644 index 0000000..64d4d21 --- /dev/null +++ b/scripts/imaging/config/visualize/plots.yaml @@ -0,0 +1,31 @@ +subplot_format: [png] +fits_are_zoomed: false + +dataset: + subplot_dataset: true + fits_dataset: true + +fit: + subplot_fit: true + subplot_fit_log10: false + subplot_of_galaxies: false + subplot_galaxy_images: false + fits_fit: false + fits_galaxy_images: false + fits_model_galaxy_images: false + +fit_imaging: {} + +galaxies: + subplot_galaxies: false + subplot_galaxy_images: false + fits_galaxy_images: false + +inversion: + subplot_inversion: false + subplot_mappings: false + csv_reconstruction: false + +adapt: + subplot_adapt_images: true + fits_adapt_images: true diff --git a/scripts/imaging/config_source/visualize/plots.yaml b/scripts/imaging/config_source/visualize/plots.yaml new file mode 100644 index 0000000..dd00742 --- /dev/null +++ b/scripts/imaging/config_source/visualize/plots.yaml @@ -0,0 +1,31 @@ +subplot_format: [png] +fits_are_zoomed: false + +dataset: + subplot_dataset: false + fits_dataset: false + +fit: + subplot_fit: true + subplot_fit_log10: false + subplot_of_galaxies: false + subplot_galaxy_images: false + fits_fit: false + fits_galaxy_images: false + fits_model_galaxy_images: false + +fit_imaging: {} + +galaxies: + subplot_galaxies: true + subplot_galaxy_images: false + fits_galaxy_images: false + +inversion: + subplot_inversion: true + subplot_mappings: false + csv_reconstruction: false + +adapt: + subplot_adapt_images: false + fits_adapt_images: false diff --git a/scripts/imaging/model_fit.py b/scripts/imaging/model_fit.py new file mode 100644 index 0000000..671aaf8 --- /dev/null +++ b/scripts/imaging/model_fit.py @@ -0,0 +1,110 @@ +""" +Modeling: Single Galaxy Sersic Fit +================================== + +End-to-end imaging model-fit on the autogalaxy single-galaxy ``jax_test`` dataset. +Exercises ``AnalysisImaging`` -> ``FitImaging`` with a Nautilus search. + +Galaxy: a single ``Sersic`` bulge — no lens / mass / source split (this is autogalaxy, +not autolens). +""" + +import os +from os import path + +import autofit as af +import autogalaxy as ag +import autogalaxy.plot as aplt + + +""" +__Dataset__ + +Reuse the ``jax_test`` dataset already used by ``scripts/jax_likelihood_functions/imaging``. +""" +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +aplt.plot_array(array=dataset.data) + + +""" +__Mask__ +""" +mask_radius = 3.0 + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) + +dataset = dataset.apply_mask(mask=mask) + +aplt.plot_array(array=dataset.data) + + +""" +__Model__ + +Single galaxy with a parametric ``Sersic`` bulge. +""" +bulge = af.Model(ag.lp.Sersic) +bulge.centre.centre_0 = 0.0 +bulge.centre.centre_1 = 0.0 + +galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge) + +model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) + + +""" +__Search__ +""" +search = af.Nautilus( + path_prefix=path.join("build", "model_fit", "imaging"), + n_live=50, + n_like_max=300, + number_of_cores=2, +) + + +""" +__Analysis__ +""" +analysis = ag.AnalysisImaging(dataset=dataset) + + +""" +__Model-Fit__ +""" +result = search.fit(model=model, analysis=analysis) + + +""" +__Result__ +""" +print(result.max_log_likelihood_instance) + +aplt.subplot_galaxies( + galaxies=result.max_log_likelihood_galaxies, grid=result.grids.lp +) + +aplt.subplot_fit_imaging(fit=result.max_log_likelihood_fit) + +aplt.corner_cornerpy(samples=result.samples) diff --git a/scripts/imaging/modeling_visualization_jit.py b/scripts/imaging/modeling_visualization_jit.py new file mode 100644 index 0000000..7666953 --- /dev/null +++ b/scripts/imaging/modeling_visualization_jit.py @@ -0,0 +1,231 @@ +""" +End-to-end test: jit-cached visualization during a real Nautilus model-fit. +========================================================================== + +Single-galaxy autogalaxy port of the autolens +``scripts/imaging/modeling_visualization_jit.py`` end-to-end test. + +This test runs in two parts: + +Part 1 — **MGE caching probe.** Uses an MGE galaxy model (Basis of +``ag.lp_linear.Gaussian`` profiles). Calls +``analysis.fit_for_visualization(instance)`` twice and asserts the second +call is much faster than the first (confirming the compiled function is +cached on the analysis instance, not recompiled per visualization). + +Part 2 — **Live Nautilus quick-update with linear light profiles.** Runs a +real (short) Nautilus fit with the same MGE galaxy. With autogalaxy's +``LightProfileLinear`` pytree handling, the +``linear_light_profile_intensity_dict`` lookup survives the JAX pytree +round-trip and no ``KeyError`` is raised. Asserts that ``fit.png`` files +land on disk, proving the JIT-cached fit_for_visualization fires correctly +during the live search callback. + +This script deliberately opts in with +``AnalysisImaging(use_jax=True, use_jax_for_visualization=True)``. Default +model-fit scripts elsewhere in the workspace leave both flags at ``False`` +and are therefore untouched by this change. +""" + +import shutil +import time +from os import path +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autogalaxy as ag +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() + + +""" +__Dataset__ + +Re-use the ``jax_test`` dataset that the jax_likelihood_functions scripts rely on. +""" +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask_radius = 3.5 +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) +dataset = dataset.apply_mask(mask=mask) +dataset = dataset.apply_over_sampling(over_sample_size_lp=4) + + +""" +============================================================================ +Part 1 — MGE caching probe +============================================================================ + +Model: MGE parametric galaxy (Basis of ``ag.lp_linear.Gaussian``). +""" +print("\n" + "=" * 72) +print("Part 1: MGE caching probe") +print("=" * 72) + +total_gaussians = 3 +log10_sigma_list = np.linspace(-2, np.log10(mask_radius), total_gaussians) + +centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +centre_1 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) + +gaussian_list = af.Collection( + af.Model(ag.lp_linear.Gaussian) for _ in range(total_gaussians) +) +for i, gaussian in enumerate(gaussian_list): + gaussian.centre.centre_0 = centre_0 + gaussian.centre.centre_1 = centre_1 + gaussian.ell_comps = gaussian_list[0].ell_comps + gaussian.sigma = 10 ** log10_sigma_list[i] + +bulge_mge = af.Model(ag.lp_basis.Basis, profile_list=list(gaussian_list)) + +galaxy_mge = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge_mge) + +model_mge = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge)) + +register_model(model_mge) + +analysis_mge = ag.AnalysisImaging( + dataset=dataset, + use_jax=True, + use_jax_for_visualization=True, +) + +instance_mge = model_mge.instance_from_prior_medians() + +t0 = time.perf_counter() +fit_1 = analysis_mge.fit_for_visualization(instance_mge) +jax.block_until_ready(fit_1.log_likelihood) +t1 = time.perf_counter() +compile_time = t1 - t0 +print(f"First call (compile + run): {compile_time:.3f}s") +print(f" log_likelihood leaf type: {type(fit_1.log_likelihood).__name__}") +assert isinstance(fit_1.log_likelihood, jnp.ndarray), ( + f"expected jax.Array, got {type(fit_1.log_likelihood)}" +) + +t0 = time.perf_counter() +fit_2 = analysis_mge.fit_for_visualization(instance_mge) +jax.block_until_ready(fit_2.log_likelihood) +t1 = time.perf_counter() +cached_time = t1 - t0 +print(f"Second call (cached): {cached_time:.3f}s") +print(f"Speedup: {compile_time / max(cached_time, 1e-9):.1f}x") + +assert cached_time < compile_time * 0.5, ( + f"Cached call ({cached_time:.3f}s) not faster than compile " + f"({compile_time:.3f}s) — JIT cache is not being hit." +) +assert analysis_mge._jitted_fit_from is not None, ( + "expected _jitted_fit_from to be cached on the analysis instance after first call" +) +print("PASS: MGE jit-cached fit_for_visualization works and is reused.") + + +""" +============================================================================ +Part 2 — Live Nautilus quick-update with linear light profiles +============================================================================ + +Model: MGE parametric galaxy (Basis of ``ag.lp_linear.Gaussian``). The +``linear_light_profile_intensity_dict`` lookup is exercised during +visualization. The live search fires quick-update visualization every +``iterations_per_quick_update`` calls; we verify ``fit.png`` lands on disk. +""" +print("\n" + "=" * 72) +print("Part 2: Live Nautilus with linear MGE profiles + jit-visualization") +print("=" * 72) + +total_gaussians2 = 3 +log10_sigma_list2 = np.linspace(-2, np.log10(mask_radius), total_gaussians2) + +centre_0_2 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +centre_1_2 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) + +gaussian_list2 = af.Collection( + af.Model(ag.lp_linear.Gaussian) for _ in range(total_gaussians2) +) +for i, gaussian in enumerate(gaussian_list2): + gaussian.centre.centre_0 = centre_0_2 + gaussian.centre.centre_1 = centre_1_2 + gaussian.ell_comps = gaussian_list2[0].ell_comps + gaussian.sigma = 10 ** log10_sigma_list2[i] + +bulge_mge2 = af.Model(ag.lp_basis.Basis, profile_list=list(gaussian_list2)) + +galaxy_mge2 = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge_mge2) + +model_mge2 = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge2)) + +register_model(model_mge2) + +analysis_mge2 = ag.AnalysisImaging( + dataset=dataset, + use_jax=True, + use_jax_for_visualization=True, +) + +output_root = Path("scripts") / "imaging" / "images" / "modeling_visualization_jit" +if output_root.exists(): + shutil.rmtree(output_root) +output_root.mkdir(parents=True) + +search = af.Nautilus( + path_prefix=str(output_root), + name="mge_linear", + n_live=50, + n_like_max=1500, + iterations_per_quick_update=500, + number_of_cores=1, +) + +print("Running Nautilus ...") +result = search.fit(model=model_mge2, analysis=analysis_mge2) + +# The Nautilus output goes to output////image/ +# The quick-update visualizer writes fit.png (via subplot_fit function) +# to that image folder during each quick update. +output_search_root = Path("output") / output_root / "mge_linear" +produced_pngs = list(output_search_root.rglob("fit.png")) +print(f"fit.png files produced: {len(produced_pngs)}") +for p in produced_pngs: + print(f" {p}") +assert len(produced_pngs) > 0, ( + f"no fit.png produced under {output_search_root} — " + "quick-update visualization did not fire" +) + +# Note: _jitted_fit_from is built on the worker process Nautilus forks for the search +# loop, not the parent's analysis_mge2 instance — so we don't assert it post-search. +# Part 1 above already verifies the cache is set on the calling process. + +print( + "\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates " + "with linear MGE profiles, fit.png written, no KeyError from " + "linear_light_profile_intensity_dict lookup." +) diff --git a/scripts/imaging/visualization.py b/scripts/imaging/visualization.py new file mode 100644 index 0000000..379f5f3 --- /dev/null +++ b/scripts/imaging/visualization.py @@ -0,0 +1,339 @@ +""" +Visualization: Imaging Analysis +================================ + +Tests that ``VisualizerImaging.visualize_before_fit`` and ``visualize`` output all expected +files to disk and that each output has the correct FITS HDU structure. + +Dataset: single-galaxy ``jax_test`` imaging (Sersic bulge + Exponential disk simulator). + +Structure +--------- +1. ``visualize_before_fit`` runs once with a parametric source (fastest) and writes all + before-fit outputs (dataset.png/.fits, adapt_images.png/.fits) to the main + ``visualization/`` folder. + +2. ``visualize`` runs once per galaxy type, each writing into its own subfolder: + visualization/parametric/ — Sersic light-profile galaxy + visualization/rectangular/ — RectangularAdaptImage pixelization + visualization/delaunay/ — Delaunay pixelization + + Each subfolder contains only the per-run comparison plots: + fit.png, galaxies.png (all three types) + inversion_0_0.png (rectangular and delaunay only) + + A minimal ``config_source/visualize/plots.yaml`` (pushed before these runs) limits + output to just those files so the per-source runs stay fast. + +Expected outputs are derived directly from the source code of: + - autogalaxy/imaging/model/visualizer.py (VisualizerImaging) + - autogalaxy/imaging/model/plotter.py (PlotterImaging) + - autogalaxy/analysis/plotter.py (Plotter: galaxies, inversion) + - autogalaxy/imaging/plot/fit_imaging_plots.py +""" + +import shutil +import time +from os import path +from pathlib import Path +from types import SimpleNamespace + +from autoconf import conf + +conf.instance.push( + new_path=path.join(path.dirname(path.realpath(__file__)), "config"), + output_path=path.join(path.dirname(path.realpath(__file__)), "images"), +) + +import numpy as np +from astropy.io import fits as astropy_fits + +import autofit as af +import autogalaxy as ag +from autogalaxy.imaging.model.visualizer import VisualizerImaging + + +""" +__Dataset__ + +Reuse the ``jax_test`` dataset from ``scripts/jax_likelihood_functions/imaging``. +""" +pixel_scale = 0.2 + +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=pixel_scale, + over_sample_size_lp=2, + over_sample_size_pixelization=2, +) + +mask_radius = 3.0 + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) + +dataset = dataset.apply_mask(mask=mask) + + +""" +__Galaxy Models__ + +Three single-galaxy configurations, exercising the three plotter code paths +(parametric / rectangular pixelization / Delaunay pixelization). +""" + +# --- Parametric (Sersic) --- +galaxy_bulge = af.Model(ag.lp.Sersic) +galaxy_bulge.centre.centre_0 = 0.0 +galaxy_bulge.centre.centre_1 = 0.0 +galaxy_bulge.intensity = 1.0 +galaxy_bulge.effective_radius = 0.6 +galaxy_bulge.sersic_index = 3.0 +galaxy_parametric = af.Model(ag.Galaxy, redshift=0.5, bulge=galaxy_bulge) +model_parametric = af.Collection(galaxies=af.Collection(galaxy=galaxy_parametric)) + +# --- Rectangular pixelization --- +mesh_rect = ag.mesh.RectangularAdaptImage(shape=(22, 22)) +reg_rect = ag.reg.Constant(coefficient=1.0) +pix_rect = ag.Pixelization(mesh=mesh_rect, regularization=reg_rect) +galaxy_rectangular = af.Model(ag.Galaxy, redshift=0.5, pixelization=pix_rect) +model_rectangular = af.Collection( + galaxies=af.Collection(galaxy=galaxy_rectangular) +) + +# --- Delaunay pixelization --- +image_mesh = ag.image_mesh.Overlay(shape=(22, 22)) +image_plane_mesh_grid = image_mesh.image_plane_mesh_grid_from(mask=dataset.mask) + +mesh_del = ag.mesh.Delaunay(pixels=image_plane_mesh_grid.shape[0], zeroed_pixels=0) +reg_del = ag.reg.ConstantSplit(coefficient=1.0) +pix_del = ag.Pixelization(mesh=mesh_del, regularization=reg_del) +galaxy_delaunay = af.Model(ag.Galaxy, redshift=0.5, pixelization=pix_del) +model_delaunay = af.Collection(galaxies=af.Collection(galaxy=galaxy_delaunay)) + + +""" +__Adapt Images__ + +Used to test that adapt_images.png/.fits are written by visualize_before_fit. +""" +adapt_images = ag.AdaptImages( + galaxy_name_image_dict={ + "('galaxies', 'galaxy')": dataset.data, + }, + galaxy_name_image_plane_mesh_grid_dict={ + "('galaxies', 'galaxy')": image_plane_mesh_grid + }, +) + + +""" +__Analysis__ + +A single analysis object is shared across all three galaxy runs: it holds the +dataset only; the galaxy type is determined by the instance passed to visualize. +""" +analysis = ag.AnalysisImaging( + dataset=dataset, + adapt_images=adapt_images, + use_jax=True, + title_prefix="TEST", +) + + +""" +__Paths__ + +Minimal paths stub: VisualizerImaging only needs image_path and output_path. +Clean the output directory on each run so assertions reflect this run only. +""" + +image_path = Path("scripts") / "imaging" / "images" / "visualization" + +if image_path.exists(): + shutil.rmtree(image_path) + +image_path.mkdir(parents=True) + +output_path = image_path / "output" +output_path.mkdir(parents=True) + +paths = SimpleNamespace( + image_path=image_path, + output_path=output_path, +) + + +""" +__Visualize Before Fit__ + +Uses the parametric galaxy (fastest) for all before-fit outputs. + +Calls PlotterImaging.imaging() -> dataset.png, dataset.fits + Plotter.adapt_images() -> adapt_images.png, adapt_images.fits +""" + +print("Running visualize_before_fit (parametric galaxy)...") + +_t0 = time.perf_counter() +VisualizerImaging.visualize_before_fit( + analysis=analysis, + paths=paths, + model=model_parametric, +) +print(f"visualize_before_fit complete in {time.perf_counter() - _t0:.2f}s") + + +""" +__Assertions: visualize_before_fit__ +""" + +# ---- dataset.fits ---- +# Source: PlotterImaging.imaging() -> hdu_list_for_output_from with ext_name_list: +# ["mask", "data", "noise_map", "psf", "over_sample_size_lp", "over_sample_size_pixelization"] +# HDU 0 is PrimaryHDU (first value), HDUs 1-5 are ImageHDU. + +assert (image_path / "dataset.png").exists(), "dataset.png missing" +print("dataset.png OK") + +with astropy_fits.open(image_path / "dataset.fits") as hdul: + assert len(hdul) == 6, f"dataset.fits: expected 6 HDUs, got {len(hdul)}" + assert hdul[0].name == "MASK" + assert hdul[1].name == "DATA" + assert hdul[2].name == "NOISE_MAP" + assert hdul[3].name == "PSF" + assert hdul[4].name == "OVER_SAMPLE_SIZE_LP" + assert hdul[5].name == "OVER_SAMPLE_SIZE_PIXELIZATION" + assert hdul[1].data.ndim == 2, "DATA HDU should be 2D" +print("dataset.fits OK") + +# ---- adapt_images.fits ---- +# Source: Plotter.adapt_images() -> hdu_list_for_output_from with ext_name_list: +# ["mask", "('galaxies', 'galaxy')"] +# HDU 0 = MASK (Primary), HDU 1 = galaxy key (uppercased). + +assert (image_path / "adapt_images.png").exists(), "adapt_images.png missing" +print("adapt_images.png OK") + +with astropy_fits.open(image_path / "adapt_images.fits") as hdul: + assert len(hdul) == 2, f"adapt_images.fits: expected 2 HDUs, got {len(hdul)}" + assert hdul[0].name == "MASK" +print("adapt_images.fits OK") + + +""" +__Push Minimal Config for Per-Source Runs__ + +Override the all-true config with a minimal one that only enables: + fit.subplot_fit, galaxies.subplot_galaxies, inversion.subplot_inversion. +All other toggles are explicitly set to false so no extra files are written. +""" +conf.instance.push( + new_path=path.join(path.dirname(path.realpath(__file__)), "config_source"), + output_path=path.join(path.dirname(path.realpath(__file__)), "images"), +) + + +""" +__Per-Source Visualization__ + +For each galaxy type, visualize is run in a dedicated subfolder. +Only fit.png and galaxies.png are generated for all three; rectangular and delaunay +also produce inversion_0_0.png. + +Calls (governed by config_source/visualize/plots.yaml): + fit.subplot_fit -> fit.png + galaxies.subplot_galaxies -> galaxies.png + inversion.subplot_inversion -> inversion_0_0.png (pixelized galaxies only) +""" + +source_runs = [ + ("parametric", model_parametric, False), + ("rectangular", model_rectangular, True), + ("delaunay", model_delaunay, True), +] + +for source_name, model, has_inversion in source_runs: + print(f"\nRunning visualize for galaxy: {source_name}...") + + sub_path = image_path / source_name + sub_path.mkdir(parents=True) + sub_output = sub_path / "output" + sub_output.mkdir(parents=True) + sub_paths = SimpleNamespace(image_path=sub_path, output_path=sub_output) + + instance = model.instance_from_prior_medians() + + _t0 = time.perf_counter() + VisualizerImaging.visualize( + analysis=analysis, + paths=sub_paths, + instance=instance, + during_analysis=False, + ) + print(f" visualize complete for {source_name} in {time.perf_counter() - _t0:.2f}s") + + assert (sub_path / "fit.png").exists(), f"{source_name}/fit.png missing" + print(f" {source_name}/fit.png OK") + assert (sub_path / "galaxies.png").exists(), f"{source_name}/galaxies.png missing" + print(f" {source_name}/galaxies.png OK") + if has_inversion: + assert ( + sub_path / "inversion_0_0.png" + ).exists(), f"{source_name}/inversion_0_0.png missing" + print(f" {source_name}/inversion_0_0.png OK") + + +""" +__RGB Visualization__ + +Tests that ``plot_array`` correctly handles ``Array2DRGB`` inputs: no colormap, +no norm, no colorbar — the image is rendered via plain ``imshow`` as an RGB image. +""" + +print("\nRunning RGB visualization test...") + +import autogalaxy.plot as aplt + +rgb_values = np.stack( + [dataset.data.native, dataset.data.native, dataset.data.native], axis=-1 +) +rgb_values = np.clip(rgb_values, 0, None) + +rgb_values_uint8 = ( + (rgb_values / rgb_values.max() * 255).astype(np.uint8) + if rgb_values.max() > 0 + else np.zeros_like(rgb_values, dtype=np.uint8) +) + +rgb_array = ag.Array2DRGB(values=rgb_values_uint8, mask=dataset.mask) + +aplt.plot_array( + array=rgb_array, + title="RGB Test", + output_path=image_path, + output_filename="rgb_array", + output_format="png", +) + +assert (image_path / "rgb_array.png").exists(), "rgb_array.png missing" +print("rgb_array.png OK") + + +print("All visualization assertions passed.") diff --git a/scripts/imaging/visualization_jax.py b/scripts/imaging/visualization_jax.py new file mode 100644 index 0000000..f90dead --- /dev/null +++ b/scripts/imaging/visualization_jax.py @@ -0,0 +1,136 @@ +""" +Visualization JAX Pilot: Imaging Analysis (autogalaxy) +======================================================= + +Single-galaxy autogalaxy port of the autolens ``visualization_jax.py`` pilot +(https://github.com/PyAutoLabs/PyAutoFit/issues/1227). + +Goal +---- +Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end, gated behind +``use_jax_for_visualization`` on ``Analysis``. A parametric MGE galaxy is used +deliberately (simplest case — no pixelization, no inversion). + +This is **Path C**: ``fit_from`` runs on the eager JAX path +(``use_jax=True`` makes ``_xp`` be ``jnp``) and returns a ``FitImaging`` backed +by ``jax.Array`` objects. Matplotlib-bound plotters materialise arrays to NumPy +at the boundary. No ``jax.jit`` is applied to ``fit_from`` — the full-JIT path +(Path A) is exercised by ``modeling_visualization_jit.py``. + +Scope +----- +- Parametric MGE galaxy only. +- Calls ``VisualizerImaging.visualize`` only (not ``visualize_before_fit``). +- Re-uses the ``jax_test`` dataset from ``jax_likelihood_functions/imaging``. +- Reuses ``config_source/visualize/plots.yaml`` from ``visualization.py`` so + only ``fit.png`` and ``galaxies.png`` are attempted. +""" + +import shutil +import traceback +from os import path +from pathlib import Path +from types import SimpleNamespace + +from autoconf import conf + +conf.instance.push( + new_path=path.join(path.dirname(path.realpath(__file__)), "config_source"), + output_path=path.join(path.dirname(path.realpath(__file__)), "images"), +) + +import autofit as af +import autogalaxy as ag +from autogalaxy.imaging.model.visualizer import VisualizerImaging + + +""" +__Dataset__ +""" +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask_radius = 3.0 +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) +dataset = dataset.apply_mask(mask=mask) + + +""" +__Model__ + +MGE parametric galaxy (matches the MGE pattern in +``jax_likelihood_functions/imaging/mge.py``). +""" +galaxy_bulge = ag.model_util.mge_model_from( + mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=True +) +galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=galaxy_bulge) +model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) + + +""" +__Analysis__ + +``use_jax=True`` turns on the JAX ``_xp`` path; ``use_jax_for_visualization=True`` +tells the search-level visualization path to wrap ``fit_from`` in ``jax.jit`` +via the ``Analysis.fit_for_visualization`` helper. +""" +analysis = ag.AnalysisImaging( + dataset=dataset, + use_jax=True, + use_jax_for_visualization=True, + title_prefix="JAX_PILOT", +) + + +""" +__Paths__ +""" +image_path = Path("scripts") / "imaging" / "images" / "visualization_jax" +if image_path.exists(): + shutil.rmtree(image_path) +image_path.mkdir(parents=True) +output_path = image_path / "output" +output_path.mkdir(parents=True) +paths = SimpleNamespace(image_path=image_path, output_path=output_path) + + +""" +__Run visualize on the eager-JAX fit__ +""" +instance = model.instance_from_prior_medians() + +print("Running VisualizerImaging.visualize with use_jax_for_visualization=True ...") +try: + VisualizerImaging.visualize( + analysis=analysis, + paths=paths, + instance=instance, + during_analysis=False, + ) + assert (image_path / "fit.png").exists(), "fit.png was not produced" + print("PILOT SUCCEEDED — JAX-backed visualization produced fit.png/galaxies.png.") +except Exception: + print("PILOT FAILED — traceback below:") + print("=" * 72) + traceback.print_exc() + print("=" * 72) diff --git a/smoke_tests.txt b/smoke_tests.txt index dc0464e..6306d80 100644 --- a/smoke_tests.txt +++ b/smoke_tests.txt @@ -6,3 +6,5 @@ jax_likelihood_functions/imaging/lp.py jax_likelihood_functions/imaging/mge.py jax_likelihood_functions/imaging/mge_group.py jax_likelihood_functions/imaging/rectangular.py +imaging/model_fit.py +imaging/visualization.py From 8e200451b2dcb86e1e63c43a4f423227fd0d3e04 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 26 Apr 2026 13:11:45 +0100 Subject: [PATCH 2/4] fix: disable scripts/imaging/visualization.py in smoke list MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The smoke runner sets PYAUTO_FAST_PLOTS=1, which skips savefig calls. visualization.py asserts that dataset.png/fit.png/etc. land on disk — those assertions fail under fast-plots mode by design. Mirrors the same disable in autolens_workspace_test/smoke_tests.txt (rhayes777/PyAutoFit#1179). --- smoke_tests.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smoke_tests.txt b/smoke_tests.txt index 6306d80..38e3658 100644 --- a/smoke_tests.txt +++ b/smoke_tests.txt @@ -7,4 +7,4 @@ jax_likelihood_functions/imaging/mge.py jax_likelihood_functions/imaging/mge_group.py jax_likelihood_functions/imaging/rectangular.py imaging/model_fit.py -imaging/visualization.py +# imaging/visualization.py # disabled: bypass mode mkdir race condition / fast-plots skips savefig (rhayes777/PyAutoFit#1179) From d2fe80963c4aab782b9e9a5f70403959cadc6644 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 26 Apr 2026 13:16:13 +0100 Subject: [PATCH 3/4] fix: clear PYAUTO_FAST_PLOTS in visualization.py + re-enable in smoke list MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit visualization.py asserts that PNG/FITS files land on disk. The smoke runner sets PYAUTO_FAST_PLOTS=1, which skips savefig — so the script would fail those assertions under smoke despite passing on direct runs. Pop PYAUTO_FAST_PLOTS from os.environ at the top of the script (before any plotting code is imported). Each smoke script runs in its own subprocess, so this only affects this script's process — other scripts in the smoke list still see PYAUTO_FAST_PLOTS=1. With that fix the script passes under the smoke profile, so re-enable imaging/visualization.py in smoke_tests.txt. --- scripts/imaging/visualization.py | 7 +++++++ smoke_tests.txt | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/scripts/imaging/visualization.py b/scripts/imaging/visualization.py index 379f5f3..9d7fba3 100644 --- a/scripts/imaging/visualization.py +++ b/scripts/imaging/visualization.py @@ -32,12 +32,19 @@ - autogalaxy/imaging/plot/fit_imaging_plots.py """ +import os import shutil import time from os import path from pathlib import Path from types import SimpleNamespace +# This script asserts that subplot PNG / FITS files land on disk. The smoke +# runner sets PYAUTO_FAST_PLOTS=1 to skip savefig for speed, which would cause +# every assertion below to fail. Clear it before any plotting code is imported +# so this script behaves the same under the smoke runner as a direct run. +os.environ.pop("PYAUTO_FAST_PLOTS", None) + from autoconf import conf conf.instance.push( diff --git a/smoke_tests.txt b/smoke_tests.txt index 38e3658..6306d80 100644 --- a/smoke_tests.txt +++ b/smoke_tests.txt @@ -7,4 +7,4 @@ jax_likelihood_functions/imaging/mge.py jax_likelihood_functions/imaging/mge_group.py jax_likelihood_functions/imaging/rectangular.py imaging/model_fit.py -# imaging/visualization.py # disabled: bypass mode mkdir race condition / fast-plots skips savefig (rhayes777/PyAutoFit#1179) +imaging/visualization.py From 8bacc918dfb6667a0a06c7b32ac8ab96f0486b2d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 26 Apr 2026 13:39:03 +0100 Subject: [PATCH 4/4] fix: unset PYAUTO_FAST_PLOTS for imaging/visualization in smoke runs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move the fast-plots disable from an os.environ.pop in the script to the proper place: config/build/env_vars.yaml. The smoke runner (PyAutoBuild's env_config.py) reads this file's defaults + overrides to construct each script's subprocess env, so the YAML is the right contract for runtime flags rather than mutating os.environ inside the script. The override matches the existing precedent for jax_likelihood_functions/. The pattern imaging/visualization is a substring match against the extension-stripped path, so it covers both visualization.py and visualization_jax.py — both rely on PNG output for their assertions. --- config/build/env_vars.yaml | 4 ++++ scripts/imaging/visualization.py | 7 ------- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/config/build/env_vars.yaml b/config/build/env_vars.yaml index 2513680..af727a2 100644 --- a/config/build/env_vars.yaml +++ b/config/build/env_vars.yaml @@ -21,3 +21,7 @@ overrides: # JAX likelihood functions test JIT compilation — need JAX enabled and full-size datasets - pattern: "jax_likelihood_functions/" unset: [PYAUTO_WORKSPACE_SMALL_DATASETS, PYAUTO_DISABLE_JAX] + # imaging/visualization.py asserts subplot PNG / FITS files land on disk — + # PYAUTO_FAST_PLOTS skips savefig and would break those assertions. + - pattern: "imaging/visualization" + unset: [PYAUTO_FAST_PLOTS] diff --git a/scripts/imaging/visualization.py b/scripts/imaging/visualization.py index 9d7fba3..379f5f3 100644 --- a/scripts/imaging/visualization.py +++ b/scripts/imaging/visualization.py @@ -32,19 +32,12 @@ - autogalaxy/imaging/plot/fit_imaging_plots.py """ -import os import shutil import time from os import path from pathlib import Path from types import SimpleNamespace -# This script asserts that subplot PNG / FITS files land on disk. The smoke -# runner sets PYAUTO_FAST_PLOTS=1 to skip savefig for speed, which would cause -# every assertion below to fail. Clear it before any plotting code is imported -# so this script behaves the same under the smoke runner as a direct run. -os.environ.pop("PYAUTO_FAST_PLOTS", None) - from autoconf import conf conf.instance.push(