Skip to content

Conversation

@inailuig
Copy link
Contributor

@inailuig inailuig commented Sep 5, 2025

We are using folx to compute the laplacian in batches using a function similar to folx.batched_vmap,
Our samples are sharded along the batch axis, and thus we need to run the batched_vmap inside of shard_map, so that each jax device only loops over batches of its own samples.

Inside of the shard_map the samples are varying, but some of the arrays of folx created from thin air (jnp.eye, jnp.ones, jnp,zeros, etc) are not, causing jax to error (see below).

This PR adds a few jax.lax.pvary's statement, setting the varying mesh axes correctly to make this work.

Example:

import jax
import jax.numpy as jnp
from functools import partial
from jax.sharding import *
from jax import shard_map, P
import numpy as np

import folx

if jax.config.jax_num_cpu_devices <=1:
    jax.config.update("jax_num_cpu_devices", 4)
    
mesh = jax.make_mesh((4,),("i",), axis_types=(AxisType.Explicit,),)
jax.sharding.set_mesh(mesh)

x = np.random.normal(size=(1024, 16 * 16))
w = np.random.normal(size=(16 * 16, 16 * 16))

x_sh = reshard(x, P('i'))
w_sh = reshard(w, P())

sparsity = 0
# sparsity = 0.5 # requires jax>=0.7.2

def f(w, x):
    return jnp.linalg.slogdet(jnp.tanh((x @ w).reshape(16, 16)))

@jax.jit
@partial(jax.shard_map, in_specs=(P(), P('i')), out_specs=P('i'))
@partial(folx.batched_vmap, in_axes=(None, 0), max_batch_size=64)
def forward_laplacian_sh(w, x):
    return folx.forward_laplacian(partial(f, w), sparsity)(x)

forward_laplacian_sh(w_sh, x_sh)

Before this PR this errored with

ValueError: unexpected JAX type (e.g. shape/dtype) for argument to vjp function: got float32[1], but expected float32[1]{i} because the corresponding output of the function flat_f had JAX type float32[1]{i}

In this PR I only set the vma in places I was able to trigger the error with the test, but It might be necessary elsewhere too (e.g. ed77b3a and e580d59 are a few places)

One nontrivial one is this

folx/folx/ad.py

Lines 92 to 96 in 30b053a

def jvp_fun(s):
return jax.jvp(f, primals, unravel(s))[1]
eye = jnp.eye(flat_primals.size, dtype=flat_primals.dtype)
J = jax.vmap(jvp_fun, out_axes=-1)(eye)

which would need a pvary setting the vma of eye if one ever tried to linear_transpose the function, see my comment here netket/netket#2072 (comment) .

@inailuig
Copy link
Contributor Author

inailuig commented Sep 5, 2025

@microsoft-github-policy-service agree

@PhilipVinc
Copy link

Ping @n-gao

1 similar comment
@PhilipVinc
Copy link

Ping @n-gao

@gcarleo
Copy link

gcarleo commented Sep 24, 2025

@n-gao would be great if we could merge this

@n-gao
Copy link
Collaborator

n-gao commented Sep 24, 2025

I am very sorry, github stopped sending me emails for this repository. Hugely annoying, please feel free to write me a mail in the future!
I will have a look later today. Thanks a lot for the fix.

@PhilipVinc
Copy link

THanks!
FYI (in case you are not familiar with this sharding stuff) jax.lax.pvary is a no-op when there are no manual meshes (aka, when not under a shard map) so this code should not affect use cases outside of it.

@n-gao
Copy link
Collaborator

n-gao commented Sep 25, 2025

The failed tests on the newest JAX versions seems unrelated to this PR and are due to changes in JAX.

But, it would be nice to clean up the pre-commit issues.

@n-gao
Copy link
Collaborator

n-gao commented Sep 25, 2025

I am quite curious on why this fails. Since I've used shard_map quite extensively in the past jointly with folx but the changes look reasonable to me. Thanks and again sorry for the delay. Feel free to mail me in the future.

@n-gao
Copy link
Collaborator

n-gao commented Sep 25, 2025

I fixed the compiler params on main, could you rebase such that CI can rerun?

@PhilipVinc
Copy link

Since I've used shard_map quite extensively in the past jointly with folx but the changes look reasonable to me

Did you since 0.7.0?
The requirement to tag arrays as pvray-ing was added there, to lighten the load on the compiler which sometimes could not figure it out.

@n-gao
Copy link
Collaborator

n-gao commented Sep 25, 2025

Ah I think I've set check_vma=False in which case it runs through even without these changes. But, I agree that handling it is preferred.

@inailuig
Copy link
Contributor Author

I fixed the compiler params on main, could you rebase such that CI can rerun?

done, can you re-run the CI?

@n-gao n-gao merged commit d05c107 into microsoft:main Sep 25, 2025
10 checks passed
@inailuig inailuig deleted the pvary branch September 25, 2025 08:33
@PhilipVinc
Copy link

@n-gao could you tag a new release?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants