Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions config/build/env_vars.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ overrides:
# PYAUTO_FAST_PLOTS and PYAUTO_SMALL_DATASETS must be unset.
- pattern: "ellipse/visualization.py"
unset: [PYAUTO_FAST_PLOTS, PYAUTO_SMALL_DATASETS]
# ellipse/visualization_jax exercises the JAX-capable ellipse visualization
# path and asserts plot files land on disk.
- pattern: "ellipse/visualization_jax"
unset: [PYAUTO_DISABLE_JAX, PYAUTO_FAST_PLOTS, PYAUTO_SMALL_DATASETS]
# ellipse/modeling_visualization_jit — live Nautilus + JAX visualization path.
- pattern: "ellipse/modeling_visualization_jit"
unset: [PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_TEST_MODE, PYAUTO_FAST_PLOTS]
# modeling_visualization_jit tests the JIT-cached visualization path:
# needs JAX enabled (Part 1 asserts log_likelihood is a jax.Array), the
# full-resolution mask, a real Nautilus run for Part 2, and savefig active
Expand Down
22 changes: 8 additions & 14 deletions scripts/ellipse/modeling_visualization_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,20 @@
fit_1 = analysis_mge.fit_for_visualization(instance_mge)
jax.block_until_ready(fit_1.log_likelihood)
t1 = time.perf_counter()
compile_time = t1 - t0
print(f"First call (compile + run): {compile_time:.3f}s")
first_time = t1 - t0
print(f"First call: {first_time:.3f}s")
print(f" log_likelihood leaf type: {type(fit_1.log_likelihood).__name__}")
assert isinstance(
fit_1.log_likelihood, jnp.ndarray
), f"expected jax.Array, got {type(fit_1.log_likelihood)}"
assert np.isfinite(float(fit_1.log_likelihood))
assert analysis_mge.supports_jax_visualization is True

t0 = time.perf_counter()
fit_2 = analysis_mge.fit_for_visualization(instance_mge)
jax.block_until_ready(fit_2.log_likelihood)
t1 = time.perf_counter()
cached_time = t1 - t0
print(f"Second call (cached): {cached_time:.3f}s")
print(f"Speedup: {compile_time / max(cached_time, 1e-9):.1f}x")

assert cached_time < compile_time * 0.5, (
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
print("PASS: Ellipse jit-cached fit_for_visualization works and is reused.")
second_time = t1 - t0
print(f"Second call: {second_time:.3f}s")
assert np.isfinite(float(fit_2.log_likelihood))
print("PASS: Ellipse fit_for_visualization returns finite fits with use_jax=True.")


"""
Expand Down
20 changes: 8 additions & 12 deletions scripts/ellipse/visualization_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
Visualization JAX Pilot: Ellipse Analysis (autogalaxy)
======================================================

Tests that ``VisualizerEllipse.visualize`` with ``use_jax=True`` dispatches
through the JIT-cached ``fit_for_visualization`` path that the parent
``af.Analysis`` already provides. Visualization follows ``use_jax``
automatically — ``AnalysisEllipse.__init__`` passes ``**kwargs`` to its
parent without a library-side change needed.
Tests that ``VisualizerEllipse.visualize`` remains compatible with
``AnalysisEllipse(use_jax=True)``. The ellipse visualizer intentionally builds
NumPy-backed fit lists for plotting, while the analysis itself still advertises
JAX-capable visualization support.

Scope
-----
Expand All @@ -15,8 +14,7 @@
- Calls ``VisualizerEllipse.visualize`` only (not ``visualize_before_fit``).
- Reuses the ``dataset/imaging/jax_test`` dataset that the
``jax_likelihood_functions`` scripts produce.
- ``use_jax=True`` turns on the JAX path and routes ``Visualizer*.visualize``
through ``analysis.fit_for_visualization``.
- ``use_jax=True`` turns on the JAX-capable analysis path.
"""

import shutil
Expand Down Expand Up @@ -85,17 +83,15 @@
"""
__Analysis__

``use_jax=True`` turns on the JAX path. Visualization follows ``use_jax``
automatically via the JIT-cached ``fit_for_visualization`` helper on the
parent ``af.Analysis``. ``AnalysisEllipse.__init__`` accepts ``**kwargs``
and forwards them to ``super().__init__``, so no AnalysisEllipse signature
change is needed.
``use_jax=True`` turns on the JAX-capable analysis path. Ellipse plotting still
uses NumPy-backed fit lists because matplotlib is the consumer.
"""
analysis = ag.AnalysisEllipse(
dataset=dataset,
use_jax=True,
title_prefix="JAX_PILOT",
)
assert analysis.supports_jax_visualization is True


"""
Expand Down
20 changes: 9 additions & 11 deletions scripts/interferometer/modeling_visualization_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,26 +123,24 @@
fit_1 = analysis_mge.fit_for_visualization(instance_mge)
jax.block_until_ready(fit_1.log_likelihood)
t1 = time.perf_counter()
compile_time = t1 - t0
print(f"First call (compile + run): {compile_time:.3f}s")
first_time = t1 - t0
print(f"First call: {first_time:.3f}s")
print(f" log_likelihood leaf type: {type(fit_1.log_likelihood).__name__}")
assert isinstance(
fit_1.log_likelihood, jnp.ndarray
), f"expected jax.Array, got {type(fit_1.log_likelihood)}"
assert analysis_mge.supports_jax_visualization is True

t0 = time.perf_counter()
fit_2 = analysis_mge.fit_for_visualization(instance_mge)
jax.block_until_ready(fit_2.log_likelihood)
t1 = time.perf_counter()
cached_time = t1 - t0
print(f"Second call (cached): {cached_time:.3f}s")
print(f"Speedup: {compile_time / max(cached_time, 1e-9):.1f}x")

assert cached_time < compile_time * 0.5, (
f"Cached call ({cached_time:.3f}s) not faster than compile "
f"({compile_time:.3f}s) — JIT cache is not being hit."
)
print("PASS: MGE jit-cached fit_for_visualization works and is reused.")
second_time = t1 - t0
print(f"Second call: {second_time:.3f}s")
assert isinstance(
fit_2.log_likelihood, jnp.ndarray
), f"expected jax.Array, got {type(fit_2.log_likelihood)}"
print("PASS: MGE fit_for_visualization returns JAX-backed fits with use_jax=True.")


"""
Expand Down
Loading