-
Notifications
You must be signed in to change notification settings - Fork 19
Description
Thank you for implementing the forward Lap. It significantly accelerates the Laplacian calculation. I encountered a TypeError when computing the Laplacian of a function involving jnp.linalg.slogdet using folx. This error only occurs in a distributed/sharded environment (using jax.sharding.Mesh).
The code works perfectly fine in a single-device setting. It also works fine in a distributed setting if I replace jnp.linalg.slogdet with other operations like jnp.sum.
#36 Minimal Reproduction Script
Here is a minimal script to reproduce the issue. It sets up a JAX Mesh and shards the input data.
import jax
import jax.numpy as jnp
import numpy as np
from folx import forward_laplacian
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
import os
#jax.config.update("jax_enable_x64", True)
def test_parallel():
print("=== Running Distributed (Mesh/Sharding) Test ===")
if 'SLURM_PROCID' in os.environ:
try:
jax.distributed.initialize()
except:
pass
mesh = Mesh(jax.devices(), axis_names=('data',))
N = 10
Batch = 8
x_host = jnp.stack([jnp.eye(N) * (2.0 + i*0.1) for i in range(Batch)])
sharding = NamedSharding(mesh, P('data'))
x_sharded = jax.device_put(x_host, sharding)
def single_sample_f(x):
sign, logdet = jnp.linalg.slogdet(x)
return logdet
fwd_op = forward_laplacian(single_sample_f)
fwd_f_vmap = jax.vmap(fwd_op)
result = fwd_f_vmap(x_sharded)
if __name__ == "__main__":
test_parallel()
#36 Error Traceback
TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal float32[] with tangent float32[], expecting tangent float32[]
#36 Environment
Jax 0.7.0 + lastest folx