Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion scripts/ellipse/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def scaled_multipole_model(m: int, cos_amp: float, sin_amp: float, major_axis: f
sub_output.mkdir(parents=True)
sub_paths = SimpleNamespace(image_path=sub_path, output_path=sub_output)

analysis = ag.AnalysisEllipse(dataset=dataset, title_prefix=scenario_name.upper())
analysis = ag.AnalysisEllipse(
dataset=dataset, title_prefix=scenario_name.upper(), use_jax=False
)
instance = model.instance_from_prior_medians()

_t0 = time.perf_counter()
Expand Down
67 changes: 55 additions & 12 deletions scripts/jax_likelihood_functions/ellipse/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from os import path

import numpy as np

import autofit as af
import autogalaxy as ag

Expand Down Expand Up @@ -65,7 +67,7 @@
"""
__Analysis (NumPy Path)__
"""
analysis = ag.AnalysisEllipse(dataset=dataset) # use_jax defaults to False
analysis = ag.AnalysisEllipse(dataset=dataset, use_jax=False)

instance = model.instance_from_prior_medians()

Expand All @@ -89,18 +91,59 @@
print(f" total_figure_of_merit= {total_figure_of_merit:.8f}")

"""
__TODO(7_analysis_ellipse_jax.md)__
__vmap Path__

Wrap the autofit ``Fitness`` in ``jax.vmap`` and evaluate a batch of parameter
vectors. This exercises the full likelihood pipeline through JIT.
"""
import time
import jax
import jax.numpy as jnp
from autofit.non_linear.fitness import Fitness

batch_size = 50

fitness = Fitness(
model=model,
analysis=ag.AnalysisEllipse(dataset=dataset, use_jax=True),
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
)

Once `AnalysisEllipse` gains the `use_jax: bool = True` flag and a
`_register_fit_ellipse_pytrees()` helper, this script should additionally:
parameters = np.zeros((batch_size, model.total_free_parameters))
for i in range(batch_size):
parameters[i, :] = model.physical_values_from_prior_medians
parameters = jnp.array(parameters)

analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True)
fit_jit_fn = jax.jit(analysis_jit.fit_from)
fit_jit = fit_jit_fn(instance)
start = time.time()
result = fitness._vmap(parameters)
print(result)
print("JAX Time To VMAP + JIT Function:", time.time() - start)

start = time.time()
result = fitness._vmap(parameters)
print("JAX Time Taken using VMAP:", time.time() - start)
print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size)

np.testing.assert_allclose(
float(fit_jit.log_likelihood),
total_log_likelihood,
rtol=1e-4,
)
"""
__JIT fit_from round-trip__

Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitEllipseSummed``
with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar.
"""
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()
register_model(model)

analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True)
fit_jit_fn = jax.jit(analysis_jit.fit_from)
fit = fit_jit_fn(instance)

print("JIT fit.log_likelihood:", fit.log_likelihood)
assert isinstance(fit.log_likelihood, jnp.ndarray), \
f"expected jax.Array, got {type(fit.log_likelihood)}"
np.testing.assert_allclose(
float(fit.log_likelihood), total_log_likelihood, rtol=1e-4
)
print("PASS: jit(fit_from) round-trip matches NumPy scalar.")
67 changes: 55 additions & 12 deletions scripts/jax_likelihood_functions/ellipse/multipoles.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from os import path

import numpy as np

import autofit as af
import autogalaxy as ag

Expand Down Expand Up @@ -77,7 +79,7 @@
"""
__Analysis (NumPy Path)__
"""
analysis = ag.AnalysisEllipse(dataset=dataset) # use_jax defaults to False
analysis = ag.AnalysisEllipse(dataset=dataset, use_jax=False)

instance = model.instance_from_prior_medians()

Expand All @@ -101,18 +103,59 @@
print(f" total_figure_of_merit= {total_figure_of_merit:.8f}")

"""
__TODO(7_analysis_ellipse_jax.md)__
__vmap Path__

Wrap the autofit ``Fitness`` in ``jax.vmap`` and evaluate a batch of parameter
vectors. This exercises the full likelihood pipeline through JIT.
"""
import time
import jax
import jax.numpy as jnp
from autofit.non_linear.fitness import Fitness

batch_size = 50

fitness = Fitness(
model=model,
analysis=ag.AnalysisEllipse(dataset=dataset, use_jax=True),
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
)

Once `AnalysisEllipse` gains the `use_jax: bool = True` flag and a
`_register_fit_ellipse_pytrees()` helper, this script should additionally:
parameters = np.zeros((batch_size, model.total_free_parameters))
for i in range(batch_size):
parameters[i, :] = model.physical_values_from_prior_medians
parameters = jnp.array(parameters)

analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True)
fit_jit_fn = jax.jit(analysis_jit.fit_from)
fit_jit = fit_jit_fn(instance)
start = time.time()
result = fitness._vmap(parameters)
print(result)
print("JAX Time To VMAP + JIT Function:", time.time() - start)

start = time.time()
result = fitness._vmap(parameters)
print("JAX Time Taken using VMAP:", time.time() - start)
print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size)

np.testing.assert_allclose(
float(fit_jit.log_likelihood),
total_log_likelihood,
rtol=1e-4,
)
"""
__JIT fit_from round-trip__

Assert that ``jax.jit(analysis.fit_from)(instance)`` returns a ``FitEllipseSummed``
with a ``jax.Array`` ``log_likelihood`` matching the NumPy-path scalar.
"""
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()
register_model(model)

analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True)
fit_jit_fn = jax.jit(analysis_jit.fit_from)
fit = fit_jit_fn(instance)

print("JIT fit.log_likelihood:", fit.log_likelihood)
assert isinstance(fit.log_likelihood, jnp.ndarray), \
f"expected jax.Array, got {type(fit.log_likelihood)}"
np.testing.assert_allclose(
float(fit.log_likelihood), total_log_likelihood, rtol=1e-4
)
print("PASS: jit(fit_from) round-trip matches NumPy scalar.")
Loading