Skip to content

Conversation

@YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented Aug 10, 2025

Includes some of the changes realized during #1820

  • The lambda function definition in the field_line_integrate can cause recompilation if one uses it in an objective. This PR moves the lambda function out, passing all the arguments in args.
  • Uses vmap_chunked instead of jnp.vectorize
  • Adds more arguments of diffrax to docs (will be useful for objectives)
  • Renames maxsteps to max_steps and updates the docs
  • Bumps the minimum diffrax version to 0.6.0 (Event is added in that release)

Resolves #1759

@github-actions
Copy link
Contributor

github-actions bot commented Aug 10, 2025

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    6.74 %    |     3.801e+03      |     4.057e+03      |    256.33    |       35.76        |       32.24        |
  test_proximal_jac_w7x_with_eq_update   |   -0.54 %    |     6.794e+03      |     6.757e+03      |    -36.72    |       168.19       |       162.54       |
  test_proximal_freeb_jac                |    0.13 %    |     1.319e+04      |     1.321e+04      |    16.90     |       80.34        |       77.31        |
  test_proximal_freeb_jac_blocked        |    0.14 %    |     7.584e+03      |     7.595e+03      |    10.96     |       71.62        |       69.21        |
  test_proximal_freeb_jac_batched        |   -0.39 %    |     7.673e+03      |     7.643e+03      |    -30.02    |       71.52        |       69.21        |
  test_proximal_jac_ripple               |   -1.17 %    |     7.659e+03      |     7.570e+03      |    -89.99    |       73.96        |       71.43        |
  test_proximal_jac_ripple_spline        |    0.75 %    |     3.402e+03      |     3.427e+03      |    25.67     |       74.14        |       75.55        |
  test_eq_solve                          |   -1.72 %    |     2.053e+03      |     2.018e+03      |    -35.34    |       125.89       |       124.77       |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@dpanici
Copy link
Collaborator

dpanici commented Sep 3, 2025

Should I just point #1855 to this PR?

@YigitElma
Copy link
Collaborator Author

Should I just point #1855 to this PR?

Yeah, or you can just merge these changes. But also if you are gonna make a new function, I don't know how much that would help.

@YigitElma YigitElma marked this pull request as ready for review September 4, 2025 21:15
@YigitElma YigitElma requested review from a team, ddudt, dpanici, f0uriest, rahulgaur104 and unalmis and removed request for a team September 4, 2025 21:15
@codecov
Copy link

codecov bot commented Sep 4, 2025

Codecov Report

❌ Patch coverage is 84.00000% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.74%. Comparing base (019a955) to head (ada3aae).
⚠️ Report is 59 commits behind head on master.

Files with missing lines Patch % Lines
desc/plotting.py 57.89% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1839      +/-   ##
==========================================
- Coverage   95.76%   95.74%   -0.03%     
==========================================
  Files         100      100              
  Lines       27541    27565      +24     
==========================================
+ Hits        26375    26392      +17     
- Misses       1166     1173       +7     
Files with missing lines Coverage Δ
desc/magnetic_fields/_core.py 96.47% <100.00%> (+0.02%) ⬆️
desc/plotting.py 95.41% <57.89%> (-0.48%) ⬇️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@YigitElma YigitElma requested a review from Copilot September 5, 2025 00:59
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR improves the field_line_integrate function to address compilation issues and provide better control over integration parameters.

  • Refactors lambda function to prevent recompilation during JAX transforms
  • Replaces jnp.vectorize with vmap_chunked for better performance
  • Renames maxsteps parameter to max_steps for consistency with diffrax

Reviewed Changes

Copilot reviewed 4 out of 5 changed files in this pull request and generated 2 comments.

