Skip to content

TypeError when using jnp.linalg.slogdet with forward_laplacian in a distributed/sharded JAX environment #37

@zhangylch

Description

@zhangylch

Thank you for implementing the forward Lap. It significantly accelerates the Laplacian calculation. I encountered a TypeError when computing the Laplacian of a function involving jnp.linalg.slogdet using folx. This error only occurs in a distributed/sharded environment (using jax.sharding.Mesh).

The code works perfectly fine in a single-device setting. It also works fine in a distributed setting if I replace jnp.linalg.slogdet with other operations like jnp.sum.

#36 Minimal Reproduction Script
Here is a minimal script to reproduce the issue. It sets up a JAX Mesh and shards the input data.

import jax
import jax.numpy as jnp
import numpy as np
from folx import forward_laplacian
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
import os

#jax.config.update("jax_enable_x64", True)

def test_parallel():
    print("=== Running Distributed (Mesh/Sharding) Test ===")

    if 'SLURM_PROCID' in os.environ:
        try:
            jax.distributed.initialize()
        except:
            pass


    mesh = Mesh(jax.devices(), axis_names=('data',))

    N = 10
    Batch = 8

    x_host = jnp.stack([jnp.eye(N) * (2.0 + i*0.1) for i in range(Batch)])

    sharding = NamedSharding(mesh, P('data'))

    x_sharded = jax.device_put(x_host, sharding)

    def single_sample_f(x):
        sign, logdet = jnp.linalg.slogdet(x)
        return logdet

    fwd_op = forward_laplacian(single_sample_f)

    fwd_f_vmap = jax.vmap(fwd_op)

    result = fwd_f_vmap(x_sharded)


if __name__ == "__main__":
    test_parallel()

#36 Error Traceback
TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal float32[] with tangent float32[], expecting tangent float32[]

#36 Environment
Jax 0.7.0 + lastest folx

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions