Skip to content

Conversation

@dpanici
Copy link
Collaborator

@dpanici dpanici commented Sep 20, 2025

Changes in #1869 effectively bumped the minimum JAX version from 0.4.29 to 0.6.2. This PR adds a compatibility layer and removes support for versions 0.4.28 and lower which DESC did not actually support.

Resolves #1925

@dpanici dpanici requested review from a team, YigitElma, ddudt, f0uriest, rahulgaur104 and unalmis and removed request for a team September 20, 2025 21:25
@dpanici dpanici added the test_jax Run tests against different versions of JAX label Sep 20, 2025
… is different api before jax 0.6.2"

This reverts commit 6fdb6e7.
…sed with vmap (across memory in single device) (#1869)"

This reverts commit 739598c.
@dpanici dpanici added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Sep 20, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Sep 20, 2025

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    3.40 %    |     3.908e+03      |     4.041e+03      |    133.04    |       37.66        |       33.64        |
  test_proximal_jac_w7x_with_eq_update   |    0.63 %    |     6.828e+03      |     6.871e+03      |    42.95     |       173.83       |       173.53       |
  test_proximal_freeb_jac                |    0.30 %    |     1.322e+04      |     1.326e+04      |    39.00     |       86.43        |       86.13        |
  test_proximal_freeb_jac_blocked        |    0.63 %    |     7.652e+03      |     7.701e+03      |    48.16     |       78.82        |       78.94        |
  test_proximal_freeb_jac_batched        |   -0.13 %    |     7.623e+03      |     7.613e+03      |    -9.88     |       78.61        |       78.82        |
  test_proximal_jac_ripple               |    0.00 %    |     7.728e+03      |     7.728e+03      |     0.34     |       70.11        |       70.52        |
  test_proximal_jac_ripple_spline        |    1.46 %    |     3.639e+03      |     3.692e+03      |    53.25     |       70.41        |       69.92        |
  test_eq_solve                          |    2.89 %    |     2.041e+03      |     2.100e+03      |    59.04     |       115.92       |       113.81       |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@unalmis
Copy link
Collaborator

unalmis commented Sep 20, 2025

some comments

  1. If a full reversion is being done, should make a new PR that also includes the reversion of this reversion.
  2. It sounds like this reversion would prevent use of DESC on JAX versions beyond 0.6.2. ?
  3. The changes we backported from JAX (I think also include some JAX versions before 0.6.2) include JAX's bugfixes for various batching issues. As this PR intends to revert those fixes, this leads to a danger in shipping a version frankenstein of code. If the JAX version is one which assumes those are fixed and they aren't we would get subtle issues.
  4. It is hard enough to convince developers of other dependencies to fix their bugs. Anything short of a creating a GitHub issue with the actual bugfix, or even more difficult: with a minimial working example, always gets ignored. That is my experience with some jax dependencies at least. I doubt anyone will take Tracking bugs I reported in JAX #1599 seriously if we mix jax versions for example. I am especially hesitant because JAX is still fixing batching issues that caused subtle bugs in DESC DFTInterpolator for singular integral gives bad results #1522 and maybe still affects Patch to work around discretization error impeding optimizer #1894 in some way which took a very long time to isolate and delayed work.

so basically be careful

@codecov
Copy link

codecov bot commented Sep 20, 2025

Codecov Report

❌ Patch coverage is 22.22222% with 14 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.77%. Comparing base (8770643) to head (b2f4c4a).
⚠️ Report is 62 commits behind head on master.

Files with missing lines Patch % Lines
desc/batching.py 22.22% 14 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1926      +/-   ##
==========================================
- Coverage   95.80%   95.77%   -0.04%     
==========================================
  Files         100      100              
  Lines       27561    27541      -20     
==========================================
- Hits        26405    26377      -28     
- Misses       1156     1164       +8     
Files with missing lines Coverage Δ
desc/batching.py 82.24% <22.22%> (-6.12%) ⬇️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@dpanici
Copy link
Collaborator Author

dpanici commented Sep 21, 2025

some comments

  1. If a full reversion is being done, should make a new PR that also includes the reversion of this reversion.
  2. It sounds like this reversion would prevent use of DESC on JAX versions beyond 0.6.2. ?
  3. The changes we backported from JAX (I think also include some JAX versions before 0.6.2) include JAX's bugfixes for various batching issues. As this PR intends to revert those fixes, this leads to a danger in shipping a version frankenstein of code. If the JAX version is one which assumes those are fixed and they aren't we would get subtle issues.
  4. It is hard enough to convince developers of other dependencies to fix their bugs. Anything short of a creating a GitHub issue with the actual bugfix, or even more difficult: with a minimial working example, always gets ignored. That is my experience with some jax dependencies at least. I doubt anyone will take Tracking bugs I reported in JAX #1599 seriously if we mix jax versions for example. I am especially hesitant because JAX is still fixing batching issues that caused subtle bugs in DESC DFTInterpolator for singular integral gives bad results #1522 and maybe still affects Patch to work around discretization error impeding optimizer #1894 in some way which took a very long time to isolate and delayed work.

so basically be careful

Fair, I need to think more. But if right now the master version of the code does not work with any JAX before AJX 0.6.2, then we have incorrect package requirements listed right now. And certain clusters seem to not support newer jax versions, which means we cannot run latest DESC correctly. Just am trying to think of a way around this. I welcome any suggestions but right now if you install jax 0.6.0 then run tests locally on master, there are failures due to import errors which is problematic in my opinion.

we either need to

  • revert the changes that caused these specific import errors or
  • change our requirements to jax ==0.6.2

@dpanici
Copy link
Collaborator Author

dpanici commented Sep 21, 2025

some comments

  1. If a full reversion is being done, should make a new PR that also includes the reversion of this reversion.

can you explain this part? I dont follow

@rahulgaur104
Copy link
Collaborator

A trivial solution is to unmerge #1869. That should fix it but we won't have the new sharding-related function.
But I don't really know how much #1869 helps with speed or memory.

@unalmis
Copy link
Collaborator

unalmis commented Sep 21, 2025

some comments

  1. If a full reversion is being done, should make a new PR that also includes the reversion of this reversion.

can you explain this part? I dont follow

i'm saying don't lose the current state of master. switch it to a pr if you're reverting

Copy link
Collaborator

@unalmis unalmis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so basically be careful

Fair, I need to think more. I welcome any suggestions
we either need to

  • revert the changes that caused these specific import errors or

  • change our requirements to jax ==0.6.2

Here is how I would have done the original PR if the jax tests had failed there it.
There are 3 functions on master that we took from the newest jax version.

If you want to preserve DESC ability to run old jax versions (even though these old jax versions have bugs that affect DESC), and there is no simple fix, then it's safer to make new legacy batching functions for those 3 functions.

Call those legacy functions only when jax has an old version.
see the top of the desc.batching file to see how to check the jax version.

So update the citation of those 3 functions to include the version you are taking the code from.

if jax > compatible version:
    from jax import batch_abd_renainder
else:
    def batch_and....
        """..... jax version xxxx"""
         stuff

@unalmis
Copy link
Collaborator

unalmis commented Sep 21, 2025

A trivial solution is to unmerge #1869

That may turn your loud error into potentially silent and dangerous bugs.

Backporting changes in an external code to support new features on old versions of the external code is frowned upon because it requires lots of development time to do properly. That's why there are dedicated teams that work on backporting security updates to various softwares when those softwares are marked to no longer receive feature updates. Even that is frowned upon though because the backports also introduce bugs.

The same is true for mixing source code from different versions of external code.

That's why I suggest the thing above.

@dpanici
Copy link
Collaborator Author

dpanici commented Sep 22, 2025

If you want to preserve DESC ability to run old jax versions (even though these old jax versions have bugs that affect DESC), and there is no simple fix, then it's safer to make new legacy batching functions for those 3 functions.

more questions just to remind me, is this specifically for the like large batch sizes having buggy results? did #1869 fix those bugs? I just thought you had said #1869 would be laying ground for future sharding

@dpanici
Copy link
Collaborator Author

dpanici commented Sep 22, 2025

So update the citation of those 3 functions to include the version you are taking the code from.

if jax > compatible version:
    from jax import batch_abd_renainder
else:
    def batch_and....
        """..... jax version xxxx"""
         stuff

we will do this @YigitElma

@YigitElma YigitElma added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Sep 22, 2025
@YigitElma YigitElma added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Sep 22, 2025
unalmis
unalmis previously approved these changes Sep 24, 2025
f0uriest
f0uriest previously approved these changes Sep 24, 2025
Copy link
Member

@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@unalmis is there any need from your end to bump our minimum Jax version? Ie bugs you've found and added workarounds in desc that we can avoid if we require newer Jax?

Otherwise I'm fine with this

@unalmis
Copy link
Collaborator

unalmis commented Sep 24, 2025

I try not to add workarounds because it would pollute our code base. From my memory, there's a series of bugs with FFT accuracy before 0.6.2 and AD with indexing multi-dim arrays. There's 2 in an open PR but that's because the versions with bug fixes aren't out yet.

@f0uriest
Copy link
Member

Ok do we think we hit any of those paths in desc currently?

Also we should note the new min/Max Jax versions in the changelog

@unalmis
Copy link
Collaborator

unalmis commented Sep 24, 2025

yes

@dpanici dpanici dismissed stale reviews from f0uriest and unalmis via 7cf2e7a September 24, 2025 18:42
unalmis
unalmis previously approved these changes Sep 24, 2025
@dpanici
Copy link
Collaborator Author

dpanici commented Sep 24, 2025

Should we update docs to say that some bugs exist before 0.6.2 and recommend 0.6.2?

@unalmis
Copy link
Collaborator

unalmis commented Sep 24, 2025

Should we update docs to say that some bugs exist before 0.6.2 and recommend 0.6.2?

if jax doesn't automatically instal 0.6.2 they can't do anything anyway

rahulgaur104
rahulgaur104 previously approved these changes Sep 24, 2025
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@dpanici
Copy link
Collaborator Author

dpanici commented Sep 24, 2025

Ok no more changes, once tests pass and we confirm they all are passing we can merge

@unalmis unalmis added the run_benchmarks Run timing benchmarks on this PR against current master branch label Sep 25, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Sep 25, 2025

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     -0.70 +/- 4.55     | -4.87e-03 +/- 3.17e-02 |  6.91e-01 +/- 1.7e-02  |  6.96e-01 +/- 2.7e-02  |
 test_build_transform_fft_highres        |     -0.00 +/- 2.47     | -4.55e-05 +/- 2.30e-02 |  9.32e-01 +/- 1.3e-02  |  9.32e-01 +/- 1.9e-02  |
 test_equilibrium_init_lowres            |     -0.90 +/- 1.46     | -4.16e-02 +/- 6.76e-02 |  4.60e+00 +/- 5.1e-02  |  4.64e+00 +/- 4.5e-02  |
 test_objective_compile_atf              |     -0.03 +/- 2.86     | -1.80e-03 +/- 1.77e-01 |  6.18e+00 +/- 7.1e-02  |  6.19e+00 +/- 1.6e-01  |
 test_objective_compute_atf              |     +7.73 +/- 25.31    | +1.82e-04 +/- 5.95e-04 |  2.53e-03 +/- 5.3e-04  |  2.35e-03 +/- 2.8e-04  |
 test_objective_jac_atf                  |     +1.10 +/- 1.35     | +1.93e-02 +/- 2.36e-02 |  1.77e+00 +/- 1.5e-02  |  1.75e+00 +/- 1.9e-02  |
 test_perturb_1                          |     -0.35 +/- 2.44     | -4.99e-02 +/- 3.43e-01 |  1.40e+01 +/- 2.4e-01  |  1.41e+01 +/- 2.4e-01  |
 test_proximal_jac_atf                   |     -0.53 +/- 1.37     | -3.00e-02 +/- 7.71e-02 |  5.60e+00 +/- 6.0e-02  |  5.63e+00 +/- 4.8e-02  |
 test_proximal_freeb_compute             |     -0.05 +/- 2.58     | -8.92e-05 +/- 4.23e-03 |  1.64e-01 +/- 3.6e-03  |  1.64e-01 +/- 2.2e-03  |
 test_solve_fixed_iter                   |     -0.70 +/- 1.45     | -2.07e-01 +/- 4.27e-01 |  2.91e+01 +/- 2.7e-01  |  2.93e+01 +/- 3.3e-01  |
 test_objective_compute_ripple           |     +0.19 +/- 0.97     | +4.88e-03 +/- 2.55e-02 |  2.63e+00 +/- 1.7e-02  |  2.62e+00 +/- 1.9e-02  |
 test_objective_grad_ripple              |     -0.87 +/- 1.09     | -4.13e-02 +/- 5.17e-02 |  4.72e+00 +/- 2.7e-02  |  4.76e+00 +/- 4.4e-02  |
 test_build_transform_fft_lowres         |     +4.25 +/- 5.83     | +2.38e-02 +/- 3.26e-02 |  5.84e-01 +/- 3.1e-02  |  5.60e-01 +/- 1.1e-02  |
 test_equilibrium_init_medres            |     +3.00 +/- 3.61     | +1.49e-01 +/- 1.79e-01 |  5.10e+00 +/- 1.5e-01  |  4.96e+00 +/- 9.1e-02  |
 test_equilibrium_init_highres           |     +0.15 +/- 3.93     | +8.54e-03 +/- 2.24e-01 |  5.70e+00 +/- 7.1e-02  |  5.69e+00 +/- 2.1e-01  |
 test_objective_compile_dshape_current   |     -3.13 +/- 1.70     | -1.11e-01 +/- 6.03e-02 |  3.44e+00 +/- 4.2e-02  |  3.55e+00 +/- 4.3e-02  |
 test_objective_compute_dshape_current   |     +9.33 +/- 8.60     | +6.82e-05 +/- 6.28e-05 |  7.99e-04 +/- 5.1e-05  |  7.31e-04 +/- 3.6e-05  |
 test_objective_jac_dshape_current       |     -0.73 +/- 19.91    | -2.40e-04 +/- 6.58e-03 |  3.28e-02 +/- 4.1e-03  |  3.31e-02 +/- 5.2e-03  |
 test_perturb_2                          |     -1.48 +/- 1.50     | -2.62e-01 +/- 2.66e-01 |  1.74e+01 +/- 2.3e-01  |  1.77e+01 +/- 1.3e-01  |
 test_proximal_jac_atf_with_eq_update    |     -0.47 +/- 0.85     | -6.39e-02 +/- 1.16e-01 |  1.36e+01 +/- 1.0e-01  |  1.37e+01 +/- 5.1e-02  |
 test_proximal_freeb_jac                 |     +0.78 +/- 1.75     | +3.90e-02 +/- 8.73e-02 |  5.02e+00 +/- 7.7e-02  |  4.98e+00 +/- 4.1e-02  |
 test_solve_fixed_iter_compiled          |     +1.01 +/- 1.27     | +1.71e-01 +/- 2.14e-01 |  1.70e+01 +/- 1.5e-01  |  1.68e+01 +/- 1.5e-01  |
 test_LinearConstraintProjection_build   |     +2.10 +/- 3.03     | +1.73e-01 +/- 2.49e-01 |  8.37e+00 +/- 1.6e-01  |  8.20e+00 +/- 1.9e-01  |
 test_objective_compute_ripple_spline    |     +1.87 +/- 3.29     | +5.40e-03 +/- 9.48e-03 |  2.94e-01 +/- 6.3e-03  |  2.88e-01 +/- 7.1e-03  |
 test_objective_grad_ripple_spline       |     -0.61 +/- 1.73     | -6.84e-03 +/- 1.93e-02 |  1.11e+00 +/- 8.8e-03  |  1.11e+00 +/- 1.7e-02  |

@dpanici dpanici requested a review from f0uriest September 25, 2025 15:14
@dpanici dpanici merged commit 684de2e into master Sep 25, 2025
85 of 86 checks passed
@dpanici dpanici deleted the dp/hotfix-batching branch September 25, 2025 17:27
DMCXE pushed a commit to DMCXE/DESC-OOPS that referenced this pull request Oct 14, 2025
Changes in PlasmaControl#1869 effectively bumped the minimum JAX version from 0.4.29
to 0.6.2. This PR adds a compatibility layer and removes support for
versions 0.4.28 and lower which DESC did not actually support.

- During testing, it's been noticed that we don't support 0.4.28 and
older anymore, since they are more than a year old, I just removed them.
- JAX 0.7.0 passes the tests but as noted in
patrick-kidger/diffrax#680 can cause problems.
New requirements skip that.
- JAX 0.7.1 fails tests and as noted in
patrick-kidger/diffrax#680 has poor
performance for loops. Again, skipping it.
- JAX 0.7.2 seems to be fine, but there are likely unknown bugs and has
issues with PlasmaControl#1857

Resolves PlasmaControl#1925

---------

Co-authored-by: YigitElma <yigitelmacioglu@gmail.com>
Co-authored-by: Yigit Gunsur Elmacioglu <102380275+YigitElma@users.noreply.github.com>
Co-authored-by: Kaya Unalmis <kayaunalmis@proton.me>
Co-authored-by: Rahul Gaur <19224702+rahulgaur104@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run_benchmarks Run timing benchmarks on this PR against current master branch skip_changelog No need to update changelog on this PR test_jax Run tests against different versions of JAX

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Batched equilibrium solve throws an error

6 participants