diff --git a/scripts/ellipse/visualization.py b/scripts/ellipse/visualization.py index a409cf2..66ee703 100644 --- a/scripts/ellipse/visualization.py +++ b/scripts/ellipse/visualization.py @@ -230,7 +230,9 @@ def scaled_multipole_model(m: int, cos_amp: float, sin_amp: float, major_axis: f sub_output.mkdir(parents=True) sub_paths = SimpleNamespace(image_path=sub_path, output_path=sub_output) - analysis = ag.AnalysisEllipse(dataset=dataset, title_prefix=scenario_name.upper()) + analysis = ag.AnalysisEllipse( + dataset=dataset, title_prefix=scenario_name.upper(), use_jax=False + ) instance = model.instance_from_prior_medians() _t0 = time.perf_counter() diff --git a/scripts/jax_likelihood_functions/ellipse/fit.py b/scripts/jax_likelihood_functions/ellipse/fit.py index 2bbd6c8..5f7cc08 100644 --- a/scripts/jax_likelihood_functions/ellipse/fit.py +++ b/scripts/jax_likelihood_functions/ellipse/fit.py @@ -12,6 +12,8 @@ from os import path +import numpy as np + import autofit as af import autogalaxy as ag @@ -65,7 +67,7 @@ """ __Analysis (NumPy Path)__ """ -analysis = ag.AnalysisEllipse(dataset=dataset) # use_jax defaults to False +analysis = ag.AnalysisEllipse(dataset=dataset, use_jax=False) instance = model.instance_from_prior_medians() @@ -89,18 +91,59 @@ print(f" total_figure_of_merit= {total_figure_of_merit:.8f}") """ -__TODO(7_analysis_ellipse_jax.md)__ +__vmap Path__ + +Wrap the autofit ``Fitness`` in ``jax.vmap`` and evaluate a batch of parameter +vectors. This exercises the full likelihood pipeline through JIT. +""" +import time +import jax +import jax.numpy as jnp +from autofit.non_linear.fitness import Fitness + +batch_size = 50 + +fitness = Fitness( + model=model, + analysis=ag.AnalysisEllipse(dataset=dataset, use_jax=True), + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) -Once `AnalysisEllipse` gains the `use_jax: bool = True` flag and a -`_register_fit_ellipse_pytrees()` helper, this script should additionally: +parameters = np.zeros((batch_size, model.total_free_parameters)) +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians +parameters = jnp.array(parameters) - analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True) - fit_jit_fn = jax.jit(analysis_jit.fit_from) - fit_jit = fit_jit_fn(instance) +start = time.time() +result = fitness._vmap(parameters) +print(result) +print("JAX Time To VMAP + JIT Function:", time.time() - start) + +start = time.time() +result = fitness._vmap(parameters) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) - np.testing.assert_allclose( - float(fit_jit.log_likelihood), - total_log_likelihood, - rtol=1e-4, - ) """ +__JIT fit_from round-trip__ + +Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitEllipseSummed`` +with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. +""" +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() +register_model(model) + +analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance(fit.log_likelihood, jnp.ndarray), \ + f"expected jax.Array, got {type(fit.log_likelihood)}" +np.testing.assert_allclose( + float(fit.log_likelihood), total_log_likelihood, rtol=1e-4 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.") diff --git a/scripts/jax_likelihood_functions/ellipse/multipoles.py b/scripts/jax_likelihood_functions/ellipse/multipoles.py index e4a8e9d..39e6de2 100644 --- a/scripts/jax_likelihood_functions/ellipse/multipoles.py +++ b/scripts/jax_likelihood_functions/ellipse/multipoles.py @@ -16,6 +16,8 @@ from os import path +import numpy as np + import autofit as af import autogalaxy as ag @@ -77,7 +79,7 @@ """ __Analysis (NumPy Path)__ """ -analysis = ag.AnalysisEllipse(dataset=dataset) # use_jax defaults to False +analysis = ag.AnalysisEllipse(dataset=dataset, use_jax=False) instance = model.instance_from_prior_medians() @@ -101,18 +103,59 @@ print(f" total_figure_of_merit= {total_figure_of_merit:.8f}") """ -__TODO(7_analysis_ellipse_jax.md)__ +__vmap Path__ + +Wrap the autofit ``Fitness`` in ``jax.vmap`` and evaluate a batch of parameter +vectors. This exercises the full likelihood pipeline through JIT. +""" +import time +import jax +import jax.numpy as jnp +from autofit.non_linear.fitness import Fitness + +batch_size = 50 + +fitness = Fitness( + model=model, + analysis=ag.AnalysisEllipse(dataset=dataset, use_jax=True), + fom_is_log_likelihood=True, + resample_figure_of_merit=-1.0e99, +) -Once `AnalysisEllipse` gains the `use_jax: bool = True` flag and a -`_register_fit_ellipse_pytrees()` helper, this script should additionally: +parameters = np.zeros((batch_size, model.total_free_parameters)) +for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians +parameters = jnp.array(parameters) - analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True) - fit_jit_fn = jax.jit(analysis_jit.fit_from) - fit_jit = fit_jit_fn(instance) +start = time.time() +result = fitness._vmap(parameters) +print(result) +print("JAX Time To VMAP + JIT Function:", time.time() - start) + +start = time.time() +result = fitness._vmap(parameters) +print("JAX Time Taken using VMAP:", time.time() - start) +print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size) - np.testing.assert_allclose( - float(fit_jit.log_likelihood), - total_log_likelihood, - rtol=1e-4, - ) """ +__JIT fit_from round-trip__ + +Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitEllipseSummed`` +with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar. +""" +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() +register_model(model) + +analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True) +fit_jit_fn = jax.jit(analysis_jit.fit_from) +fit = fit_jit_fn(instance) + +print("JIT fit.log_likelihood:", fit.log_likelihood) +assert isinstance(fit.log_likelihood, jnp.ndarray), \ + f"expected jax.Array, got {type(fit.log_likelihood)}" +np.testing.assert_allclose( + float(fit.log_likelihood), total_log_likelihood, rtol=1e-4 +) +print("PASS: jit(fit_from) round-trip matches NumPy scalar.")