From 307a472b1a5d21df7ef4489a28f0eed5e86a26c7 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Mon, 1 Sep 2025 14:06:43 +0200 Subject: [PATCH 1/6] pvary I --- folx/ad.py | 4 +++- folx/wrapped_functions.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/folx/ad.py b/folx/ad.py index d28e078..a7e4e77 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -71,8 +71,10 @@ def flat_f(x): out = flat_f(flat_primals) + eye = jnp.eye(out.size, dtype=out.dtype) + eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma)) result = jax.vmap(vjp(flat_f, flat_primals))( - jnp.eye(out.size, dtype=out.dtype) + eye )[0] result = jax.vmap(unravel, out_axes=0)(result) if len(primals) == 1: diff --git a/folx/wrapped_functions.py b/folx/wrapped_functions.py index 59ff887..a9c8854 100644 --- a/folx/wrapped_functions.py +++ b/folx/wrapped_functions.py @@ -227,6 +227,7 @@ def custom_jvp(jacobian, tangent, sign): log_det_jvp = jac_dot_tangent.real else: sign_jvp = jnp.zeros((), dtype=jac_dot_tangent.dtype) + sign_jvp = jax.lax.pvary(sign_jvp, tuple(jax.typeof(jac_dot_tangent).vma)) log_det_jvp = jac_dot_tangent return (sign_jvp, log_det_jvp) From 4a40f3ea4b16639148670d88a3c26fa8c989ccfb Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Mon, 1 Sep 2025 17:17:44 +0200 Subject: [PATCH 2/6] workaround to fix sparse version (jax>=0.7.2) --- folx/api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/folx/api.py b/folx/api.py index fd677d2..8acdb9b 100644 --- a/folx/api.py +++ b/folx/api.py @@ -131,7 +131,9 @@ def get_indices(mask, out_mask): if isinstance(outputs, np.ndarray): with jax.ensure_compile_time_eval(): - result = np.asarray(get_indices(flat_mask, flat_outputs), dtype=int).T + # see https://github.com/jax-ml/jax/discussions/31461 + with jax.sharding.use_abstract_mesh(jax.sharding.AbstractMesh((), ())): + result = np.asarray(get_indices(flat_mask, flat_outputs), dtype=int).T else: result = get_indices(flat_mask, flat_outputs).T return result.reshape(mask.shape) From 3d1ebef2de945ad5d6990391d6f4888e203e9d53 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Fri, 5 Sep 2025 16:55:11 +0200 Subject: [PATCH 3/6] restore compatibilty with older versions of jax --- folx/ad.py | 3 ++- folx/api.py | 7 +++++-- folx/wrapped_functions.py | 3 ++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/folx/ad.py b/folx/ad.py index a7e4e77..a7d0680 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -72,7 +72,8 @@ def flat_f(x): out = flat_f(flat_primals) eye = jnp.eye(out.size, dtype=out.dtype) - eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma)) + if hasattr(jax.lax, 'pvary'): + eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma)) result = jax.vmap(vjp(flat_f, flat_primals))( eye )[0] diff --git a/folx/api.py b/folx/api.py index 8acdb9b..7ea3951 100644 --- a/folx/api.py +++ b/folx/api.py @@ -131,8 +131,11 @@ def get_indices(mask, out_mask): if isinstance(outputs, np.ndarray): with jax.ensure_compile_time_eval(): - # see https://github.com/jax-ml/jax/discussions/31461 - with jax.sharding.use_abstract_mesh(jax.sharding.AbstractMesh((), ())): + if hasattr(jax.sharding, 'use_abstract_mesh'): # jax>=0.7.2 + # see https://github.com/jax-ml/jax/discussions/31461 + with jax.sharding.use_abstract_mesh(jax.sharding.AbstractMesh((), ())): + result = np.asarray(get_indices(flat_mask, flat_outputs), dtype=int).T + else: result = np.asarray(get_indices(flat_mask, flat_outputs), dtype=int).T else: result = get_indices(flat_mask, flat_outputs).T diff --git a/folx/wrapped_functions.py b/folx/wrapped_functions.py index a9c8854..dd24afd 100644 --- a/folx/wrapped_functions.py +++ b/folx/wrapped_functions.py @@ -227,7 +227,8 @@ def custom_jvp(jacobian, tangent, sign): log_det_jvp = jac_dot_tangent.real else: sign_jvp = jnp.zeros((), dtype=jac_dot_tangent.dtype) - sign_jvp = jax.lax.pvary(sign_jvp, tuple(jax.typeof(jac_dot_tangent).vma)) + if hasattr(jax.lax, 'pvary'): + sign_jvp = jax.lax.pvary(sign_jvp, tuple(jax.typeof(jac_dot_tangent).vma)) log_det_jvp = jac_dot_tangent return (sign_jvp, log_det_jvp) From 29345137dd0b0b8ed91ad5dd19367a787c705357 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Fri, 5 Sep 2025 16:06:16 +0200 Subject: [PATCH 4/6] add a test --- test/test_layers.py | 57 +++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/test/test_layers.py b/test/test_layers.py index 46d0286..1b38421 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -4,6 +4,7 @@ import jax.numpy as jnp import jax.tree_util as jtu import numpy as np +from functools import partial from laplacian_testcase import LaplacianTestCase from parameterized import parameterized @@ -15,6 +16,8 @@ ) from folx.api import FwdLaplArray +from packaging.version import Version + class TestForwardLaplacian(LaplacianTestCase): @parameterized.expand([(False,), (True,)]) @@ -174,28 +177,48 @@ def test_slogdet(self, test_complex: bool): w = w + 1j * np.random.normal(size=w.shape) @jax.jit - def f(x): + def _f(w, x): return jnp.linalg.slogdet(jnp.tanh((x @ w).reshape(16, 16))) + f = partial(_f, w) + for sparsity in [0, x.size]: - with self.subTest(sparsity=sparsity): - sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x) - self.assertEqual(log_y.x.shape, f(x)[1].shape) - self.assert_allclose(log_y.x, f(x)[1]) - self.assert_allclose( - log_y.jacobian.dense_array, self.jacobian(f, x)[1].T - ) - self.assert_allclose(log_y.laplacian, self.laplacian(f, x)[1]) - - self.assertEqual(sign_y.shape, log_y.x.shape) - if test_complex: - self.assertIsInstance(sign_y, FwdLaplArray) + for use_shard_map in [False, True]: + with self.subTest(sparsity=sparsity, use_shard_map=use_shard_map): + if use_shard_map and \ + (Version(jax.__version__) < Version('0.7.1') or \ + (sparsity!=0 and Version(jax.__version__) < Version('0.7.2'))): + self.skipTest("jax version too old") + if use_shard_map: + mesh = jax.sharding.Mesh(jax.devices()[:1], "i", axis_types=jax.sharding.AxisType.Explicit) + @jax.jit + @partial(jax.shard_map, in_specs=(jax.P(), jax.P('i')), out_specs=jax.P('i')) + @partial(jax.vmap, in_axes=(None, 0)) + def forward_laplacian_sh(w, x): + return forward_laplacian(partial(_f, w), sparsity)(x) + with jax.set_mesh(mesh): + x_sh = jax.sharding.reshard(x[None], jax.P('i')) + w_sh = jax.sharding.reshard(w, jax.P()) + sign_y_sh, log_y_sh = jax.tree.map(lambda x: x[0], forward_laplacian_sh(w_sh, x_sh)) + else: + sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x) + + self.assertEqual(log_y.x.shape, f(x)[1].shape) + self.assert_allclose(log_y.x, f(x)[1]) self.assert_allclose( - sign_y.jacobian.dense_array, self.jacobian(f, x)[0].T + log_y.jacobian.dense_array, self.jacobian(f, x)[1].T ) - self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0]) - else: - self.assertIsInstance(sign_y, jax.Array) + self.assert_allclose(log_y.laplacian, self.laplacian(f, x)[1]) + + self.assertEqual(sign_y.shape, log_y.x.shape) + if test_complex: + self.assertIsInstance(sign_y, FwdLaplArray) + self.assert_allclose( + sign_y.jacobian.dense_array, self.jacobian(f, x)[0].T + ) + self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0]) + else: + self.assertIsInstance(sign_y, jax.Array) def test_custom_hessian(self): x = np.random.normal(size=(16,)) From 46202a3a12478343a26f2dee5306a36d40cbf1eb Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Thu, 25 Sep 2025 09:59:34 +0200 Subject: [PATCH 5/6] fmt --- folx/ad.py | 4 +--- folx/api.py | 14 ++++++++++---- folx/wrapped_functions.py | 4 +++- test/test_layers.py | 35 +++++++++++++++++++++++++---------- 4 files changed, 39 insertions(+), 18 deletions(-) diff --git a/folx/ad.py b/folx/ad.py index a7d0680..f1dc97d 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -74,9 +74,7 @@ def flat_f(x): eye = jnp.eye(out.size, dtype=out.dtype) if hasattr(jax.lax, 'pvary'): eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma)) - result = jax.vmap(vjp(flat_f, flat_primals))( - eye - )[0] + result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0] result = jax.vmap(unravel, out_axes=0)(result) if len(primals) == 1: return result[0] diff --git a/folx/api.py b/folx/api.py index 7ea3951..0e9b964 100644 --- a/folx/api.py +++ b/folx/api.py @@ -131,12 +131,18 @@ def get_indices(mask, out_mask): if isinstance(outputs, np.ndarray): with jax.ensure_compile_time_eval(): - if hasattr(jax.sharding, 'use_abstract_mesh'): # jax>=0.7.2 + if hasattr(jax.sharding, 'use_abstract_mesh'): # jax>=0.7.2 # see https://github.com/jax-ml/jax/discussions/31461 - with jax.sharding.use_abstract_mesh(jax.sharding.AbstractMesh((), ())): - result = np.asarray(get_indices(flat_mask, flat_outputs), dtype=int).T + with jax.sharding.use_abstract_mesh( + jax.sharding.AbstractMesh((), ()) + ): + result = np.asarray( + get_indices(flat_mask, flat_outputs), dtype=int + ).T else: - result = np.asarray(get_indices(flat_mask, flat_outputs), dtype=int).T + result = np.asarray( + get_indices(flat_mask, flat_outputs), dtype=int + ).T else: result = get_indices(flat_mask, flat_outputs).T return result.reshape(mask.shape) diff --git a/folx/wrapped_functions.py b/folx/wrapped_functions.py index dd24afd..bf6eca9 100644 --- a/folx/wrapped_functions.py +++ b/folx/wrapped_functions.py @@ -228,7 +228,9 @@ def custom_jvp(jacobian, tangent, sign): else: sign_jvp = jnp.zeros((), dtype=jac_dot_tangent.dtype) if hasattr(jax.lax, 'pvary'): - sign_jvp = jax.lax.pvary(sign_jvp, tuple(jax.typeof(jac_dot_tangent).vma)) + sign_jvp = jax.lax.pvary( + sign_jvp, tuple(jax.typeof(jac_dot_tangent).vma) + ) log_det_jvp = jac_dot_tangent return (sign_jvp, log_det_jvp) diff --git a/test/test_layers.py b/test/test_layers.py index 1b38421..1856227 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -1,11 +1,12 @@ import functools +from functools import partial import jax import jax.numpy as jnp import jax.tree_util as jtu import numpy as np -from functools import partial from laplacian_testcase import LaplacianTestCase +from packaging.version import Version from parameterized import parameterized from folx import ( @@ -16,8 +17,6 @@ ) from folx.api import FwdLaplArray -from packaging.version import Version - class TestForwardLaplacian(LaplacianTestCase): @parameterized.expand([(False,), (True,)]) @@ -185,21 +184,37 @@ def _f(w, x): for sparsity in [0, x.size]: for use_shard_map in [False, True]: with self.subTest(sparsity=sparsity, use_shard_map=use_shard_map): - if use_shard_map and \ - (Version(jax.__version__) < Version('0.7.1') or \ - (sparsity!=0 and Version(jax.__version__) < Version('0.7.2'))): - self.skipTest("jax version too old") + if use_shard_map and ( + Version(jax.__version__) < Version('0.7.1') + or ( + sparsity != 0 + and Version(jax.__version__) < Version('0.7.2') + ) + ): + self.skipTest('jax version too old') if use_shard_map: - mesh = jax.sharding.Mesh(jax.devices()[:1], "i", axis_types=jax.sharding.AxisType.Explicit) + mesh = jax.sharding.Mesh( + jax.devices()[:1], + 'i', + axis_types=jax.sharding.AxisType.Explicit, + ) + @jax.jit - @partial(jax.shard_map, in_specs=(jax.P(), jax.P('i')), out_specs=jax.P('i')) + @partial( + jax.shard_map, + in_specs=(jax.P(), jax.P('i')), + out_specs=jax.P('i'), + ) @partial(jax.vmap, in_axes=(None, 0)) def forward_laplacian_sh(w, x): return forward_laplacian(partial(_f, w), sparsity)(x) + with jax.set_mesh(mesh): x_sh = jax.sharding.reshard(x[None], jax.P('i')) w_sh = jax.sharding.reshard(w, jax.P()) - sign_y_sh, log_y_sh = jax.tree.map(lambda x: x[0], forward_laplacian_sh(w_sh, x_sh)) + sign_y_sh, log_y_sh = jax.tree.map( + lambda x: x[0], forward_laplacian_sh(w_sh, x_sh) + ) else: sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x) From 288ee5eedf1af47afb4e9bb3b1e42f4a0837b061 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Thu, 25 Sep 2025 10:21:40 +0200 Subject: [PATCH 6/6] fix test --- test/test_layers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_layers.py b/test/test_layers.py index 1856227..ef1b29d 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -212,7 +212,7 @@ def forward_laplacian_sh(w, x): with jax.set_mesh(mesh): x_sh = jax.sharding.reshard(x[None], jax.P('i')) w_sh = jax.sharding.reshard(w, jax.P()) - sign_y_sh, log_y_sh = jax.tree.map( + sign_y, log_y = jax.tree.map( lambda x: x[0], forward_laplacian_sh(w_sh, x_sh) ) else: @@ -234,6 +234,8 @@ def forward_laplacian_sh(w, x): self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0]) else: self.assertIsInstance(sign_y, jax.Array) + del sign_y + del log_y def test_custom_hessian(self): x = np.random.normal(size=(16,))