|
| 1 | +"""Regression: cached_property descriptors on AbstractNDArray subclasses |
| 2 | +must be filtered from ``instance_flatten`` so derived caches never reach |
| 3 | +the JAX pytree leaves. |
| 4 | +
|
| 5 | +NumPy-only per the project rule [[feedback_no_jax_in_unit_tests]]: |
| 6 | +exercise the ``instance_flatten`` classmethod directly (which is what |
| 7 | +the JAX pytree path delegates to) and assert composition is correct. |
| 8 | +""" |
| 9 | + |
| 10 | +import functools |
| 11 | + |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +from autoarray.abstract_ndarray import AbstractNDArray |
| 15 | + |
| 16 | + |
| 17 | +class _FakeArray(AbstractNDArray): |
| 18 | + """Minimal AbstractNDArray subclass that adds a ``@cached_property`` |
| 19 | + returning a string. Used to assert the guard filters it from |
| 20 | + ``instance_flatten``.""" |
| 21 | + |
| 22 | + __no_flatten__ = ("use_jax",) |
| 23 | + |
| 24 | + def __init__(self, array): |
| 25 | + # Skip AbstractNDArray.__init__ to avoid the JAX-registration path |
| 26 | + # — we only need the dict-shape for the flatten test. |
| 27 | + self._array = np.asarray(array) |
| 28 | + self._is_transformed = False |
| 29 | + self.use_jax = False |
| 30 | + |
| 31 | + @property |
| 32 | + def native(self): |
| 33 | + # AbstractNDArray declares ``native`` abstract; the body is |
| 34 | + # irrelevant to the flatten path so just echo ``_array``. |
| 35 | + return self._array |
| 36 | + |
| 37 | + @functools.cached_property |
| 38 | + def heavy_summary(self): |
| 39 | + return "a-pretty-printed-summary-of-the-array" |
| 40 | + |
| 41 | + |
| 42 | +def test_instance_flatten_excludes_cached_property_names(): |
| 43 | + """``AbstractNDArray.instance_flatten`` unions the class-level |
| 44 | + ``__no_flatten__`` with the result of |
| 45 | + ``autoconf.tools.decorators.cached_property_names`` so derived |
| 46 | + cached strings stay out of the pytree leaves. |
| 47 | +
|
| 48 | + This pins the structural defense that follows PyAutoFit#1300: the |
| 49 | + leak surfaces today only on the Model side, but the same opt-out |
| 50 | + filter shape on AbstractNDArray descendants would break ``jax.jit`` |
| 51 | + the moment anyone added a ``@cached_property`` returning a |
| 52 | + non-array value to a Fit class.""" |
| 53 | + |
| 54 | + arr = _FakeArray([1.0, 2.0, 3.0]) |
| 55 | + |
| 56 | + # Trigger the cached property: it writes "...summary..." into __dict__. |
| 57 | + _ = arr.heavy_summary |
| 58 | + assert arr.__dict__["heavy_summary"] == "a-pretty-printed-summary-of-the-array" |
| 59 | + |
| 60 | + leaves, keys = _FakeArray.instance_flatten(arr) |
| 61 | + |
| 62 | + # The pre-existing __no_flatten__ exclusion ("use_jax") still applies. |
| 63 | + assert "use_jax" not in keys |
| 64 | + # The new cached_property exclusion fires too. |
| 65 | + assert "heavy_summary" not in keys |
| 66 | + # No string leaves anywhere. |
| 67 | + assert not any(isinstance(leaf, str) for leaf in leaves) |
| 68 | + |
| 69 | + |
| 70 | +def test_instance_flatten_preserves_array_data(): |
| 71 | + """Sanity check: filtering cached_property names does not collateral- |
| 72 | + damage real array data. The underlying numpy array must still appear |
| 73 | + in the leaves.""" |
| 74 | + |
| 75 | + arr = _FakeArray([1.0, 2.0, 3.0]) |
| 76 | + _ = arr.heavy_summary # poison the cache before flattening |
| 77 | + |
| 78 | + leaves, keys = _FakeArray.instance_flatten(arr) |
| 79 | + |
| 80 | + assert "_array" in keys |
| 81 | + array_index = keys.index("_array") |
| 82 | + np.testing.assert_array_equal(leaves[array_index], np.asarray([1.0, 2.0, 3.0])) |
0 commit comments