diff --git a/scripts/jax_likelihood_functions/imaging/delaunay.py b/scripts/jax_likelihood_functions/imaging/delaunay.py new file mode 100644 index 0000000..945c03c --- /dev/null +++ b/scripts/jax_likelihood_functions/imaging/delaunay.py @@ -0,0 +1,187 @@ +""" +JAX Likelihood: Delaunay Adapt-Image Pixelization +================================================== + +Single-galaxy autogalaxy model using a ``Delaunay`` mesh with a Hilbert +image-mesh (which seeds source-pixel centres in the image plane via an adapt +image) and ``AdaptSplit`` regularization. + +This exercises the second post-unflatten lookup site — +``GalaxiesToInversion.image_plane_mesh_grid_list`` — which previously fell +back to the single-mesh-grid value when the by-instance lookup missed. + +Two paths are exercised: + +1. ``fitness._vmap`` batch evaluation. +2. ``jax.jit(analysis.fit_from)`` scalar round-trip — relies on + ``AnalysisImaging._register_fit_imaging_pytrees`` and on + ``AdaptImages.image_plane_mesh_grid_for_galaxy`` resolving fresh-Galaxy + lookups via the path-tuple list. +""" + +import time +from os import path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autogalaxy as ag + + +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask_radius = 3.0 + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) + +dataset = dataset.apply_mask(mask=mask) +dataset = dataset.apply_over_sampling( + over_sample_size_lp=4, + over_sample_size_pixelization=4, +) + +""" +__JAX & Preloads__ + +JAX requires static-shaped arrays. ``pixels`` and ``edge_pixels_total`` fix the +total source-pixel count up front. The image-plane mesh grid is built in +NumPy via the Hilbert image-mesh and circle-edge augmentation, then passed +in via ``galaxy_name_image_plane_mesh_grid_dict``. +""" +pixels = 750 +edge_pixels_total = 30 + +galaxy_name_image_dict = { + "('galaxies', 'galaxy')": dataset.data, +} + +image_mesh = ag.image_mesh.Hilbert(pixels=pixels, weight_power=3.5, weight_floor=0.01) + +image_plane_mesh_grid = image_mesh.image_plane_mesh_grid_from( + mask=dataset.mask, adapt_data=galaxy_name_image_dict["('galaxies', 'galaxy')"] +) + +image_plane_mesh_grid = ag.image_mesh.append_with_circle_edge_points( + image_plane_mesh_grid=image_plane_mesh_grid, + centre=mask.mask_centre, + radius=mask_radius + mask.pixel_scale / 2.0, + n_points=edge_pixels_total, +) + +total_mapper_pixels = image_plane_mesh_grid.shape[0] + +adapt_images = ag.AdaptImages( + galaxy_name_image_dict=galaxy_name_image_dict, + galaxy_name_image_plane_mesh_grid_dict={ + "('galaxies', 'galaxy')": image_plane_mesh_grid + }, +) + +""" +__Model__ + +Single galaxy with a Delaunay pixelization seeded by the Hilbert image-mesh. +""" +pixelization = af.Model( + ag.Pixelization, + mesh=ag.mesh.Delaunay(pixels=pixels, zeroed_pixels=edge_pixels_total), + regularization=ag.reg.AdaptSplit, +) + +galaxy = af.Model(ag.Galaxy, redshift=0.5, pixelization=pixelization) + +model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) + +print(model.info) + +settings = ag.Settings( + use_border_relocator=True, + use_positive_only_solver=True, + use_mixed_precision=True, +) + +analysis = ag.AnalysisImaging( + dataset=dataset, adapt_images=adapt_images, settings=settings +) + +""" +__vmap Path__ +""" +from autofit.non_linear.fitness import Fitness + +batch_size = 3 + +fitness = Fitness( + model=model, + analysis=analysis, + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) + +parameters = np.zeros((batch_size, model.total_free_parameters)) +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians +parameters = jnp.array(parameters) + +start = time.time() +result = fitness._vmap(parameters) +print(result) +print("JAX Time To VMAP + JIT Function:", time.time() - start) + +start = time.time() +result = fitness._vmap(parameters) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) + +""" +__Path A: jit-wrap ``analysis.fit_from``__ +""" +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() +register_model(model) + +instance = model.instance_from_prior_medians() + +analysis_np = ag.AnalysisImaging( + dataset=dataset, adapt_images=adapt_images, settings=settings, use_jax=False +) +fit_np = analysis_np.fit_from(instance=instance) +print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) + +analysis_jit = ag.AnalysisImaging( + dataset=dataset, adapt_images=adapt_images, settings=settings, use_jax=True +) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance(fit.log_likelihood, jnp.ndarray), ( + f"expected jax.Array, got {type(fit.log_likelihood)}" +) +np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-2 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/scripts/jax_likelihood_functions/imaging/delaunay_mge.py b/scripts/jax_likelihood_functions/imaging/delaunay_mge.py new file mode 100644 index 0000000..d325344 --- /dev/null +++ b/scripts/jax_likelihood_functions/imaging/delaunay_mge.py @@ -0,0 +1,206 @@ +""" +JAX Likelihood: Delaunay + MGE Bulge +===================================== + +Two-galaxy autogalaxy model: a foreground galaxy with an MGE bulge and a +second galaxy with a ``Delaunay`` mesh + ``MaternAdaptKernel`` regularization +seeded by a Hilbert image-mesh. + +Disabled in ``smoke_tests.txt`` to mirror the autolens ``delaunay_mge.py`` +deferral for the JAX 0.7 regression +(``jax.interpreters.xla.pytype_aval_mappings`` removed). The script runs +locally under current JAX — re-enable once the autolens equivalent is +re-enabled, or sooner if the autogalaxy path is confirmed unaffected. + +Two paths are exercised when run directly: + +1. ``fitness._vmap`` batch evaluation. +2. ``jax.jit(analysis.fit_from)`` scalar round-trip. +""" + +import time +from os import path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autogalaxy as ag + + +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask_radius = 3.0 + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) + +dataset = dataset.apply_mask(mask=mask) + +over_sample_size = ag.util.over_sample.over_sample_size_via_radial_bins_from( + grid=dataset.grid, + sub_size_list=[4, 2, 1], + radial_list=[0.3, 0.6], + centre_list=[(0.0, 0.0)], +) + +dataset = dataset.apply_over_sampling( + over_sample_size_lp=over_sample_size, + over_sample_size_pixelization=4, +) + +""" +__JAX & Preloads__ + +JAX requires static-shaped arrays. ``pixels`` and ``edge_pixels_total`` fix the +total source-pixel count up front. The image-plane mesh grid for ``galaxy_1`` +is built in NumPy via the Hilbert image-mesh. +""" +pixels = 750 +edge_pixels_total = 30 + +galaxy_name_image_dict = { + "('galaxies', 'galaxy_0')": dataset.data, + "('galaxies', 'galaxy_1')": dataset.data, +} + +image_mesh = ag.image_mesh.Hilbert(pixels=pixels, weight_power=3.5, weight_floor=0.01) + +image_plane_mesh_grid = image_mesh.image_plane_mesh_grid_from( + mask=dataset.mask, + adapt_data=galaxy_name_image_dict["('galaxies', 'galaxy_1')"], +) + +image_plane_mesh_grid = ag.image_mesh.append_with_circle_edge_points( + image_plane_mesh_grid=image_plane_mesh_grid, + centre=mask.mask_centre, + radius=mask_radius + mask.pixel_scale / 2.0, + n_points=edge_pixels_total, +) + +total_mapper_pixels = image_plane_mesh_grid.shape[0] + +adapt_images = ag.AdaptImages( + galaxy_name_image_dict=galaxy_name_image_dict, + galaxy_name_image_plane_mesh_grid_dict={ + "('galaxies', 'galaxy_1')": image_plane_mesh_grid + }, +) + +""" +__Model__ + +galaxy_0: MGE bulge. +galaxy_1: Delaunay pixelization with MaternAdaptKernel regularization. +""" +bulge = ag.model_util.mge_model_from( + mask_radius=mask_radius, + total_gaussians=20, + centre_prior_is_uniform=True, +) + +galaxy_0 = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge) + +pixelization = af.Model( + ag.Pixelization, + mesh=ag.mesh.Delaunay(pixels=pixels, zeroed_pixels=edge_pixels_total), + regularization=ag.reg.MaternAdaptKernel, +) + +galaxy_1 = af.Model(ag.Galaxy, redshift=0.5, pixelization=pixelization) + +model = af.Collection( + galaxies=af.Collection(galaxy_0=galaxy_0, galaxy_1=galaxy_1) +) + +print(model.info) + +settings = ag.Settings( + use_border_relocator=True, + use_positive_only_solver=True, + use_mixed_precision=True, +) + +analysis = ag.AnalysisImaging( + dataset=dataset, adapt_images=adapt_images, settings=settings +) + +""" +__vmap Path__ +""" +from autofit.non_linear.fitness import Fitness + +batch_size = 3 + +fitness = Fitness( + model=model, + analysis=analysis, + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) + +parameters = np.zeros((batch_size, model.total_free_parameters)) +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians +parameters = jnp.array(parameters) + +start = time.time() +result = fitness._vmap(parameters) +print(result) +print("JAX Time To VMAP + JIT Function:", time.time() - start) + +start = time.time() +result = fitness._vmap(parameters) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) + +""" +__Path A: jit-wrap ``analysis.fit_from``__ +""" +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() +register_model(model) + +instance = model.instance_from_prior_medians() + +analysis_np = ag.AnalysisImaging( + dataset=dataset, adapt_images=adapt_images, settings=settings, use_jax=False +) +fit_np = analysis_np.fit_from(instance=instance) +print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) + +analysis_jit = ag.AnalysisImaging( + dataset=dataset, adapt_images=adapt_images, settings=settings, use_jax=True +) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance(fit.log_likelihood, jnp.ndarray), ( + f"expected jax.Array, got {type(fit.log_likelihood)}" +) +np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-2 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/scripts/jax_likelihood_functions/imaging/rectangular.py b/scripts/jax_likelihood_functions/imaging/rectangular.py index 64879b2..2d49c3e 100644 --- a/scripts/jax_likelihood_functions/imaging/rectangular.py +++ b/scripts/jax_likelihood_functions/imaging/rectangular.py @@ -1,20 +1,18 @@ """ -JAX Likelihood: Rectangular Pixelization -========================================= +JAX Likelihood: Rectangular Adapt-Image Pixelization +===================================================== Verify that JAX can compute the log-likelihood of an ``Imaging`` fit for an -autogalaxy model using a non-adapt rectangular pixelization mesh. Two paths -are exercised: +autogalaxy model that uses an adapt-image rectangular pixelization +(``RectangularAdaptImage`` + ``Adapt`` regularization). + +Two paths are exercised: 1. ``fitness._vmap`` batch evaluation. 2. ``jax.jit(analysis.fit_from)`` scalar round-trip — relies on - ``AnalysisImaging._register_fit_imaging_pytrees``. - -Note: this port uses ``ag.mesh.RectangularUniform`` + ``ag.reg.Constant`` (no -adapt images). The adapt-image variant (``RectangularAdaptImage`` + -``ag.reg.Adapt``) hits a post-unflatten Galaxy-identity mismatch in -``AdaptImages.galaxy_image_dict`` that the autogalaxy library does not yet -resolve across the JIT boundary — a separate library fix is required there. + ``AnalysisImaging._register_fit_imaging_pytrees`` and on + ``AdaptImages.image_for_galaxy`` resolving fresh-Galaxy lookups via the + path-tuple list across the JIT boundary. """ import time @@ -53,16 +51,32 @@ ) dataset = dataset.apply_mask(mask=mask) -dataset = dataset.apply_over_sampling(over_sample_size_lp=1) +dataset = dataset.apply_over_sampling( + over_sample_size_lp=4, + over_sample_size_pixelization=4, +) + +""" +__Adapt Images__ + +The galaxy is named ``galaxy`` in the model, so the path tuple is +``('galaxies', 'galaxy')``. ``dataset.data`` is used as a stand-in for the +"previous-fit" galaxy image — sufficient to exercise the adapt-image code paths. +""" +galaxy_name_image_dict = { + "('galaxies', 'galaxy')": dataset.data, +} + +adapt_images = ag.AdaptImages(galaxy_name_image_dict=galaxy_name_image_dict) """ __Model__ -Single galaxy with a rectangular pixelization. No lens/source split, no mass -profile, no adapt images. +Single galaxy with an adapt-image rectangular pixelization. The mesh shape is +fixed (28 x 28) per the JAX static-shape requirement. """ -mesh = ag.mesh.RectangularUniform(shape=(28, 28)) -regularization = ag.reg.Constant(coefficient=1.0) +mesh = ag.mesh.RectangularAdaptImage(shape=(28, 28), weight_power=1.0) +regularization = ag.reg.Adapt() pixelization = ag.Pixelization(mesh=mesh, regularization=regularization) galaxy = af.Model(ag.Galaxy, redshift=0.5, pixelization=pixelization) @@ -71,7 +85,15 @@ print(model.info) -analysis = ag.AnalysisImaging(dataset=dataset) +analysis = ag.AnalysisImaging( + dataset=dataset, + adapt_images=adapt_images, + settings=ag.Settings( + use_border_relocator=True, + use_positive_only_solver=True, + use_mixed_precision=True, + ), +) """ __vmap Path__ @@ -112,11 +134,29 @@ instance = model.instance_from_prior_medians() -analysis_np = ag.AnalysisImaging(dataset=dataset, use_jax=False) +analysis_np = ag.AnalysisImaging( + dataset=dataset, + adapt_images=adapt_images, + settings=ag.Settings( + use_border_relocator=True, + use_positive_only_solver=True, + use_mixed_precision=True, + ), + use_jax=False, +) fit_np = analysis_np.fit_from(instance=instance) print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) -analysis_jit = ag.AnalysisImaging(dataset=dataset, use_jax=True) +analysis_jit = ag.AnalysisImaging( + dataset=dataset, + adapt_images=adapt_images, + settings=ag.Settings( + use_border_relocator=True, + use_positive_only_solver=True, + use_mixed_precision=True, + ), + use_jax=True, +) fit_jit_fn = jax.jit(analysis_jit.fit_from) fit = fit_jit_fn(instance) @@ -125,6 +165,6 @@ f"expected jax.Array, got {type(fit.log_likelihood)}" ) np.testing.assert_allclose( - float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-2 ) print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/scripts/jax_likelihood_functions/imaging/rectangular_mge.py b/scripts/jax_likelihood_functions/imaging/rectangular_mge.py new file mode 100644 index 0000000..0ec9dd8 --- /dev/null +++ b/scripts/jax_likelihood_functions/imaging/rectangular_mge.py @@ -0,0 +1,183 @@ +""" +JAX Likelihood: Rectangular Adapt-Image Pixelization + MGE Bulge +================================================================= + +Two-galaxy autogalaxy model: a foreground galaxy with an MGE bulge and a +second galaxy with an adapt-image rectangular pixelization +(``RectangularAdaptImage`` + ``Constant`` regularization). + +This is the multi-pixelization regression case the path-tuple library fix +was made for: prior to the fix the autolens fallback would silently return +the wrong adapt image when more than one galaxy is present. + +Two paths are exercised: + +1. ``fitness._vmap`` batch evaluation. +2. ``jax.jit(analysis.fit_from)`` scalar round-trip — relies on + ``AnalysisImaging._register_fit_imaging_pytrees`` and on + ``AdaptImages.image_for_galaxy`` resolving fresh-Galaxy lookups via the + path-tuple list. +""" + +import time +from os import path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autogalaxy as ag + + +dataset_path = path.join("dataset", "imaging", "jax_test") + +if not path.exists(path.join(dataset_path, "data.fits")): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"], + check=True, + ) + +dataset = ag.Imaging.from_fits( + data_path=path.join(dataset_path, "data.fits"), + psf_path=path.join(dataset_path, "psf.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + pixel_scales=0.2, +) + +mask_radius = 3.0 + +mask = ag.Mask2D.circular( + shape_native=dataset.shape_native, + pixel_scales=dataset.pixel_scales, + radius=mask_radius, +) + +dataset = dataset.apply_mask(mask=mask) + +over_sample_size = ag.util.over_sample.over_sample_size_via_radial_bins_from( + grid=dataset.grid, + sub_size_list=[4, 2, 1], + radial_list=[0.3, 0.6], + centre_list=[(0.0, 0.0)], +) + +dataset = dataset.apply_over_sampling( + over_sample_size_lp=over_sample_size, + over_sample_size_pixelization=4, +) + +""" +__Adapt Images__ + +The model has two galaxies named ``galaxy_0`` (MGE bulge) and ``galaxy_1`` +(pixelization). ``galaxy_1`` is the only one that needs an adapt image, but +``galaxy_0`` is included in the dict to keep the path list aligned with all +galaxies in the analysis. +""" +galaxy_name_image_dict = { + "('galaxies', 'galaxy_0')": dataset.data, + "('galaxies', 'galaxy_1')": dataset.data, +} + +adapt_images = ag.AdaptImages(galaxy_name_image_dict=galaxy_name_image_dict) + +""" +__Model__ + +galaxy_0: MGE bulge — provides linear light profiles. +galaxy_1: rectangular adapt-image pixelization — exercises the adapt-image +inversion path that was previously broken across the JIT boundary. +""" +bulge = ag.model_util.mge_model_from( + mask_radius=mask_radius, + total_gaussians=20, + centre_prior_is_uniform=True, +) + +galaxy_0 = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge) + +mesh = ag.mesh.RectangularAdaptImage(shape=(28, 28)) +regularization = ag.reg.Constant(coefficient=1.0) +pixelization = ag.Pixelization(mesh=mesh, regularization=regularization) + +galaxy_1 = af.Model(ag.Galaxy, redshift=0.5, pixelization=pixelization) + +model = af.Collection( + galaxies=af.Collection(galaxy_0=galaxy_0, galaxy_1=galaxy_1) +) + +print(model.info) + +settings = ag.Settings( + use_border_relocator=True, + use_positive_only_solver=True, + use_mixed_precision=True, +) + +analysis = ag.AnalysisImaging( + dataset=dataset, adapt_images=adapt_images, settings=settings +) + +""" +__vmap Path__ +""" +from autofit.non_linear.fitness import Fitness + +batch_size = 3 + +fitness = Fitness( + model=model, + analysis=analysis, + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) + +parameters = np.zeros((batch_size, model.total_free_parameters)) +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians +parameters = jnp.array(parameters) + +start = time.time() +result = fitness._vmap(parameters) +print(result) +print("JAX Time To VMAP + JIT Function:", time.time() - start) + +start = time.time() +result = fitness._vmap(parameters) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) + +""" +__Path A: jit-wrap ``analysis.fit_from``__ +""" +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() +register_model(model) + +instance = model.instance_from_prior_medians() + +analysis_np = ag.AnalysisImaging( + dataset=dataset, adapt_images=adapt_images, settings=settings, use_jax=False +) +fit_np = analysis_np.fit_from(instance=instance) +print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood)) + +analysis_jit = ag.AnalysisImaging( + dataset=dataset, adapt_images=adapt_images, settings=settings, use_jax=True +) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance(fit.log_likelihood, jnp.ndarray), ( + f"expected jax.Array, got {type(fit.log_likelihood)}" +) +np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-2 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/smoke_tests.txt b/smoke_tests.txt index 6306d80..4326ac0 100644 --- a/smoke_tests.txt +++ b/smoke_tests.txt @@ -6,5 +6,8 @@ jax_likelihood_functions/imaging/lp.py jax_likelihood_functions/imaging/mge.py jax_likelihood_functions/imaging/mge_group.py jax_likelihood_functions/imaging/rectangular.py +jax_likelihood_functions/imaging/rectangular_mge.py +jax_likelihood_functions/imaging/delaunay.py +# jax_likelihood_functions/imaging/delaunay_mge.py # disabled: jax 0.7 removed jax.interpreters.xla.pytype_aval_mappings — see admin_jammy/prompt/build/smoke_workspace_fixes.md imaging/model_fit.py imaging/visualization.py