Skip to content

Commit ab6d5d3

Browse files
authored
Merge pull request #1228 from PyAutoLabs/feature/jax-visualization
feat: add use_jax_for_visualization flag and fit_for_visualization dispatch
2 parents afaf6e6 + f54ff92 commit ab6d5d3

2 files changed

Lines changed: 116 additions & 3 deletions

File tree

autofit/non_linear/analysis/analysis.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,48 @@ class Analysis(ABC):
3333
LATENT_KEYS = []
3434

3535
def __init__(
36-
self, use_jax : bool = False, **kwargs
36+
self,
37+
use_jax: bool = False,
38+
use_jax_for_visualization: bool = False,
39+
**kwargs,
3740
):
3841
import os
3942
if os.environ.get("PYAUTO_DISABLE_JAX") == "1":
4043
use_jax = False
44+
use_jax_for_visualization = False
45+
46+
if use_jax_for_visualization and not use_jax:
47+
logger.warning(
48+
"use_jax_for_visualization=True requires use_jax=True; "
49+
"disabling use_jax_for_visualization."
50+
)
51+
use_jax_for_visualization = False
4152

4253
self._use_jax = use_jax
54+
self._use_jax_for_visualization = use_jax_for_visualization
4355
self.kwargs = kwargs
4456

57+
def fit_for_visualization(self, instance):
58+
"""
59+
Build the fit used by the visualizer.
60+
61+
Currently a thin dispatch over ``self.fit_from``: when ``use_jax=True``
62+
the fit is built on the eager JAX path (``self._xp is jnp``) and the
63+
plotter materialises arrays to NumPy at the matplotlib boundary. The
64+
``use_jax_for_visualization`` flag is an explicit opt-in today — it
65+
marks the intent to use JAX for visualization, and is the dispatch
66+
point where full ``jax.jit``-wrapping will plug in once
67+
:class:`autolens.imaging.fit_imaging.FitImaging` and its nested
68+
autoarray types are registered as JAX pytrees (that work is tracked
69+
separately — see ``admin_jammy/prompt/autolens/fit_imaging_pytree.md``).
70+
71+
``fit_from`` is defined by Analysis subclasses (e.g. ``AnalysisImaging``),
72+
not the base class — this method is only callable on subclasses that
73+
provide it. Downstream visualizers should prefer this over calling
74+
``fit_from`` directly so the JIT seam stays in one place.
75+
"""
76+
return self.fit_from(instance=instance)
77+
4578
def __getattr__(self, item: str):
4679
"""
4780
If a method starts with 'visualize_' then we assume it is associated with
@@ -306,8 +339,15 @@ def supports_background_update(self) -> bool:
306339

307340
@property
308341
def supports_jax_visualization(self) -> bool:
309-
"""Whether the visualizer can work directly with JAX arrays."""
310-
return False
342+
"""
343+
Whether the visualizer can work directly with JAX arrays.
344+
345+
Derived from the ``use_jax_for_visualization`` flag passed at
346+
construction time. Subclasses may override to force a specific
347+
answer (e.g. an Analysis that has been audited to support JAX
348+
visualization unconditionally).
349+
"""
350+
return self._use_jax_for_visualization
311351

312352
def perform_quick_update(self, paths, instance):
313353
raise NotImplementedError
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Tests for the ``use_jax_for_visualization`` flag on ``Analysis``."""
2+
3+
import pytest
4+
5+
import autofit as af
6+
7+
8+
class _FittableAnalysis(af.Analysis):
9+
"""Minimal Analysis subclass with a trivial ``fit_from`` for dispatch tests."""
10+
11+
def __init__(self, **kwargs):
12+
super().__init__(**kwargs)
13+
self.fit_from_calls = 0
14+
15+
def log_likelihood_function(self, instance):
16+
return 0.0
17+
18+
def fit_from(self, instance):
19+
self.fit_from_calls += 1
20+
return ("fit", instance)
21+
22+
23+
def test_default_flag_is_false():
24+
analysis = af.Analysis()
25+
assert analysis._use_jax is False
26+
assert analysis._use_jax_for_visualization is False
27+
assert analysis.supports_jax_visualization is False
28+
29+
30+
def test_flag_requires_use_jax(caplog):
31+
with caplog.at_level("WARNING"):
32+
analysis = af.Analysis(use_jax=False, use_jax_for_visualization=True)
33+
assert analysis._use_jax_for_visualization is False
34+
assert any("requires use_jax=True" in r.message for r in caplog.records)
35+
36+
37+
def test_flag_accepted_when_use_jax_true():
38+
analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True)
39+
assert analysis._use_jax is True
40+
assert analysis._use_jax_for_visualization is True
41+
assert analysis.supports_jax_visualization is True
42+
43+
44+
def test_pyauto_disable_jax_env_var_clears_both_flags(monkeypatch):
45+
monkeypatch.setenv("PYAUTO_DISABLE_JAX", "1")
46+
analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True)
47+
assert analysis._use_jax is False
48+
assert analysis._use_jax_for_visualization is False
49+
50+
51+
def test_fit_for_visualization_dispatches_to_fit_from():
52+
analysis = _FittableAnalysis(use_jax=True, use_jax_for_visualization=True)
53+
result = analysis.fit_for_visualization(instance="sentinel")
54+
assert result == ("fit", "sentinel")
55+
assert analysis.fit_from_calls == 1
56+
57+
58+
def test_fit_for_visualization_works_without_flag():
59+
analysis = _FittableAnalysis()
60+
result = analysis.fit_for_visualization(instance="sentinel")
61+
assert result == ("fit", "sentinel")
62+
assert analysis.fit_from_calls == 1
63+
64+
65+
def test_subclass_can_override_supports_jax_visualization():
66+
class ForcedAnalysis(af.Analysis):
67+
@property
68+
def supports_jax_visualization(self):
69+
return True
70+
71+
analysis = ForcedAnalysis()
72+
assert analysis._use_jax_for_visualization is False
73+
assert analysis.supports_jax_visualization is True

0 commit comments

Comments
 (0)