Skip to content

Conversation

@ae-foster
Copy link
Collaborator

Closes #33

The fix for the latest version of jax is actually already available in jax==0.4.34, which I believe is the lowest version we support. So we can simply switch to the updated API across all versions. Tested on 0.4.34 and everything works

@ae-foster ae-foster requested review from n-gao and szbernat May 27, 2025 11:14
@ae-foster
Copy link
Collaborator Author

Ah, I see you are supporting 0.4.10 on the lower Python versions.

@ae-foster
Copy link
Collaborator Author

I propose dropping support for lower versions. But I am happy to also write some workaround code for the 0.4.10 support. Lmk which you prefer

@n-gao
Copy link
Collaborator

n-gao commented May 27, 2025

thanks a lot! I would actually prefer to keep the current minimum since I am using 0.4.29 in some project.

@ae-foster ae-foster force-pushed the ae-foster/fix-jax-version-pallas branch from 92b62bc to 8ae452e Compare May 28, 2025 12:09
@ae-foster
Copy link
Collaborator Author

Just pushed a fix that doesn't change the versions :) Hopefully we can remove some of these hacks at a later stage, but everything is now working across all supported versions-yay

@ae-foster ae-foster force-pushed the ae-foster/fix-jax-version-pallas branch from 8ae452e to 2634e24 Compare May 28, 2025 12:25
Copy link
Collaborator

@n-gao n-gao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, looks great! :)

@n-gao n-gao merged commit f9b237a into main May 28, 2025
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Address breaking change in pallas API

3 participants