diff --git a/config/build/env_vars.yaml b/config/build/env_vars.yaml index 2513680..af727a2 100644 --- a/config/build/env_vars.yaml +++ b/config/build/env_vars.yaml @@ -21,3 +21,7 @@ overrides: # JAX likelihood functions test JIT compilation — need JAX enabled and full-size datasets - pattern: "jax_likelihood_functions/" unset: [PYAUTO_WORKSPACE_SMALL_DATASETS, PYAUTO_DISABLE_JAX] + # imaging/visualization.py asserts subplot PNG / FITS files land on disk — + # PYAUTO_FAST_PLOTS skips savefig and would break those assertions. + - pattern: "imaging/visualization" + unset: [PYAUTO_FAST_PLOTS] diff --git a/scripts/imaging/__init__.py b/scripts/imaging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/imaging/config/visualize/plots.yaml b/scripts/imaging/config/visualize/plots.yaml new file mode 100644 index 0000000..64d4d21 --- /dev/null +++ b/scripts/imaging/config/visualize/plots.yaml @@ -0,0 +1,31 @@ +subplot_format: [png] +fits_are_zoomed: false + +dataset: + subplot_dataset: true + fits_dataset: true + +fit: + subplot_fit: true + subplot_fit_log10: false + subplot_of_galaxies: false + subplot_galaxy_images: false + fits_fit: false + fits_galaxy_images: false + fits_model_galaxy_images: false + +fit_imaging: {} + +galaxies: + subplot_galaxies: false + subplot_galaxy_images: false + fits_galaxy_images: false + +inversion: + subplot_inversion: false + subplot_mappings: false + csv_reconstruction: false + +adapt: + subplot_adapt_images: true + fits_adapt_images: true diff --git a/scripts/imaging/config_source/visualize/plots.yaml b/scripts/imaging/config_source/visualize/plots.yaml new file mode 100644 index 0000000..dd00742 --- /dev/null +++ b/scripts/imaging/config_source/visualize/plots.yaml @@ -0,0 +1,31 @@ +subplot_format: [png] +fits_are_zoomed: false + +dataset: + subplot_dataset: false + fits_dataset: false + +fit: + subplot_fit: true + subplot_fit_log10: false + subplot_of_galaxies: false + subplot_galaxy_images: false + fits_fit: false + fits_galaxy_images: false + fits_model_galaxy_images: false + +fit_imaging: {} + +galaxies: + subplot_galaxies: true + subplot_galaxy_images: false + fits_galaxy_images: false + +inversion: + subplot_inversion: true + subplot_mappings: false + csv_reconstruction: false + +adapt: + subplot_adapt_images: false + fits_adapt_images: false diff --git a/scripts/imaging/model_fit.py b/scripts/imaging/model_fit.py new file mode 100644 index 0000000..671aaf8 --- /dev/null +++ b/scripts/imaging/model_fit.py @@ -0,0 +1,110 @@ +""" +Modeling: Single Galaxy Sersic Fit +================================== + +End-to-end imaging model-fit on the autogalaxy single-galaxy ``jax_test`` dataset. +Exercises ``AnalysisImaging`` -> ``FitImaging`` with a Nautilus search. + +Galaxy: a single ``Sersic`` bulge — no lens / mass / source split (this is autogalaxy, +not autolens). +""" + +import os +from os import path + +import autofit as af +import autogalaxy as ag +import autogalaxy.plot as aplt + + +""" +__Dataset__ + +Reuse the ``jax_test`` dataset already used by ``scripts/jax_likelihood_functions/imaging``. +""" +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +aplt.plot_array(array=dataset.data) + + +""" +__Mask__ +""" +mask_radius = 3.0 + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) + +dataset = dataset.apply_mask(mask=mask) + +aplt.plot_array(array=dataset.data) + + +""" +__Model__ + +Single galaxy with a parametric ``Sersic`` bulge. +""" +bulge = af.Model(ag.lp.Sersic) +bulge.centre.centre_0 = 0.0 +bulge.centre.centre_1 = 0.0 + +galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge) + +model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) + + +""" +__Search__ +""" +search = af.Nautilus( + path_prefix=path.join("build", "model_fit", "imaging"), + n_live=50, + n_like_max=300, + number_of_cores=2, +) + + +""" +__Analysis__ +""" +analysis = ag.AnalysisImaging(dataset=dataset) + + +""" +__Model-Fit__ +""" +result = search.fit(model=model, analysis=analysis) + + +""" +__Result__ +""" +print(result.max_log_likelihood_instance) + +aplt.subplot_galaxies( + galaxies=result.max_log_likelihood_galaxies, grid=result.grids.lp +) + +aplt.subplot_fit_imaging(fit=result.max_log_likelihood_fit) + +aplt.corner_cornerpy(samples=result.samples) diff --git a/scripts/imaging/modeling_visualization_jit.py b/scripts/imaging/modeling_visualization_jit.py new file mode 100644 index 0000000..7666953 --- /dev/null +++ b/scripts/imaging/modeling_visualization_jit.py @@ -0,0 +1,231 @@ +""" +End-to-end test: jit-cached visualization during a real Nautilus model-fit. +========================================================================== + +Single-galaxy autogalaxy port of the autolens +``scripts/imaging/modeling_visualization_jit.py`` end-to-end test. + +This test runs in two parts: + +Part 1 — **MGE caching probe.** Uses an MGE galaxy model (Basis of +``ag.lp_linear.Gaussian`` profiles). Calls +``analysis.fit_for_visualization(instance)`` twice and asserts the second +call is much faster than the first (confirming the compiled function is +cached on the analysis instance, not recompiled per visualization). + +Part 2 — **Live Nautilus quick-update with linear light profiles.** Runs a +real (short) Nautilus fit with the same MGE galaxy. With autogalaxy's +``LightProfileLinear`` pytree handling, the +``linear_light_profile_intensity_dict`` lookup survives the JAX pytree +round-trip and no ``KeyError`` is raised. Asserts that ``fit.png`` files +land on disk, proving the JIT-cached fit_for_visualization fires correctly +during the live search callback. + +This script deliberately opts in with +``AnalysisImaging(use_jax=True, use_jax_for_visualization=True)``. Default +model-fit scripts elsewhere in the workspace leave both flags at ``False`` +and are therefore untouched by this change. +""" + +import shutil +import time +from os import path +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autogalaxy as ag +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() + + +""" +__Dataset__ + +Re-use the ``jax_test`` dataset that the jax_likelihood_functions scripts rely on. +""" +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask_radius = 3.5 +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) +dataset = dataset.apply_mask(mask=mask) +dataset = dataset.apply_over_sampling(over_sample_size_lp=4) + + +""" +============================================================================ +Part 1 — MGE caching probe +============================================================================ + +Model: MGE parametric galaxy (Basis of ``ag.lp_linear.Gaussian``). +""" +print("\n" + "=" * 72) +print("Part 1: MGE caching probe") +print("=" * 72) + +total_gaussians = 3 +log10_sigma_list = np.linspace(-2, np.log10(mask_radius), total_gaussians) + +centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +centre_1 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) + +gaussian_list = af.Collection( + af.Model(ag.lp_linear.Gaussian) for _ in range(total_gaussians) +) +for i, gaussian in enumerate(gaussian_list): + gaussian.centre.centre_0 = centre_0 + gaussian.centre.centre_1 = centre_1 + gaussian.ell_comps = gaussian_list[0].ell_comps + gaussian.sigma = 10 ** log10_sigma_list[i] + +bulge_mge = af.Model(ag.lp_basis.Basis, profile_list=list(gaussian_list)) + +galaxy_mge = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge_mge) + +model_mge = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge)) + +register_model(model_mge) + +analysis_mge = ag.AnalysisImaging( + dataset=dataset, + use_jax=True, + use_jax_for_visualization=True, +) + +instance_mge = model_mge.instance_from_prior_medians() + +t0 = time.perf_counter() +fit_1 = analysis_mge.fit_for_visualization(instance_mge) +jax.block_until_ready(fit_1.log_likelihood) +t1 = time.perf_counter() +compile_time = t1 - t0 +print(f"First call (compile + run): {compile_time:.3f}s") +print(f" log_likelihood leaf type: {type(fit_1.log_likelihood).__name__}") +assert isinstance(fit_1.log_likelihood, jnp.ndarray), ( + f"expected jax.Array, got {type(fit_1.log_likelihood)}" +) + +t0 = time.perf_counter() +fit_2 = analysis_mge.fit_for_visualization(instance_mge) +jax.block_until_ready(fit_2.log_likelihood) +t1 = time.perf_counter() +cached_time = t1 - t0 +print(f"Second call (cached): {cached_time:.3f}s") +print(f"Speedup: {compile_time / max(cached_time, 1e-9):.1f}x") + +assert cached_time < compile_time * 0.5, ( + f"Cached call ({cached_time:.3f}s) not faster than compile " + f"({compile_time:.3f}s) — JIT cache is not being hit." +) +assert analysis_mge._jitted_fit_from is not None, ( + "expected _jitted_fit_from to be cached on the analysis instance after first call" +) +print("PASS: MGE jit-cached fit_for_visualization works and is reused.") + + +""" +============================================================================ +Part 2 — Live Nautilus quick-update with linear light profiles +============================================================================ + +Model: MGE parametric galaxy (Basis of ``ag.lp_linear.Gaussian``). The +``linear_light_profile_intensity_dict`` lookup is exercised during +visualization. The live search fires quick-update visualization every +``iterations_per_quick_update`` calls; we verify ``fit.png`` lands on disk. +""" +print("\n" + "=" * 72) +print("Part 2: Live Nautilus with linear MGE profiles + jit-visualization") +print("=" * 72) + +total_gaussians2 = 3 +log10_sigma_list2 = np.linspace(-2, np.log10(mask_radius), total_gaussians2) + +centre_0_2 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +centre_1_2 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) + +gaussian_list2 = af.Collection( + af.Model(ag.lp_linear.Gaussian) for _ in range(total_gaussians2) +) +for i, gaussian in enumerate(gaussian_list2): + gaussian.centre.centre_0 = centre_0_2 + gaussian.centre.centre_1 = centre_1_2 + gaussian.ell_comps = gaussian_list2[0].ell_comps + gaussian.sigma = 10 ** log10_sigma_list2[i] + +bulge_mge2 = af.Model(ag.lp_basis.Basis, profile_list=list(gaussian_list2)) + +galaxy_mge2 = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge_mge2) + +model_mge2 = af.Collection(galaxies=af.Collection(galaxy=galaxy_mge2)) + +register_model(model_mge2) + +analysis_mge2 = ag.AnalysisImaging( + dataset=dataset, + use_jax=True, + use_jax_for_visualization=True, +) + +output_root = Path("scripts") / "imaging" / "images" / "modeling_visualization_jit" +if output_root.exists(): + shutil.rmtree(output_root) +output_root.mkdir(parents=True) + +search = af.Nautilus( + path_prefix=str(output_root), + name="mge_linear", + n_live=50, + n_like_max=1500, + iterations_per_quick_update=500, + number_of_cores=1, +) + +print("Running Nautilus ...") +result = search.fit(model=model_mge2, analysis=analysis_mge2) + +# The Nautilus output goes to output////image/ +# The quick-update visualizer writes fit.png (via subplot_fit function) +# to that image folder during each quick update. +output_search_root = Path("output") / output_root / "mge_linear" +produced_pngs = list(output_search_root.rglob("fit.png")) +print(f"fit.png files produced: {len(produced_pngs)}") +for p in produced_pngs: + print(f" {p}") +assert len(produced_pngs) > 0, ( + f"no fit.png produced under {output_search_root} — " + "quick-update visualization did not fire" +) + +# Note: _jitted_fit_from is built on the worker process Nautilus forks for the search +# loop, not the parent's analysis_mge2 instance — so we don't assert it post-search. +# Part 1 above already verifies the cache is set on the calling process. + +print( + "\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates " + "with linear MGE profiles, fit.png written, no KeyError from " + "linear_light_profile_intensity_dict lookup." +) diff --git a/scripts/imaging/visualization.py b/scripts/imaging/visualization.py new file mode 100644 index 0000000..379f5f3 --- /dev/null +++ b/scripts/imaging/visualization.py @@ -0,0 +1,339 @@ +""" +Visualization: Imaging Analysis +================================ + +Tests that ``VisualizerImaging.visualize_before_fit`` and ``visualize`` output all expected +files to disk and that each output has the correct FITS HDU structure. + +Dataset: single-galaxy ``jax_test`` imaging (Sersic bulge + Exponential disk simulator). + +Structure +--------- +1. ``visualize_before_fit`` runs once with a parametric source (fastest) and writes all + before-fit outputs (dataset.png/.fits, adapt_images.png/.fits) to the main + ``visualization/`` folder. + +2. ``visualize`` runs once per galaxy type, each writing into its own subfolder: + visualization/parametric/ — Sersic light-profile galaxy + visualization/rectangular/ — RectangularAdaptImage pixelization + visualization/delaunay/ — Delaunay pixelization + + Each subfolder contains only the per-run comparison plots: + fit.png, galaxies.png (all three types) + inversion_0_0.png (rectangular and delaunay only) + + A minimal ``config_source/visualize/plots.yaml`` (pushed before these runs) limits + output to just those files so the per-source runs stay fast. + +Expected outputs are derived directly from the source code of: + - autogalaxy/imaging/model/visualizer.py (VisualizerImaging) + - autogalaxy/imaging/model/plotter.py (PlotterImaging) + - autogalaxy/analysis/plotter.py (Plotter: galaxies, inversion) + - autogalaxy/imaging/plot/fit_imaging_plots.py +""" + +import shutil +import time +from os import path +from pathlib import Path +from types import SimpleNamespace + +from autoconf import conf + +conf.instance.push( + new_path=path.join(path.dirname(path.realpath(__file__)), "config"), + output_path=path.join(path.dirname(path.realpath(__file__)), "images"), +) + +import numpy as np +from astropy.io import fits as astropy_fits + +import autofit as af +import autogalaxy as ag +from autogalaxy.imaging.model.visualizer import VisualizerImaging + + +""" +__Dataset__ + +Reuse the ``jax_test`` dataset from ``scripts/jax_likelihood_functions/imaging``. +""" +pixel_scale = 0.2 + +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=pixel_scale, + over_sample_size_lp=2, + over_sample_size_pixelization=2, +) + +mask_radius = 3.0 + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) + +dataset = dataset.apply_mask(mask=mask) + + +""" +__Galaxy Models__ + +Three single-galaxy configurations, exercising the three plotter code paths +(parametric / rectangular pixelization / Delaunay pixelization). +""" + +# --- Parametric (Sersic) --- +galaxy_bulge = af.Model(ag.lp.Sersic) +galaxy_bulge.centre.centre_0 = 0.0 +galaxy_bulge.centre.centre_1 = 0.0 +galaxy_bulge.intensity = 1.0 +galaxy_bulge.effective_radius = 0.6 +galaxy_bulge.sersic_index = 3.0 +galaxy_parametric = af.Model(ag.Galaxy, redshift=0.5, bulge=galaxy_bulge) +model_parametric = af.Collection(galaxies=af.Collection(galaxy=galaxy_parametric)) + +# --- Rectangular pixelization --- +mesh_rect = ag.mesh.RectangularAdaptImage(shape=(22, 22)) +reg_rect = ag.reg.Constant(coefficient=1.0) +pix_rect = ag.Pixelization(mesh=mesh_rect, regularization=reg_rect) +galaxy_rectangular = af.Model(ag.Galaxy, redshift=0.5, pixelization=pix_rect) +model_rectangular = af.Collection( + galaxies=af.Collection(galaxy=galaxy_rectangular) +) + +# --- Delaunay pixelization --- +image_mesh = ag.image_mesh.Overlay(shape=(22, 22)) +image_plane_mesh_grid = image_mesh.image_plane_mesh_grid_from(mask=dataset.mask) + +mesh_del = ag.mesh.Delaunay(pixels=image_plane_mesh_grid.shape[0], zeroed_pixels=0) +reg_del = ag.reg.ConstantSplit(coefficient=1.0) +pix_del = ag.Pixelization(mesh=mesh_del, regularization=reg_del) +galaxy_delaunay = af.Model(ag.Galaxy, redshift=0.5, pixelization=pix_del) +model_delaunay = af.Collection(galaxies=af.Collection(galaxy=galaxy_delaunay)) + + +""" +__Adapt Images__ + +Used to test that adapt_images.png/.fits are written by visualize_before_fit. +""" +adapt_images = ag.AdaptImages( + galaxy_name_image_dict={ + "('galaxies', 'galaxy')": dataset.data, + }, + galaxy_name_image_plane_mesh_grid_dict={ + "('galaxies', 'galaxy')": image_plane_mesh_grid + }, +) + + +""" +__Analysis__ + +A single analysis object is shared across all three galaxy runs: it holds the +dataset only; the galaxy type is determined by the instance passed to visualize. +""" +analysis = ag.AnalysisImaging( + dataset=dataset, + adapt_images=adapt_images, + use_jax=True, + title_prefix="TEST", +) + + +""" +__Paths__ + +Minimal paths stub: VisualizerImaging only needs image_path and output_path. +Clean the output directory on each run so assertions reflect this run only. +""" + +image_path = Path("scripts") / "imaging" / "images" / "visualization" + +if image_path.exists(): + shutil.rmtree(image_path) + +image_path.mkdir(parents=True) + +output_path = image_path / "output" +output_path.mkdir(parents=True) + +paths = SimpleNamespace( + image_path=image_path, + output_path=output_path, +) + + +""" +__Visualize Before Fit__ + +Uses the parametric galaxy (fastest) for all before-fit outputs. + +Calls PlotterImaging.imaging() -> dataset.png, dataset.fits + Plotter.adapt_images() -> adapt_images.png, adapt_images.fits +""" + +print("Running visualize_before_fit (parametric galaxy)...") + +_t0 = time.perf_counter() +VisualizerImaging.visualize_before_fit( + analysis=analysis, + paths=paths, + model=model_parametric, +) +print(f"visualize_before_fit complete in {time.perf_counter() - _t0:.2f}s") + + +""" +__Assertions: visualize_before_fit__ +""" + +# ---- dataset.fits ---- +# Source: PlotterImaging.imaging() -> hdu_list_for_output_from with ext_name_list: +# ["mask", "data", "noise_map", "psf", "over_sample_size_lp", "over_sample_size_pixelization"] +# HDU 0 is PrimaryHDU (first value), HDUs 1-5 are ImageHDU. + +assert (image_path / "dataset.png").exists(), "dataset.png missing" +print("dataset.png OK") + +with astropy_fits.open(image_path / "dataset.fits") as hdul: + assert len(hdul) == 6, f"dataset.fits: expected 6 HDUs, got {len(hdul)}" + assert hdul[0].name == "MASK" + assert hdul[1].name == "DATA" + assert hdul[2].name == "NOISE_MAP" + assert hdul[3].name == "PSF" + assert hdul[4].name == "OVER_SAMPLE_SIZE_LP" + assert hdul[5].name == "OVER_SAMPLE_SIZE_PIXELIZATION" + assert hdul[1].data.ndim == 2, "DATA HDU should be 2D" +print("dataset.fits OK") + +# ---- adapt_images.fits ---- +# Source: Plotter.adapt_images() -> hdu_list_for_output_from with ext_name_list: +# ["mask", "('galaxies', 'galaxy')"] +# HDU 0 = MASK (Primary), HDU 1 = galaxy key (uppercased). + +assert (image_path / "adapt_images.png").exists(), "adapt_images.png missing" +print("adapt_images.png OK") + +with astropy_fits.open(image_path / "adapt_images.fits") as hdul: + assert len(hdul) == 2, f"adapt_images.fits: expected 2 HDUs, got {len(hdul)}" + assert hdul[0].name == "MASK" +print("adapt_images.fits OK") + + +""" +__Push Minimal Config for Per-Source Runs__ + +Override the all-true config with a minimal one that only enables: + fit.subplot_fit, galaxies.subplot_galaxies, inversion.subplot_inversion. +All other toggles are explicitly set to false so no extra files are written. +""" +conf.instance.push( + new_path=path.join(path.dirname(path.realpath(__file__)), "config_source"), + output_path=path.join(path.dirname(path.realpath(__file__)), "images"), +) + + +""" +__Per-Source Visualization__ + +For each galaxy type, visualize is run in a dedicated subfolder. +Only fit.png and galaxies.png are generated for all three; rectangular and delaunay +also produce inversion_0_0.png. + +Calls (governed by config_source/visualize/plots.yaml): + fit.subplot_fit -> fit.png + galaxies.subplot_galaxies -> galaxies.png + inversion.subplot_inversion -> inversion_0_0.png (pixelized galaxies only) +""" + +source_runs = [ + ("parametric", model_parametric, False), + ("rectangular", model_rectangular, True), + ("delaunay", model_delaunay, True), +] + +for source_name, model, has_inversion in source_runs: + print(f"\nRunning visualize for galaxy: {source_name}...") + + sub_path = image_path / source_name + sub_path.mkdir(parents=True) + sub_output = sub_path / "output" + sub_output.mkdir(parents=True) + sub_paths = SimpleNamespace(image_path=sub_path, output_path=sub_output) + + instance = model.instance_from_prior_medians() + + _t0 = time.perf_counter() + VisualizerImaging.visualize( + analysis=analysis, + paths=sub_paths, + instance=instance, + during_analysis=False, + ) + print(f" visualize complete for {source_name} in {time.perf_counter() - _t0:.2f}s") + + assert (sub_path / "fit.png").exists(), f"{source_name}/fit.png missing" + print(f" {source_name}/fit.png OK") + assert (sub_path / "galaxies.png").exists(), f"{source_name}/galaxies.png missing" + print(f" {source_name}/galaxies.png OK") + if has_inversion: + assert ( + sub_path / "inversion_0_0.png" + ).exists(), f"{source_name}/inversion_0_0.png missing" + print(f" {source_name}/inversion_0_0.png OK") + + +""" +__RGB Visualization__ + +Tests that ``plot_array`` correctly handles ``Array2DRGB`` inputs: no colormap, +no norm, no colorbar — the image is rendered via plain ``imshow`` as an RGB image. +""" + +print("\nRunning RGB visualization test...") + +import autogalaxy.plot as aplt + +rgb_values = np.stack( + [dataset.data.native, dataset.data.native, dataset.data.native], axis=-1 +) +rgb_values = np.clip(rgb_values, 0, None) + +rgb_values_uint8 = ( + (rgb_values / rgb_values.max() * 255).astype(np.uint8) + if rgb_values.max() > 0 + else np.zeros_like(rgb_values, dtype=np.uint8) +) + +rgb_array = ag.Array2DRGB(values=rgb_values_uint8, mask=dataset.mask) + +aplt.plot_array( + array=rgb_array, + title="RGB Test", + output_path=image_path, + output_filename="rgb_array", + output_format="png", +) + +assert (image_path / "rgb_array.png").exists(), "rgb_array.png missing" +print("rgb_array.png OK") + + +print("All visualization assertions passed.") diff --git a/scripts/imaging/visualization_jax.py b/scripts/imaging/visualization_jax.py new file mode 100644 index 0000000..f90dead --- /dev/null +++ b/scripts/imaging/visualization_jax.py @@ -0,0 +1,136 @@ +""" +Visualization JAX Pilot: Imaging Analysis (autogalaxy) +======================================================= + +Single-galaxy autogalaxy port of the autolens ``visualization_jax.py`` pilot +(https://github.com/PyAutoLabs/PyAutoFit/issues/1227). + +Goal +---- +Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end, gated behind +``use_jax_for_visualization`` on ``Analysis``. A parametric MGE galaxy is used +deliberately (simplest case — no pixelization, no inversion). + +This is **Path C**: ``fit_from`` runs on the eager JAX path +(``use_jax=True`` makes ``_xp`` be ``jnp``) and returns a ``FitImaging`` backed +by ``jax.Array`` objects. Matplotlib-bound plotters materialise arrays to NumPy +at the boundary. No ``jax.jit`` is applied to ``fit_from`` — the full-JIT path +(Path A) is exercised by ``modeling_visualization_jit.py``. + +Scope +----- +- Parametric MGE galaxy only. +- Calls ``VisualizerImaging.visualize`` only (not ``visualize_before_fit``). +- Re-uses the ``jax_test`` dataset from ``jax_likelihood_functions/imaging``. +- Reuses ``config_source/visualize/plots.yaml`` from ``visualization.py`` so + only ``fit.png`` and ``galaxies.png`` are attempted. +""" + +import shutil +import traceback +from os import path +from pathlib import Path +from types import SimpleNamespace + +from autoconf import conf + +conf.instance.push( + new_path=path.join(path.dirname(path.realpath(__file__)), "config_source"), + output_path=path.join(path.dirname(path.realpath(__file__)), "images"), +) + +import autofit as af +import autogalaxy as ag +from autogalaxy.imaging.model.visualizer import VisualizerImaging + + +""" +__Dataset__ +""" +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask_radius = 3.0 +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) +dataset = dataset.apply_mask(mask=mask) + + +""" +__Model__ + +MGE parametric galaxy (matches the MGE pattern in +``jax_likelihood_functions/imaging/mge.py``). +""" +galaxy_bulge = ag.model_util.mge_model_from( + mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=True +) +galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=galaxy_bulge) +model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) + + +""" +__Analysis__ + +``use_jax=True`` turns on the JAX ``_xp`` path; ``use_jax_for_visualization=True`` +tells the search-level visualization path to wrap ``fit_from`` in ``jax.jit`` +via the ``Analysis.fit_for_visualization`` helper. +""" +analysis = ag.AnalysisImaging( + dataset=dataset, + use_jax=True, + use_jax_for_visualization=True, + title_prefix="JAX_PILOT", +) + + +""" +__Paths__ +""" +image_path = Path("scripts") / "imaging" / "images" / "visualization_jax" +if image_path.exists(): + shutil.rmtree(image_path) +image_path.mkdir(parents=True) +output_path = image_path / "output" +output_path.mkdir(parents=True) +paths = SimpleNamespace(image_path=image_path, output_path=output_path) + + +""" +__Run visualize on the eager-JAX fit__ +""" +instance = model.instance_from_prior_medians() + +print("Running VisualizerImaging.visualize with use_jax_for_visualization=True ...") +try: + VisualizerImaging.visualize( + analysis=analysis, + paths=paths, + instance=instance, + during_analysis=False, + ) + assert (image_path / "fit.png").exists(), "fit.png was not produced" + print("PILOT SUCCEEDED — JAX-backed visualization produced fit.png/galaxies.png.") +except Exception: + print("PILOT FAILED — traceback below:") + print("=" * 72) + traceback.print_exc() + print("=" * 72) diff --git a/smoke_tests.txt b/smoke_tests.txt index dc0464e..6306d80 100644 --- a/smoke_tests.txt +++ b/smoke_tests.txt @@ -6,3 +6,5 @@ jax_likelihood_functions/imaging/lp.py jax_likelihood_functions/imaging/mge.py jax_likelihood_functions/imaging/mge_group.py jax_likelihood_functions/imaging/rectangular.py +imaging/model_fit.py +imaging/visualization.py