File Description
desc/magnetic_fields/_core.py Major refactoring of field_line_integrate function, moving lambda out and adding new parameters
tests/test_magnetic_fields.py Updates tests to use new max_steps parameter and adds recompilation test
tests/test_plotting.py Removes chunk_size parameter and adds options dict test
CHANGELOG.md Documents API changes and bug fixes

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@YigitElma YigitElma self-assigned this Sep 5, 2025
@YigitElma YigitElma added the run_benchmarks Run timing benchmarks on this PR against current master branch label Sep 5, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Sep 5, 2025

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     +0.04 +/- 2.98     | +2.66e-04 +/- 1.96e-02 |  6.58e-01 +/- 1.5e-02  |  6.57e-01 +/- 1.2e-02  |
 test_build_transform_fft_highres        |     +0.66 +/- 2.16     | +5.94e-03 +/- 1.93e-02 |  9.00e-01 +/- 1.1e-02  |  8.94e-01 +/- 1.6e-02  |
 test_equilibrium_init_lowres            |     -0.08 +/- 1.48     | -3.35e-03 +/- 6.35e-02 |  4.30e+00 +/- 4.8e-02  |  4.30e+00 +/- 4.1e-02  |
 test_objective_compile_atf              |     +0.64 +/- 2.86     | +3.81e-02 +/- 1.71e-01 |  6.02e+00 +/- 1.2e-01  |  5.98e+00 +/- 1.2e-01  |
 test_objective_compute_atf              |     +0.51 +/- 3.25     | +1.01e-05 +/- 6.44e-05 |  1.99e-03 +/- 3.2e-05  |  1.98e-03 +/- 5.6e-05  |
 test_objective_jac_atf                  |     -1.66 +/- 2.97     | -2.90e-02 +/- 5.19e-02 |  1.71e+00 +/- 2.6e-02  |  1.74e+00 +/- 4.5e-02  |
 test_perturb_1                          |     -0.45 +/- 2.32     | -6.09e-02 +/- 3.14e-01 |  1.35e+01 +/- 2.1e-01  |  1.35e+01 +/- 2.3e-01  |
 test_proximal_jac_atf                   |     +0.02 +/- 1.53     | +1.01e-03 +/- 8.52e-02 |  5.56e+00 +/- 6.4e-02  |  5.56e+00 +/- 5.6e-02  |
 test_proximal_freeb_compute             |     +0.96 +/- 2.53     | +1.54e-03 +/- 4.05e-03 |  1.62e-01 +/- 2.7e-03  |  1.61e-01 +/- 3.0e-03  |
 test_solve_fixed_iter                   |     +0.17 +/- 1.30     | +4.93e-02 +/- 3.69e-01 |  2.84e+01 +/- 3.3e-01  |  2.84e+01 +/- 1.7e-01  |
 test_objective_compute_ripple           |     +0.47 +/- 1.14     | +1.21e-02 +/- 2.97e-02 |  2.61e+00 +/- 1.3e-02  |  2.60e+00 +/- 2.7e-02  |
 test_objective_grad_ripple              |     +0.01 +/- 1.91     | +3.03e-04 +/- 8.97e-02 |  4.69e+00 +/- 5.6e-02  |  4.69e+00 +/- 7.0e-02  |
 test_build_transform_fft_lowres         |     -0.23 +/- 2.75     | -1.25e-03 +/- 1.52e-02 |  5.53e-01 +/- 1.1e-02  |  5.54e-01 +/- 1.1e-02  |
 test_equilibrium_init_medres            |     +0.02 +/- 0.82     | +9.65e-04 +/- 3.97e-02 |  4.82e+00 +/- 2.4e-02  |  4.82e+00 +/- 3.1e-02  |
 test_equilibrium_init_highres           |     +1.45 +/- 1.95     | +7.86e-02 +/- 1.06e-01 |  5.52e+00 +/- 1.0e-01  |  5.44e+00 +/- 3.0e-02  |
 test_objective_compile_dshape_current   |     +0.03 +/- 0.94     | +1.15e-03 +/- 3.12e-02 |  3.31e+00 +/- 1.8e-02  |  3.31e+00 +/- 2.6e-02  |
 test_objective_compute_dshape_current   |     -3.55 +/- 6.41     | -2.79e-05 +/- 5.04e-05 |  7.59e-04 +/- 2.3e-05  |  7.87e-04 +/- 4.5e-05  |
 test_objective_jac_dshape_current       |     -0.37 +/- 16.20    | -1.19e-04 +/- 5.22e-03 |  3.21e-02 +/- 2.8e-03  |  3.22e-02 +/- 4.4e-03  |
 test_perturb_2                          |     -1.46 +/- 2.16     | -2.49e-01 +/- 3.68e-01 |  1.68e+01 +/- 1.4e-01  |  1.70e+01 +/- 3.4e-01  |
 test_proximal_jac_atf_with_eq_update    |     -0.73 +/- 0.72     | -9.89e-02 +/- 9.76e-02 |  1.35e+01 +/- 9.3e-02  |  1.36e+01 +/- 3.0e-02  |
 test_proximal_freeb_jac                 |     -0.24 +/- 1.67     | -1.20e-02 +/- 8.36e-02 |  4.99e+00 +/- 4.9e-02  |  5.00e+00 +/- 6.8e-02  |
 test_solve_fixed_iter_compiled          |     -0.32 +/- 1.20     | -5.38e-02 +/- 2.05e-01 |  1.70e+01 +/- 5.8e-02  |  1.71e+01 +/- 2.0e-01  |
 test_LinearConstraintProjection_build   |     -2.05 +/- 3.27     | -1.74e-01 +/- 2.79e-01 |  8.34e+00 +/- 9.8e-02  |  8.51e+00 +/- 2.6e-01  |
 test_objective_compute_ripple_spline    |     +0.14 +/- 2.34     | +4.18e-04 +/- 6.89e-03 |  2.94e-01 +/- 5.1e-03  |  2.94e-01 +/- 4.7e-03  |
 test_objective_grad_ripple_spline       |     -0.24 +/- 3.24     | -2.70e-03 +/- 3.60e-02 |  1.11e+00 +/- 1.7e-02  |  1.11e+00 +/- 3.1e-02  |

