From a261bf86ca6725fe08c9f3b24a611204803e5c1f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 22 Apr 2026 14:37:38 +0100 Subject: [PATCH] feat: add jax_likelihood_functions/imaging/ scripts (4/8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ports the first wave of autolens_workspace_test jax_likelihood_functions/imaging/ scripts to autogalaxy_workspace_test on a single-plane galaxy model (no lens/ source split, no ray-tracing, no mass profiles): - lp.py — parametric Sersic - mge.py — MGE basis - mge_group.py — MGE + extra_galaxies - rectangular.py — non-adapt RectangularUniform + reg.Constant pixelization - simulator.py — generates the shared dataset/imaging/jax_test/ fits All four exercise the three-step JAX contract (NumPy baseline → fitness._vmap batch → jax.jit(analysis.fit_from) scalar round-trip) enabled by the _register_fit_imaging_pytrees scaffold landed in PyAutoGalaxy PR #364. Deferred to a follow-up library task due to an AdaptImages post-unflatten Galaxy-identity bug on the autogalaxy side (see admin_jammy/prompt/autogalaxy/ adapt_images_pytree_fix.md): - rectangular_mge.py — RectangularAdaptImage + reg.Adapt - delaunay.py — Delaunay requires image_plane_mesh_grid via adapt_images - delaunay_mge.py — same Also gitignores dataset/ (matches autolens_workspace_test) — scripts auto- regenerate via subprocess on first run if the fits are missing. Part of PyAutoLabs/autogalaxy_workspace_test#8 (epic #5, task 3/9). Co-Authored-By: Claude Opus 4.7 --- .gitignore | 1 + scripts/jax_likelihood_functions/__init__.py | 0 .../imaging/__init__.py | 0 .../jax_likelihood_functions/imaging/lp.py | 129 +++++++++++ .../jax_likelihood_functions/imaging/mge.py | 142 ++++++++++++ .../imaging/mge_group.py | 206 ++++++++++++++++++ .../imaging/rectangular.py | 130 +++++++++++ .../imaging/simulator.py | 77 +++++++ smoke_tests.txt | 4 + 9 files changed, 689 insertions(+) create mode 100644 scripts/jax_likelihood_functions/__init__.py create mode 100644 scripts/jax_likelihood_functions/imaging/__init__.py create mode 100644 scripts/jax_likelihood_functions/imaging/lp.py create mode 100644 scripts/jax_likelihood_functions/imaging/mge.py create mode 100644 scripts/jax_likelihood_functions/imaging/mge_group.py create mode 100644 scripts/jax_likelihood_functions/imaging/rectangular.py create mode 100644 scripts/jax_likelihood_functions/imaging/simulator.py diff --git a/.gitignore b/.gitignore index b3dc17b..1a68155 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ output/ *.log .pytest_cache/ failed/ +dataset/ diff --git a/scripts/jax_likelihood_functions/__init__.py b/scripts/jax_likelihood_functions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/jax_likelihood_functions/imaging/__init__.py b/scripts/jax_likelihood_functions/imaging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/jax_likelihood_functions/imaging/lp.py b/scripts/jax_likelihood_functions/imaging/lp.py new file mode 100644 index 0000000..bcde593 --- /dev/null +++ b/scripts/jax_likelihood_functions/imaging/lp.py @@ -0,0 +1,129 @@ +""" +JAX Likelihood: Parametric Light Profile +======================================== + +Verify that JAX can compute the log-likelihood of an ``Imaging`` fit for an +autogalaxy model composed of a linear Sersic bulge. Two paths are exercised: + +1. ``fitness._vmap`` batch evaluation (tests ``jax.vmap`` + ``jax.jit`` on the + autofit ``Fitness`` wrapper). +2. ``jax.jit(analysis.fit_from)`` round-trip, which relies on the pytree + registration added to ``AnalysisImaging._register_fit_imaging_pytrees`` — + this path exercises the full ``FitImaging`` return value flattening. +""" + +import time +from os import path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autogalaxy as ag + + +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=3.0, +) + +dataset = dataset.apply_mask(mask=mask) +dataset = dataset.apply_over_sampling(over_sample_size_lp=1) + +""" +__Model__ + +Single galaxy with a linear Sersic bulge — no lens/source split, no mass profile. +""" +bulge = af.Model(ag.lp.Sersic) + +galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge) + +model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) + +print(model.info) + +analysis = ag.AnalysisImaging(dataset=dataset) + +""" +__vmap Path__ + +Wrap the autofit ``Fitness`` in ``jax.vmap`` and evaluate a batch of parameter +vectors. This tests that the full likelihood pipeline JIT-compiles end to end. +""" +from autofit.non_linear.fitness import Fitness + +batch_size = 50 + +fitness = Fitness( + model=model, + analysis=analysis, + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) + +parameters = np.zeros((batch_size, model.total_free_parameters)) +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians +parameters = jnp.array(parameters) + +start = time.time() +result = fitness._vmap(parameters) +print(result) +print("JAX Time To VMAP + JIT Function:", time.time() - start) + +start = time.time() +result = fitness._vmap(parameters) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) + +""" +__Path A: jit-wrap ``analysis.fit_from``__ + +Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitImaging`` +with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. This +is the part unblocked by ``_register_fit_imaging_pytrees``. +""" +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() +register_model(model) + +instance = model.instance_from_prior_medians() + +analysis_np = ag.AnalysisImaging(dataset=dataset, use_jax=False) +fit_np = analysis_np.fit_from(instance=instance) +print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) + +analysis_jit = ag.AnalysisImaging(dataset=dataset, use_jax=True) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance(fit.log_likelihood, jnp.ndarray), ( + f"expected jax.Array, got {type(fit.log_likelihood)}" +) +np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/scripts/jax_likelihood_functions/imaging/mge.py b/scripts/jax_likelihood_functions/imaging/mge.py new file mode 100644 index 0000000..2646308 --- /dev/null +++ b/scripts/jax_likelihood_functions/imaging/mge.py @@ -0,0 +1,142 @@ +""" +JAX Likelihood: MGE Basis Light Profile +======================================== + +Verify that JAX can compute the log-likelihood of an ``Imaging`` fit for an +autogalaxy model composed of a Multi-Gaussian Expansion (MGE) linear basis. +Two paths are exercised: + +1. ``fitness._vmap`` batch evaluation (tests ``jax.vmap`` + ``jax.jit`` on the + autofit ``Fitness`` wrapper). +2. ``jax.jit(analysis.fit_from)`` round-trip, which relies on the pytree + registration added to ``AnalysisImaging._register_fit_imaging_pytrees`` — + this path exercises the full ``FitImaging`` return value flattening. +""" + +import time +from os import path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autogalaxy as ag + + +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask_radius = 3.0 + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) + +dataset = dataset.apply_mask(mask=mask) + +over_sample_size = ag.util.over_sample.over_sample_size_via_radial_bins_from( + grid=dataset.grid, + sub_size_list=[4, 2, 1], + radial_list=[0.3, 0.6], + centre_list=[(0.0, 0.0)], +) + +dataset = dataset.apply_over_sampling(over_sample_size_lp=over_sample_size) + +""" +__Model__ + +Single galaxy with an MGE linear basis light profile — no lens/source split, +no mass profile. +""" +bulge = ag.model_util.mge_model_from( + mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=True +) + +galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge) + +model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) + +print(model.info) + +analysis = ag.AnalysisImaging(dataset=dataset) + +""" +__vmap Path__ + +Wrap the autofit ``Fitness`` in ``jax.vmap`` and evaluate a batch of parameter +vectors. This tests that the full likelihood pipeline JIT-compiles end to end. +""" +from autofit.non_linear.fitness import Fitness + +batch_size = 50 + +fitness = Fitness( + model=model, + analysis=analysis, + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) + +parameters = np.zeros((batch_size, model.total_free_parameters)) +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians +parameters = jnp.array(parameters) + +start = time.time() +result = fitness._vmap(parameters) +print(result) +print("JAX Time To VMAP + JIT Function:", time.time() - start) + +start = time.time() +result = fitness._vmap(parameters) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) + +""" +__Path A: jit-wrap ``analysis.fit_from``__ + +Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitImaging`` +with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. +""" +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() +register_model(model) + +instance = model.instance_from_prior_medians() + +analysis_np = ag.AnalysisImaging(dataset=dataset, use_jax=False) +fit_np = analysis_np.fit_from(instance=instance) +print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) + +analysis_jit = ag.AnalysisImaging(dataset=dataset, use_jax=True) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance(fit.log_likelihood, jnp.ndarray), ( + f"expected jax.Array, got {type(fit.log_likelihood)}" +) +np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/scripts/jax_likelihood_functions/imaging/mge_group.py b/scripts/jax_likelihood_functions/imaging/mge_group.py new file mode 100644 index 0000000..c941f74 --- /dev/null +++ b/scripts/jax_likelihood_functions/imaging/mge_group.py @@ -0,0 +1,206 @@ +""" +JAX Likelihood: MGE Basis Light Profile with Extra Galaxies +============================================================ + +Verify that JAX can compute the log-likelihood of an ``Imaging`` fit for an +autogalaxy model composed of a Multi-Gaussian Expansion (MGE) linear basis +on the primary galaxy plus extra galaxies (also with MGE bases). +Two paths are exercised: + +1. ``fitness._vmap`` batch evaluation (tests ``jax.vmap`` + ``jax.jit`` on the + autofit ``Fitness`` wrapper). +2. ``jax.jit(analysis.fit_from)`` round-trip, which relies on the pytree + registration added to ``AnalysisImaging._register_fit_imaging_pytrees`` — + this path exercises the full ``FitImaging`` return value flattening. +""" + +import time +from os import path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autogalaxy as ag + + +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +""" +__Group Centres__ +""" +centre_list = [(0.0, 0.0), (0.0, 1.0), (0.0, 2.0), (0.0, 3.0), (0.0, 4.0)] + +mask_radius = 3.0 + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, pixel_scales=dataset.pixel_scales, radius=4.0 +) + +dataset = dataset.apply_mask(mask=mask) + +over_sample_size = ag.util.over_sample.over_sample_size_via_radial_bins_from( + grid=dataset.grid, + sub_size_list=[4, 2, 1], + radial_list=[0.3, 0.6], + centre_list=[(0.0, 0.0)] + centre_list, +) + +dataset = dataset.apply_over_sampling(over_sample_size_lp=over_sample_size) + +""" +__Model__ + +Single primary galaxy with an MGE linear basis, plus extra galaxies (each with +a spherical MGE linear basis). No mass profiles, no lens/source split. +""" +total_gaussians = 30 +gaussian_per_basis = 2 + +log10_sigma_list = np.linspace(-2, np.log10(mask_radius), total_gaussians) + +centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +centre_1 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) + +bulge_gaussian_list = [] + +for j in range(gaussian_per_basis): + gaussian_list = af.Collection( + af.Model(ag.lp_linear.Gaussian) for _ in range(total_gaussians) + ) + + for i, gaussian in enumerate(gaussian_list): + gaussian.centre.centre_0 = centre_0 + gaussian.centre.centre_1 = centre_1 + gaussian.ell_comps = gaussian_list[0].ell_comps + gaussian.sigma = 10 ** log10_sigma_list[i] + + bulge_gaussian_list += gaussian_list + +bulge = af.Model( + ag.lp_basis.Basis, + profile_list=bulge_gaussian_list, +) + +galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge) + +# Extra Galaxies: + +extra_galaxies_list = [] + +for extra_galaxy_centre in centre_list: + total_gaussians = 10 + + log10_sigma_list = np.linspace(-2, np.log10(mask_radius), total_gaussians) + + extra_galaxy_gaussian_list = [] + + gaussian_list = af.Collection( + af.Model(ag.lp_linear.GaussianSph) for _ in range(total_gaussians) + ) + + for i, gaussian in enumerate(gaussian_list): + gaussian.centre.centre_0 = extra_galaxy_centre[0] + gaussian.centre.centre_1 = extra_galaxy_centre[1] + gaussian.sigma = 10 ** log10_sigma_list[i] + + extra_galaxy_gaussian_list += gaussian_list + + extra_galaxy_bulge = af.Model( + ag.lp_basis.Basis, profile_list=extra_galaxy_gaussian_list + ) + + extra_galaxy = af.Model( + ag.Galaxy, redshift=0.5, bulge=extra_galaxy_bulge + ) + + extra_galaxies_list.append(extra_galaxy) + +extra_galaxies = af.Collection(extra_galaxies_list) + +# Overall Model: + +model = af.Collection( + galaxies=af.Collection(galaxy=galaxy), extra_galaxies=extra_galaxies +) + +analysis = ag.AnalysisImaging(dataset=dataset) + +""" +__vmap Path__ + +Wrap the autofit ``Fitness`` in ``jax.vmap`` and evaluate a batch of parameter +vectors. This tests that the full likelihood pipeline JIT-compiles end to end. +""" +from autofit.non_linear.fitness import Fitness + +batch_size = 50 + +fitness = Fitness( + model=model, + analysis=analysis, + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) + +parameters = np.zeros((batch_size, model.total_free_parameters)) +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians +parameters = jnp.array(parameters) + +start = time.time() +result = fitness._vmap(parameters) +print(result) +print("JAX Time To VMAP + JIT Function:", time.time() - start) + +start = time.time() +result = fitness._vmap(parameters) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) + +""" +__Path A: jit-wrap ``analysis.fit_from``__ + +Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitImaging`` +with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. +""" +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() +register_model(model) + +instance = model.instance_from_prior_medians() + +analysis_np = ag.AnalysisImaging(dataset=dataset, use_jax=False) +fit_np = analysis_np.fit_from(instance=instance) +print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) + +analysis_jit = ag.AnalysisImaging(dataset=dataset, use_jax=True) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance(fit.log_likelihood, jnp.ndarray), ( + f"expected jax.Array, got {type(fit.log_likelihood)}" +) +np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/scripts/jax_likelihood_functions/imaging/rectangular.py b/scripts/jax_likelihood_functions/imaging/rectangular.py new file mode 100644 index 0000000..64879b2 --- /dev/null +++ b/scripts/jax_likelihood_functions/imaging/rectangular.py @@ -0,0 +1,130 @@ +""" +JAX Likelihood: Rectangular Pixelization +========================================= + +Verify that JAX can compute the log-likelihood of an ``Imaging`` fit for an +autogalaxy model using a non-adapt rectangular pixelization mesh. Two paths +are exercised: + +1. ``fitness._vmap`` batch evaluation. +2. ``jax.jit(analysis.fit_from)`` scalar round-trip — relies on + ``AnalysisImaging._register_fit_imaging_pytrees``. + +Note: this port uses ``ag.mesh.RectangularUniform`` + ``ag.reg.Constant`` (no +adapt images). The adapt-image variant (``RectangularAdaptImage`` + +``ag.reg.Adapt``) hits a post-unflatten Galaxy-identity mismatch in +``AdaptImages.galaxy_image_dict`` that the autogalaxy library does not yet +resolve across the JIT boundary — a separate library fix is required there. +""" + +import time +from os import path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autogalaxy as ag + + +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=3.0, +) + +dataset = dataset.apply_mask(mask=mask) +dataset = dataset.apply_over_sampling(over_sample_size_lp=1) + +""" +__Model__ + +Single galaxy with a rectangular pixelization. No lens/source split, no mass +profile, no adapt images. +""" +mesh = ag.mesh.RectangularUniform(shape=(28, 28)) +regularization = ag.reg.Constant(coefficient=1.0) +pixelization = ag.Pixelization(mesh=mesh, regularization=regularization) + +galaxy = af.Model(ag.Galaxy, redshift=0.5, pixelization=pixelization) + +model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) + +print(model.info) + +analysis = ag.AnalysisImaging(dataset=dataset) + +""" +__vmap Path__ +""" +from autofit.non_linear.fitness import Fitness + +batch_size = 3 + +fitness = Fitness( + model=model, + analysis=analysis, + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) + +parameters = np.zeros((batch_size, model.total_free_parameters)) +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians +parameters = jnp.array(parameters) + +start = time.time() +result = fitness._vmap(parameters) +print(result) +print("JAX Time To VMAP + JIT Function:", time.time() - start) + +start = time.time() +result = fitness._vmap(parameters) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) + +""" +__Path A: jit-wrap ``analysis.fit_from``__ +""" +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() +register_model(model) + +instance = model.instance_from_prior_medians() + +analysis_np = ag.AnalysisImaging(dataset=dataset, use_jax=False) +fit_np = analysis_np.fit_from(instance=instance) +print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) + +analysis_jit = ag.AnalysisImaging(dataset=dataset, use_jax=True) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance(fit.log_likelihood, jnp.ndarray), ( + f"expected jax.Array, got {type(fit.log_likelihood)}" +) +np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/scripts/jax_likelihood_functions/imaging/simulator.py b/scripts/jax_likelihood_functions/imaging/simulator.py new file mode 100644 index 0000000..ff1fc1f --- /dev/null +++ b/scripts/jax_likelihood_functions/imaging/simulator.py @@ -0,0 +1,77 @@ +""" +Simulator: JAX Imaging Test Dataset +=================================== + +Simulates the `Imaging` dataset consumed by every script in +``scripts/jax_likelihood_functions/imaging/``. + +A single galaxy with a Sersic bulge + Exponential disk is imaged at HST-like +resolution and signal-to-noise. No lens / mass / source plane — this is a +single-plane autogalaxy dataset designed to exercise the JAX likelihood path on +parametric light profiles, MGE bases, and pixelization sources. + +Output files (under ``dataset/imaging/jax_test/``): + +- ``data.fits`` — the simulated noisy image +- ``psf.fits`` — the Gaussian PSF kernel used during simulation +- ``noise_map.fits`` — per-pixel 1-sigma noise map +- ``galaxies.json`` — the exact ``Galaxies`` used, for reproducibility +""" + +from pathlib import Path + +import autogalaxy as ag +import autogalaxy.plot as aplt + + +dataset_path = Path("dataset", "imaging", "jax_test") + +grid = ag.Grid2D.uniform(shape_native=(180, 180), pixel_scales=0.2) + +psf = ag.Convolver.from_gaussian( + shape_native=(21, 21), sigma=0.2, pixel_scales=grid.pixel_scales, normalize=True +) + +simulator = ag.SimulatorImaging( + exposure_time=2000.0, + psf=psf, + background_sky_level=1.0, + add_poisson_noise_to_data=True, + noise_seed=1, +) + +galaxy = ag.Galaxy( + redshift=0.5, + bulge=ag.lp.Sersic( + centre=(0.0, 0.0), + ell_comps=ag.convert.ell_comps_from(axis_ratio=0.9, angle=45.0), + intensity=4.0, + effective_radius=0.6, + sersic_index=3.0, + ), + disk=ag.lp.Exponential( + centre=(0.0, 0.0), + ell_comps=ag.convert.ell_comps_from(axis_ratio=0.7, angle=30.0), + intensity=2.0, + effective_radius=1.6, + ), +) + +galaxies = ag.Galaxies(galaxies=[galaxy]) + +dataset = simulator.via_galaxies_from(galaxies=galaxies, grid=grid) + +aplt.fits_imaging( + dataset=dataset, + data_path=dataset_path / "data.fits", + psf_path=dataset_path / "psf.fits", + noise_map_path=dataset_path / "noise_map.fits", + overwrite=True, +) + +ag.output_to_json( + obj=galaxies, + file_path=Path(dataset_path, "galaxies.json"), +) + +print("Dataset written to", dataset_path) diff --git a/smoke_tests.txt b/smoke_tests.txt index 8601af1..dc0464e 100644 --- a/smoke_tests.txt +++ b/smoke_tests.txt @@ -2,3 +2,7 @@ aggregator/galaxies.py aggregator/fit_imaging.py aggregator/fit_interferometer.py aggregator/ellipse.py +jax_likelihood_functions/imaging/lp.py +jax_likelihood_functions/imaging/mge.py +jax_likelihood_functions/imaging/mge_group.py +jax_likelihood_functions/imaging/rectangular.py