Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ output/
*.log
.pytest_cache/
failed/
dataset/
Empty file.
Empty file.
129 changes: 129 additions & 0 deletions scripts/jax_likelihood_functions/imaging/lp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
JAX Likelihood: Parametric Light Profile
========================================

Verify that JAX can compute the log-likelihood of an ``Imaging`` fit for an
autogalaxy model composed of a linear Sersic bulge. Two paths are exercised:

1. ``fitness._vmap`` batch evaluation (tests ``jax.vmap`` + ``jax.jit`` on the
autofit ``Fitness`` wrapper).
2. ``jax.jit(analysis.fit_from)`` round-trip, which relies on the pytree
registration added to ``AnalysisImaging._register_fit_imaging_pytrees`` —
this path exercises the full ``FitImaging`` return value flattening.
"""

import time
from os import path

import jax
import jax.numpy as jnp
import numpy as np

import autofit as af
import autogalaxy as ag


dataset_path = path.join("dataset", "imaging", "jax_test")

if not path.exists(path.join(dataset_path, "data.fits")):
import subprocess
import sys

subprocess.run(
[sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"],
check=True,
)

dataset = ag.Imaging.from_fits(
data_path=path.join(dataset_path, "data.fits"),
psf_path=path.join(dataset_path, "psf.fits"),
noise_map_path=path.join(dataset_path, "noise_map.fits"),
pixel_scales=0.2,
)

mask = ag.Mask2D.circular(
shape_native=dataset.shape_native,
pixel_scales=dataset.pixel_scales,
radius=3.0,
)

dataset = dataset.apply_mask(mask=mask)
dataset = dataset.apply_over_sampling(over_sample_size_lp=1)

"""
__Model__

Single galaxy with a linear Sersic bulge — no lens/source split, no mass profile.
"""
bulge = af.Model(ag.lp.Sersic)

galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge)

model = af.Collection(galaxies=af.Collection(galaxy=galaxy))

print(model.info)

analysis = ag.AnalysisImaging(dataset=dataset)

"""
__vmap Path__

Wrap the autofit ``Fitness`` in ``jax.vmap`` and evaluate a batch of parameter
vectors. This tests that the full likelihood pipeline JIT-compiles end to end.
"""
from autofit.non_linear.fitness import Fitness

batch_size = 50

fitness = Fitness(
model=model,
analysis=analysis,
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
)

parameters = np.zeros((batch_size, model.total_free_parameters))
for i in range(batch_size):
parameters[i, :] = model.physical_values_from_prior_medians
parameters = jnp.array(parameters)

start = time.time()
result = fitness._vmap(parameters)
print(result)
print("JAX Time To VMAP + JIT Function:", time.time() - start)

start = time.time()
result = fitness._vmap(parameters)
print("JAX Time Taken using VMAP:", time.time() - start)
print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size)

"""
__Path A: jit-wrap ``analysis.fit_from``__

Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitImaging``
with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. This
is the part unblocked by ``_register_fit_imaging_pytrees``.
"""
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()
register_model(model)

instance = model.instance_from_prior_medians()

analysis_np = ag.AnalysisImaging(dataset=dataset, use_jax=False)
fit_np = analysis_np.fit_from(instance=instance)
print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood))

analysis_jit = ag.AnalysisImaging(dataset=dataset, use_jax=True)
fit_jit_fn = jax.jit(analysis_jit.fit_from)
fit = fit_jit_fn(instance)

print("JIT fit.log_likelihood:", fit.log_likelihood)
assert isinstance(fit.log_likelihood, jnp.ndarray), (
f"expected jax.Array, got {type(fit.log_likelihood)}"
)
np.testing.assert_allclose(
float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4
)
print("PASS: jit(fit_from) round-trip matches NumPy scalar.")
142 changes: 142 additions & 0 deletions scripts/jax_likelihood_functions/imaging/mge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
JAX Likelihood: MGE Basis Light Profile
========================================

Verify that JAX can compute the log-likelihood of an ``Imaging`` fit for an
autogalaxy model composed of a Multi-Gaussian Expansion (MGE) linear basis.
Two paths are exercised:

1. ``fitness._vmap`` batch evaluation (tests ``jax.vmap`` + ``jax.jit`` on the
autofit ``Fitness`` wrapper).
2. ``jax.jit(analysis.fit_from)`` round-trip, which relies on the pytree
registration added to ``AnalysisImaging._register_fit_imaging_pytrees`` —
this path exercises the full ``FitImaging`` return value flattening.
"""