G
i
t
h
u
b

C
I

p
e
r
f
o
r
m
a
n
c
e

c
a
n

b
e

n
o
i
s
y
.

W
h
e
n

e
v
a
l
u
a
t
i
n
g

t
h
e

b
e
n
c
h
m
a
r
k
s
,

d
e
v
e
l
o
p
e
r
s

s
h
o
u
l
d

t
a
k
e

t
h
i
s

i
n
t
o

a
c
c
o
u
n
t
.


# check if it is jittable
r, z = jit(field_line_integrate)(r0, z0, phis, field)
r, z = jit(fun0)(r0, z0, field)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you don't want to jit fun0 here if you're trying to check recompilation. This should pass on master I think? since its only caching based on r0, z0 and field.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the test, but basically I need to close over some of the variables that are not hashable (declaring static doesn't work). One option could be writing a dummy objective, and calling compute and grad, because basically what I care is not recompiling the objective function.

Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still reviewing but add bs_chunk_size to retain that functionality

@YigitElma YigitElma requested review from ddudt and dpanici September 26, 2025 17:05
@YigitElma YigitElma requested a review from dpanici October 1, 2025 02:33
f0uriest
f0uriest previously approved these changes Oct 1, 2025
@YigitElma YigitElma merged commit 5cf5ca0 into master Oct 1, 2025
27 of 28 checks passed
@YigitElma YigitElma deleted the yge/field-integrate branch October 1, 2025 17:34
@unalmis unalmis added the test_jax Run tests against different versions of JAX label Oct 2, 2025
@unalmis unalmis restored the yge/field-integrate branch October 2, 2025 06:25
@unalmis unalmis added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 2, 2025
@unalmis unalmis deleted the yge/field-integrate branch October 2, 2025 06:27
DMCXE pushed a commit to DMCXE/DESC-OOPS that referenced this pull request Oct 14, 2025
Includes some of the changes realized during PlasmaControl#1820 

- The lambda function definition in the `field_line_integrate` can cause
recompilation if one uses it in an objective. This PR moves the lambda
function out, passing all the arguments in `args`.
- Uses `vmap_chunked` instead of `jnp.vectorize`
- Adds more arguments of diffrax to docs (will be useful for objectives)
- Renames `maxsteps` to `max_steps` and updates the docs
- Bumps the minimum `diffrax` version to 0.6.0 (`Event` is added in that
release)

Resolves PlasmaControl#1759
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

easy Short and simple to code or review run_benchmarks Run timing benchmarks on this PR against current master branch test_jax Run tests against different versions of JAX

Projects

None yet

Development

Successfully merging this pull request may close these issues.

diffrax.diffeqsolve(throw=False)

6 participants