Skip to content

Commit f2193a8

Browse files
Jammy2211Jammy2211
authored andcommitted
jax: gated pytree registration for AbstractNDArray + register_instance_pytree helper
Reintroduces the removed auto-registration of AbstractNDArray subclasses as JAX pytrees, gated so it only runs when an instance is constructed on the JAX path (xp is not np). Each subclass pays the registration cost at most once via a module-level sentinel. Also adds register_instance_pytree(cls, no_flatten=...), a generic helper for non-AbstractNDArray classes (FitImaging, Tracer, DatasetModel, etc.) that flattens __dict__ and carries no_flatten names through aux_data for per-analysis constants like dataset/settings/cosmology. Issue: PyAutoLabs/PyAutoLens#444
1 parent 1897037 commit f2193a8

2 files changed

Lines changed: 136 additions & 0 deletions

File tree

autoarray/abstract_ndarray.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,75 @@ def wrapper(self, other):
6262
return wrapper
6363

6464

65+
_pytree_registered_classes: set = set()
66+
67+
68+
def _register_as_pytree(cls):
69+
"""Register ``cls`` with ``jax.tree_util`` via the lazy autoconf wrapper.
70+
71+
Gated: only called when a subclass instance is constructed on the JAX path
72+
(``xp is not np``). The registration is class-scoped via
73+
``_pytree_registered_classes`` so each subclass pays the cost at most once
74+
regardless of how many instances are created. No-ops if JAX is not installed.
75+
"""
76+
if cls in _pytree_registered_classes:
77+
return
78+
from autoconf.jax_wrapper import register_pytree_node
79+
80+
register_pytree_node(cls, cls.instance_flatten, cls.instance_unflatten)
81+
_pytree_registered_classes.add(cls)
82+
83+
84+
def register_instance_pytree(cls, no_flatten=()):
85+
"""Register any class with ``jax.tree_util`` via ``__dict__`` flattening.
86+
87+
Generic counterpart to :func:`_register_as_pytree` for classes that are
88+
*not* ``AbstractNDArray`` subclasses but still need to round-trip through
89+
``jax.jit`` (e.g. ``FitImaging``, ``Tracer``, ``Imaging``). Attributes are
90+
partitioned using ``no_flatten``:
91+
92+
* Names **not** in ``no_flatten`` ride as pytree children — JAX traces them
93+
and can substitute new values on unflatten (dynamic per fit).
94+
* Names **in** ``no_flatten`` ride as ``aux_data`` — JAX treats them as
95+
opaque Python objects, closing over the original reference across the
96+
JIT boundary. Appropriate for per-analysis constants (dataset, settings,
97+
cosmology, adapt images).
98+
99+
Reconstructs via ``cls.__new__`` + ``setattr`` (side-effect-free — no
100+
``__init__`` re-entry). Idempotent.
101+
"""
102+
if cls in _pytree_registered_classes:
103+
return
104+
from autoconf.jax_wrapper import register_pytree_node
105+
106+
no_flatten_set = frozenset(no_flatten)
107+
108+
def flatten(instance):
109+
dyn: list = []
110+
static: list = []
111+
for key, value in sorted(instance.__dict__.items()):
112+
if key in no_flatten_set:
113+
static.append((key, value))
114+
else:
115+
dyn.append((key, value))
116+
dyn_keys = tuple(k for k, _ in dyn)
117+
dyn_values = tuple(v for _, v in dyn)
118+
static_items = tuple(static)
119+
return dyn_values, (dyn_keys, static_items)
120+
121+
def unflatten(aux, children):
122+
dyn_keys, static_items = aux
123+
new = cls.__new__(cls)
124+
for key, value in zip(dyn_keys, children):
125+
setattr(new, key, value)
126+
for key, value in static_items:
127+
setattr(new, key, value)
128+
return new
129+
130+
register_pytree_node(cls, flatten, unflatten)
131+
_pytree_registered_classes.add(cls)
132+
133+
65134
class AbstractNDArray(ABC):
66135

67136
__no_flatten__ = ()
@@ -76,6 +145,9 @@ def __init__(self, array, xp=np):
76145

77146
self.use_jax = xp is not np
78147

148+
if self.use_jax:
149+
_register_as_pytree(type(self))
150+
79151
@property
80152
def is_transformed(self) -> bool:
81153
return self._is_transformed

test_autoarray/test_jax_pytree.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Tests for gated JAX pytree registration of ``AbstractNDArray`` subclasses.
2+
3+
Follows the three-step pattern from ``autolens_workspace_test/scripts/hessian_jax.py``:
4+
1. NumPy path — confirm autoarray type with ``np.ndarray`` backing, no pytree registration.
5+
2. JAX path outside JIT — same autoarray type with ``jax.Array`` backing; pytree registered.
6+
3. JAX path through ``jax.jit`` — round-trip the instance and assert the output carries
7+
a ``jax.Array`` leaf.
8+
"""
9+
10+
import numpy as np
11+
import numpy.testing as npt
12+
import pytest
13+
14+
jax = pytest.importorskip("jax")
15+
jnp = pytest.importorskip("jax.numpy")
16+
17+
from autoarray.abstract_ndarray import AbstractNDArray, _pytree_registered_classes
18+
19+
20+
class _LeafArray(AbstractNDArray):
21+
"""Minimal concrete ``AbstractNDArray`` with no nested autoarray children.
22+
23+
Isolates the pytree-registration machinery from the larger autoarray
24+
hierarchy: a real ``Array2D`` also carries a ``Mask2D`` and other nested
25+
``AbstractNDArray`` children whose own registration is covered by
26+
follow-up steps in the ``fit-imaging-pytree`` task.
27+
"""
28+
29+
@property
30+
def native(self):
31+
return self
32+
33+
34+
def test_numpy_path_does_not_register_pytree():
35+
_pytree_registered_classes.discard(_LeafArray)
36+
37+
arr = _LeafArray(np.array([1.0, 2.0, 3.0]))
38+
39+
assert isinstance(arr._array, np.ndarray)
40+
assert _LeafArray not in _pytree_registered_classes
41+
42+
43+
def test_jax_path_registers_pytree_once():
44+
_pytree_registered_classes.discard(_LeafArray)
45+
46+
arr_jax = _LeafArray(jnp.array([1.0, 2.0, 3.0]), xp=jnp)
47+
48+
assert isinstance(arr_jax._array, jnp.ndarray)
49+
assert _LeafArray in _pytree_registered_classes
50+
51+
# Second construction on the JAX path is a no-op; class stays registered.
52+
_LeafArray(jnp.array([4.0, 5.0]), xp=jnp)
53+
assert _LeafArray in _pytree_registered_classes
54+
55+
56+
def test_jax_jit_round_trip_returns_wrapper_with_jax_array():
57+
arr_jax = _LeafArray(jnp.array([1.0, 2.0, 3.0]), xp=jnp)
58+
assert _LeafArray in _pytree_registered_classes
59+
60+
result = jax.jit(lambda a: a)(arr_jax)
61+
62+
assert isinstance(result, _LeafArray)
63+
assert isinstance(result._array, jnp.ndarray)
64+
npt.assert_allclose(np.asarray(result._array), np.asarray(arr_jax._array))

0 commit comments

Comments
 (0)