-
Notifications
You must be signed in to change notification settings - Fork 40
Hotfix for compatibility with older jax versions #1926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…erent api before jax 0.6.2
… is different api before jax 0.6.2" This reverts commit 6fdb6e7.
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 3.40 % | 3.908e+03 | 4.041e+03 | 133.04 | 37.66 | 33.64 |
test_proximal_jac_w7x_with_eq_update | 0.63 % | 6.828e+03 | 6.871e+03 | 42.95 | 173.83 | 173.53 |
test_proximal_freeb_jac | 0.30 % | 1.322e+04 | 1.326e+04 | 39.00 | 86.43 | 86.13 |
test_proximal_freeb_jac_blocked | 0.63 % | 7.652e+03 | 7.701e+03 | 48.16 | 78.82 | 78.94 |
test_proximal_freeb_jac_batched | -0.13 % | 7.623e+03 | 7.613e+03 | -9.88 | 78.61 | 78.82 |
test_proximal_jac_ripple | 0.00 % | 7.728e+03 | 7.728e+03 | 0.34 | 70.11 | 70.52 |
test_proximal_jac_ripple_spline | 1.46 % | 3.639e+03 | 3.692e+03 | 53.25 | 70.41 | 69.92 |
test_eq_solve | 2.89 % | 2.041e+03 | 2.100e+03 | 59.04 | 115.92 | 113.81 |For the memory plots, go to the summary of |
|
some comments
so basically be careful |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1926 +/- ##
==========================================
- Coverage 95.80% 95.77% -0.04%
==========================================
Files 100 100
Lines 27561 27541 -20
==========================================
- Hits 26405 26377 -28
- Misses 1156 1164 +8
🚀 New features to boost your workflow:
|
Fair, I need to think more. But if right now the master version of the code does not work with any JAX before AJX 0.6.2, then we have incorrect package requirements listed right now. And certain clusters seem to not support newer jax versions, which means we cannot run latest DESC correctly. Just am trying to think of a way around this. I welcome any suggestions but right now if you install jax 0.6.0 then run tests locally on master, there are failures due to import errors which is problematic in my opinion. we either need to
|
can you explain this part? I dont follow |
i'm saying don't lose the current state of master. switch it to a pr if you're reverting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so basically be careful
Fair, I need to think more. I welcome any suggestions
we either need to
revert the changes that caused these specific import errors or
change our requirements to jax ==0.6.2
Here is how I would have done the original PR if the jax tests had failed there it.
There are 3 functions on master that we took from the newest jax version.
If you want to preserve DESC ability to run old jax versions (even though these old jax versions have bugs that affect DESC), and there is no simple fix, then it's safer to make new legacy batching functions for those 3 functions.
Call those legacy functions only when jax has an old version.
see the top of the desc.batching file to see how to check the jax version.
So update the citation of those 3 functions to include the version you are taking the code from.
if jax > compatible version:
from jax import batch_abd_renainder
else:
def batch_and....
"""..... jax version xxxx"""
stuff
That may turn your loud error into potentially silent and dangerous bugs. Backporting changes in an external code to support new features on old versions of the external code is frowned upon because it requires lots of development time to do properly. That's why there are dedicated teams that work on backporting security updates to various softwares when those softwares are marked to no longer receive feature updates. Even that is frowned upon though because the backports also introduce bugs. The same is true for mixing source code from different versions of external code. That's why I suggest the thing above. |
more questions just to remind me, is this specifically for the like large batch sizes having buggy results? did #1869 fix those bugs? I just thought you had said #1869 would be laying ground for future sharding |
we will do this @YigitElma |
f0uriest
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@unalmis is there any need from your end to bump our minimum Jax version? Ie bugs you've found and added workarounds in desc that we can avoid if we require newer Jax?
Otherwise I'm fine with this
|
I try not to add workarounds because it would pollute our code base. From my memory, there's a series of bugs with FFT accuracy before 0.6.2 and AD with indexing multi-dim arrays. There's 2 in an open PR but that's because the versions with bug fixes aren't out yet. |
|
Ok do we think we hit any of those paths in desc currently? Also we should note the new min/Max Jax versions in the changelog |
|
yes |
|
Should we update docs to say that some bugs exist before 0.6.2 and recommend 0.6.2? |
if jax doesn't automatically instal 0.6.2 they can't do anything anyway |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
Ok no more changes, once tests pass and we confirm they all are passing we can merge |
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_midres | -0.70 +/- 4.55 | -4.87e-03 +/- 3.17e-02 | 6.91e-01 +/- 1.7e-02 | 6.96e-01 +/- 2.7e-02 |
test_build_transform_fft_highres | -0.00 +/- 2.47 | -4.55e-05 +/- 2.30e-02 | 9.32e-01 +/- 1.3e-02 | 9.32e-01 +/- 1.9e-02 |
test_equilibrium_init_lowres | -0.90 +/- 1.46 | -4.16e-02 +/- 6.76e-02 | 4.60e+00 +/- 5.1e-02 | 4.64e+00 +/- 4.5e-02 |
test_objective_compile_atf | -0.03 +/- 2.86 | -1.80e-03 +/- 1.77e-01 | 6.18e+00 +/- 7.1e-02 | 6.19e+00 +/- 1.6e-01 |
test_objective_compute_atf | +7.73 +/- 25.31 | +1.82e-04 +/- 5.95e-04 | 2.53e-03 +/- 5.3e-04 | 2.35e-03 +/- 2.8e-04 |
test_objective_jac_atf | +1.10 +/- 1.35 | +1.93e-02 +/- 2.36e-02 | 1.77e+00 +/- 1.5e-02 | 1.75e+00 +/- 1.9e-02 |
test_perturb_1 | -0.35 +/- 2.44 | -4.99e-02 +/- 3.43e-01 | 1.40e+01 +/- 2.4e-01 | 1.41e+01 +/- 2.4e-01 |
test_proximal_jac_atf | -0.53 +/- 1.37 | -3.00e-02 +/- 7.71e-02 | 5.60e+00 +/- 6.0e-02 | 5.63e+00 +/- 4.8e-02 |
test_proximal_freeb_compute | -0.05 +/- 2.58 | -8.92e-05 +/- 4.23e-03 | 1.64e-01 +/- 3.6e-03 | 1.64e-01 +/- 2.2e-03 |
test_solve_fixed_iter | -0.70 +/- 1.45 | -2.07e-01 +/- 4.27e-01 | 2.91e+01 +/- 2.7e-01 | 2.93e+01 +/- 3.3e-01 |
test_objective_compute_ripple | +0.19 +/- 0.97 | +4.88e-03 +/- 2.55e-02 | 2.63e+00 +/- 1.7e-02 | 2.62e+00 +/- 1.9e-02 |
test_objective_grad_ripple | -0.87 +/- 1.09 | -4.13e-02 +/- 5.17e-02 | 4.72e+00 +/- 2.7e-02 | 4.76e+00 +/- 4.4e-02 |
test_build_transform_fft_lowres | +4.25 +/- 5.83 | +2.38e-02 +/- 3.26e-02 | 5.84e-01 +/- 3.1e-02 | 5.60e-01 +/- 1.1e-02 |
test_equilibrium_init_medres | +3.00 +/- 3.61 | +1.49e-01 +/- 1.79e-01 | 5.10e+00 +/- 1.5e-01 | 4.96e+00 +/- 9.1e-02 |
test_equilibrium_init_highres | +0.15 +/- 3.93 | +8.54e-03 +/- 2.24e-01 | 5.70e+00 +/- 7.1e-02 | 5.69e+00 +/- 2.1e-01 |
test_objective_compile_dshape_current | -3.13 +/- 1.70 | -1.11e-01 +/- 6.03e-02 | 3.44e+00 +/- 4.2e-02 | 3.55e+00 +/- 4.3e-02 |
test_objective_compute_dshape_current | +9.33 +/- 8.60 | +6.82e-05 +/- 6.28e-05 | 7.99e-04 +/- 5.1e-05 | 7.31e-04 +/- 3.6e-05 |
test_objective_jac_dshape_current | -0.73 +/- 19.91 | -2.40e-04 +/- 6.58e-03 | 3.28e-02 +/- 4.1e-03 | 3.31e-02 +/- 5.2e-03 |
test_perturb_2 | -1.48 +/- 1.50 | -2.62e-01 +/- 2.66e-01 | 1.74e+01 +/- 2.3e-01 | 1.77e+01 +/- 1.3e-01 |
test_proximal_jac_atf_with_eq_update | -0.47 +/- 0.85 | -6.39e-02 +/- 1.16e-01 | 1.36e+01 +/- 1.0e-01 | 1.37e+01 +/- 5.1e-02 |
test_proximal_freeb_jac | +0.78 +/- 1.75 | +3.90e-02 +/- 8.73e-02 | 5.02e+00 +/- 7.7e-02 | 4.98e+00 +/- 4.1e-02 |
test_solve_fixed_iter_compiled | +1.01 +/- 1.27 | +1.71e-01 +/- 2.14e-01 | 1.70e+01 +/- 1.5e-01 | 1.68e+01 +/- 1.5e-01 |
test_LinearConstraintProjection_build | +2.10 +/- 3.03 | +1.73e-01 +/- 2.49e-01 | 8.37e+00 +/- 1.6e-01 | 8.20e+00 +/- 1.9e-01 |
test_objective_compute_ripple_spline | +1.87 +/- 3.29 | +5.40e-03 +/- 9.48e-03 | 2.94e-01 +/- 6.3e-03 | 2.88e-01 +/- 7.1e-03 |
test_objective_grad_ripple_spline | -0.61 +/- 1.73 | -6.84e-03 +/- 1.93e-02 | 1.11e+00 +/- 8.8e-03 | 1.11e+00 +/- 1.7e-02 | |
Changes in PlasmaControl#1869 effectively bumped the minimum JAX version from 0.4.29 to 0.6.2. This PR adds a compatibility layer and removes support for versions 0.4.28 and lower which DESC did not actually support. - During testing, it's been noticed that we don't support 0.4.28 and older anymore, since they are more than a year old, I just removed them. - JAX 0.7.0 passes the tests but as noted in patrick-kidger/diffrax#680 can cause problems. New requirements skip that. - JAX 0.7.1 fails tests and as noted in patrick-kidger/diffrax#680 has poor performance for loops. Again, skipping it. - JAX 0.7.2 seems to be fine, but there are likely unknown bugs and has issues with PlasmaControl#1857 Resolves PlasmaControl#1925 --------- Co-authored-by: YigitElma <yigitelmacioglu@gmail.com> Co-authored-by: Yigit Gunsur Elmacioglu <102380275+YigitElma@users.noreply.github.com> Co-authored-by: Kaya Unalmis <kayaunalmis@proton.me> Co-authored-by: Rahul Gaur <19224702+rahulgaur104@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Changes in #1869 effectively bumped the minimum JAX version from 0.4.29 to 0.6.2. This PR adds a compatibility layer and removes support for versions 0.4.28 and lower which DESC did not actually support.
Resolves #1925