Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions folx/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ def flat_f(x):

out = flat_f(flat_primals)

result = jax.vmap(vjp(flat_f, flat_primals))(
jnp.eye(out.size, dtype=out.dtype)
)[0]
eye = jnp.eye(out.size, dtype=out.dtype)
if hasattr(jax.lax, 'pvary'):
eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma))
result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0]
result = jax.vmap(unravel, out_axes=0)(result)
if len(primals) == 1:
return result[0]
Expand Down
13 changes: 12 additions & 1 deletion folx/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,18 @@ def get_indices(mask, out_mask):

if isinstance(outputs, np.ndarray):
with jax.ensure_compile_time_eval():
result = np.asarray(get_indices(flat_mask, flat_outputs), dtype=int).T
if hasattr(jax.sharding, 'use_abstract_mesh'): # jax>=0.7.2
# see https://github.com/jax-ml/jax/discussions/31461
with jax.sharding.use_abstract_mesh(
jax.sharding.AbstractMesh((), ())
):
result = np.asarray(
get_indices(flat_mask, flat_outputs), dtype=int
).T
else:
result = np.asarray(
get_indices(flat_mask, flat_outputs), dtype=int
).T
else:
result = get_indices(flat_mask, flat_outputs).T
return result.reshape(mask.shape)
Expand Down
4 changes: 4 additions & 0 deletions folx/wrapped_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def custom_jvp(jacobian, tangent, sign):
log_det_jvp = jac_dot_tangent.real
else:
sign_jvp = jnp.zeros((), dtype=jac_dot_tangent.dtype)
if hasattr(jax.lax, 'pvary'):
sign_jvp = jax.lax.pvary(
sign_jvp, tuple(jax.typeof(jac_dot_tangent).vma)
)
log_det_jvp = jac_dot_tangent
return (sign_jvp, log_det_jvp)

Expand Down
74 changes: 57 additions & 17 deletions test/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import functools
from functools import partial

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from laplacian_testcase import LaplacianTestCase
from packaging.version import Version
from parameterized import parameterized

from folx import (
Expand Down Expand Up @@ -174,28 +176,66 @@ def test_slogdet(self, test_complex: bool):
w = w + 1j * np.random.normal(size=w.shape)

@jax.jit
def f(x):
def _f(w, x):
return jnp.linalg.slogdet(jnp.tanh((x @ w).reshape(16, 16)))

f = partial(_f, w)

for sparsity in [0, x.size]:
with self.subTest(sparsity=sparsity):
sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x)
self.assertEqual(log_y.x.shape, f(x)[1].shape)
self.assert_allclose(log_y.x, f(x)[1])
self.assert_allclose(
log_y.jacobian.dense_array, self.jacobian(f, x)[1].T
)
self.assert_allclose(log_y.laplacian, self.laplacian(f, x)[1])

self.assertEqual(sign_y.shape, log_y.x.shape)
if test_complex:
self.assertIsInstance(sign_y, FwdLaplArray)
for use_shard_map in [False, True]:
with self.subTest(sparsity=sparsity, use_shard_map=use_shard_map):
if use_shard_map and (
Version(jax.__version__) < Version('0.7.1')
or (
sparsity != 0
and Version(jax.__version__) < Version('0.7.2')
)
):
self.skipTest('jax version too old')
if use_shard_map:
mesh = jax.sharding.Mesh(
jax.devices()[:1],
'i',
axis_types=jax.sharding.AxisType.Explicit,
)

@jax.jit
@partial(
jax.shard_map,
in_specs=(jax.P(), jax.P('i')),
out_specs=jax.P('i'),
)
@partial(jax.vmap, in_axes=(None, 0))
def forward_laplacian_sh(w, x):
return forward_laplacian(partial(_f, w), sparsity)(x)

with jax.set_mesh(mesh):
x_sh = jax.sharding.reshard(x[None], jax.P('i'))
w_sh = jax.sharding.reshard(w, jax.P())
sign_y, log_y = jax.tree.map(
lambda x: x[0], forward_laplacian_sh(w_sh, x_sh)
)
else:
sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x)

self.assertEqual(log_y.x.shape, f(x)[1].shape)
self.assert_allclose(log_y.x, f(x)[1])
self.assert_allclose(
sign_y.jacobian.dense_array, self.jacobian(f, x)[0].T
log_y.jacobian.dense_array, self.jacobian(f, x)[1].T
)
self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0])
else:
self.assertIsInstance(sign_y, jax.Array)
self.assert_allclose(log_y.laplacian, self.laplacian(f, x)[1])

self.assertEqual(sign_y.shape, log_y.x.shape)
if test_complex:
self.assertIsInstance(sign_y, FwdLaplArray)
self.assert_allclose(
sign_y.jacobian.dense_array, self.jacobian(f, x)[0].T
)
self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0])
else:
self.assertIsInstance(sign_y, jax.Array)
del sign_y
del log_y

def test_custom_hessian(self):
x = np.random.normal(size=(16,))
Expand Down