import time
from os import path

import jax
import jax.numpy as jnp
import numpy as np

import autofit as af
import autogalaxy as ag


dataset_path = path.join("dataset", "imaging", "jax_test")

if not path.exists(path.join(dataset_path, "data.fits")):
import subprocess
import sys

subprocess.run(
[sys.executable, "scripts/jax_likelihood_functions/imaging/simulator.py"],
check=True,
)

dataset = ag.Imaging.from_fits(
data_path=path.join(dataset_path, "data.fits"),
psf_path=path.join(dataset_path, "psf.fits"),
noise_map_path=path.join(dataset_path, "noise_map.fits"),
pixel_scales=0.2,
)

mask_radius = 3.0

mask = ag.Mask2D.circular(
shape_native=dataset.shape_native,
pixel_scales=dataset.pixel_scales,
radius=mask_radius,
)

dataset = dataset.apply_mask(mask=mask)

over_sample_size = ag.util.over_sample.over_sample_size_via_radial_bins_from(
grid=dataset.grid,
sub_size_list=[4, 2, 1],
radial_list=[0.3, 0.6],
centre_list=[(0.0, 0.0)],
)

dataset = dataset.apply_over_sampling(over_sample_size_lp=over_sample_size)

"""
__Model__

Single galaxy with an MGE linear basis light profile — no lens/source split,
no mass profile.
"""
bulge = ag.model_util.mge_model_from(
mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=True
)

galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge)

model = af.Collection(galaxies=af.Collection(galaxy=galaxy))

print(model.info)

analysis = ag.AnalysisImaging(dataset=dataset)

"""
__vmap Path__

Wrap the autofit ``Fitness`` in ``jax.vmap`` and evaluate a batch of parameter
vectors. This tests that the full likelihood pipeline JIT-compiles end to end.
"""
from autofit.non_linear.fitness import Fitness

batch_size = 50

fitness = Fitness(
model=model,
analysis=analysis,
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
)

parameters = np.zeros((batch_size, model.total_free_parameters))
for i in range(batch_size):
parameters[i, :] = model.physical_values_from_prior_medians
parameters = jnp.array(parameters)

start = time.time()
result = fitness._vmap(parameters)
print(result)
print("JAX Time To VMAP + JIT Function:", time.time() - start)

start = time.time()
result = fitness._vmap(parameters)
print("JAX Time Taken using VMAP:", time.time() - start)
print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size)

"""
__Path A: jit-wrap ``analysis.fit_from``__

Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitImaging``
with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar.
"""
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()
register_model(model)

instance = model.instance_from_prior_medians()

analysis_np = ag.AnalysisImaging(dataset=dataset, use_jax=False)
fit_np = analysis_np.fit_from(instance=instance)
print("NumPy fit.log_likelihood:", float(fit_np.log_likelihood))

analysis_jit = ag.AnalysisImaging(dataset=dataset, use_jax=True)
fit_jit_fn = jax.jit(analysis_jit.fit_from)
fit = fit_jit_fn(instance)

print("JIT fit.log_likelihood:", fit.log_likelihood)
assert isinstance(fit.log_likelihood, jnp.ndarray), (
f"expected jax.Array, got {type(fit.log_likelihood)}"
)
np.testing.assert_allclose(
float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4
)
print("PASS: jit(fit_from) round-trip matches NumPy scalar.")
Loading
Loading