-
Notifications
You must be signed in to change notification settings - Fork 19
shard_map compatibility #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@microsoft-github-policy-service agree |
|
Ping @n-gao |
1 similar comment
|
Ping @n-gao |
|
@n-gao would be great if we could merge this |
|
I am very sorry, github stopped sending me emails for this repository. Hugely annoying, please feel free to write me a mail in the future! |
|
THanks! |
|
The failed tests on the newest JAX versions seems unrelated to this PR and are due to changes in JAX. But, it would be nice to clean up the pre-commit issues. |
|
I am quite curious on why this fails. Since I've used shard_map quite extensively in the past jointly with folx but the changes look reasonable to me. Thanks and again sorry for the delay. Feel free to mail me in the future. |
|
I fixed the compiler params on main, could you rebase such that CI can rerun? |
Did you since 0.7.0? |
|
Ah I think I've set |
done, can you re-run the CI? |
|
@n-gao could you tag a new release? |
We are using folx to compute the laplacian in batches using a function similar to
folx.batched_vmap,Our samples are sharded along the batch axis, and thus we need to run the
batched_vmap insideofshard_map, so that each jax device only loops over batches of its own samples.Inside of the
shard_mapthe samples are varying, but some of the arrays of folx created from thin air (jnp.eye, jnp.ones, jnp,zeros, etc) are not, causing jax to error (see below).This PR adds a few
jax.lax.pvary's statement, setting the varying mesh axes correctly to make this work.Example:
Before this PR this errored with
In this PR I only set the vma in places I was able to trigger the error with the test, but It might be necessary elsewhere too (e.g. ed77b3a and e580d59 are a few places)
One nontrivial one is this
folx/folx/ad.py
Lines 92 to 96 in 30b053a
which would need a
pvarysetting the vma ofeyeif one ever tried tolinear_transposethe function, see my comment here netket/netket#2072 (comment) .