diff --git a/folx/ad.py b/folx/ad.py index d28e078..f1dc97d 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -71,9 +71,10 @@ def flat_f(x): out = flat_f(flat_primals) - result = jax.vmap(vjp(flat_f, flat_primals))( - jnp.eye(out.size, dtype=out.dtype) - )[0] + eye = jnp.eye(out.size, dtype=out.dtype) + if hasattr(jax.lax, 'pvary'): + eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma)) + result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0] result = jax.vmap(unravel, out_axes=0)(result) if len(primals) == 1: return result[0] diff --git a/folx/api.py b/folx/api.py index fd677d2..0e9b964 100644 --- a/folx/api.py +++ b/folx/api.py @@ -131,7 +131,18 @@ def get_indices(mask, out_mask): if isinstance(outputs, np.ndarray): with jax.ensure_compile_time_eval(): - result = np.asarray(get_indices(flat_mask, flat_outputs), dtype=int).T + if hasattr(jax.sharding, 'use_abstract_mesh'): # jax>=0.7.2 + # see https://github.com/jax-ml/jax/discussions/31461 + with jax.sharding.use_abstract_mesh( + jax.sharding.AbstractMesh((), ()) + ): + result = np.asarray( + get_indices(flat_mask, flat_outputs), dtype=int + ).T + else: + result = np.asarray( + get_indices(flat_mask, flat_outputs), dtype=int + ).T else: result = get_indices(flat_mask, flat_outputs).T return result.reshape(mask.shape) diff --git a/folx/wrapped_functions.py b/folx/wrapped_functions.py index 59ff887..bf6eca9 100644 --- a/folx/wrapped_functions.py +++ b/folx/wrapped_functions.py @@ -227,6 +227,10 @@ def custom_jvp(jacobian, tangent, sign): log_det_jvp = jac_dot_tangent.real else: sign_jvp = jnp.zeros((), dtype=jac_dot_tangent.dtype) + if hasattr(jax.lax, 'pvary'): + sign_jvp = jax.lax.pvary( + sign_jvp, tuple(jax.typeof(jac_dot_tangent).vma) + ) log_det_jvp = jac_dot_tangent return (sign_jvp, log_det_jvp) diff --git a/test/test_layers.py b/test/test_layers.py index 46d0286..ef1b29d 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -1,10 +1,12 @@ import functools +from functools import partial import jax import jax.numpy as jnp import jax.tree_util as jtu import numpy as np from laplacian_testcase import LaplacianTestCase +from packaging.version import Version from parameterized import parameterized from folx import ( @@ -174,28 +176,66 @@ def test_slogdet(self, test_complex: bool): w = w + 1j * np.random.normal(size=w.shape) @jax.jit - def f(x): + def _f(w, x): return jnp.linalg.slogdet(jnp.tanh((x @ w).reshape(16, 16))) + f = partial(_f, w) + for sparsity in [0, x.size]: - with self.subTest(sparsity=sparsity): - sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x) - self.assertEqual(log_y.x.shape, f(x)[1].shape) - self.assert_allclose(log_y.x, f(x)[1]) - self.assert_allclose( - log_y.jacobian.dense_array, self.jacobian(f, x)[1].T - ) - self.assert_allclose(log_y.laplacian, self.laplacian(f, x)[1]) - - self.assertEqual(sign_y.shape, log_y.x.shape) - if test_complex: - self.assertIsInstance(sign_y, FwdLaplArray) + for use_shard_map in [False, True]: + with self.subTest(sparsity=sparsity, use_shard_map=use_shard_map): + if use_shard_map and ( + Version(jax.__version__) < Version('0.7.1') + or ( + sparsity != 0 + and Version(jax.__version__) < Version('0.7.2') + ) + ): + self.skipTest('jax version too old') + if use_shard_map: + mesh = jax.sharding.Mesh( + jax.devices()[:1], + 'i', + axis_types=jax.sharding.AxisType.Explicit, + ) + + @jax.jit + @partial( + jax.shard_map, + in_specs=(jax.P(), jax.P('i')), + out_specs=jax.P('i'), + ) + @partial(jax.vmap, in_axes=(None, 0)) + def forward_laplacian_sh(w, x): + return forward_laplacian(partial(_f, w), sparsity)(x) + + with jax.set_mesh(mesh): + x_sh = jax.sharding.reshard(x[None], jax.P('i')) + w_sh = jax.sharding.reshard(w, jax.P()) + sign_y, log_y = jax.tree.map( + lambda x: x[0], forward_laplacian_sh(w_sh, x_sh) + ) + else: + sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x) + + self.assertEqual(log_y.x.shape, f(x)[1].shape) + self.assert_allclose(log_y.x, f(x)[1]) self.assert_allclose( - sign_y.jacobian.dense_array, self.jacobian(f, x)[0].T + log_y.jacobian.dense_array, self.jacobian(f, x)[1].T ) - self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0]) - else: - self.assertIsInstance(sign_y, jax.Array) + self.assert_allclose(log_y.laplacian, self.laplacian(f, x)[1]) + + self.assertEqual(sign_y.shape, log_y.x.shape) + if test_complex: + self.assertIsInstance(sign_y, FwdLaplArray) + self.assert_allclose( + sign_y.jacobian.dense_array, self.jacobian(f, x)[0].T + ) + self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0]) + else: + self.assertIsInstance(sign_y, jax.Array) + del sign_y + del log_y def test_custom_hessian(self): x = np.random.normal(size=(16,))