-
Notifications
You must be signed in to change notification settings - Fork 40
Updates to field_line_integrate
#1839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…input/docs, add tests
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 6.74 % | 3.801e+03 | 4.057e+03 | 256.33 | 35.76 | 32.24 |
test_proximal_jac_w7x_with_eq_update | -0.54 % | 6.794e+03 | 6.757e+03 | -36.72 | 168.19 | 162.54 |
test_proximal_freeb_jac | 0.13 % | 1.319e+04 | 1.321e+04 | 16.90 | 80.34 | 77.31 |
test_proximal_freeb_jac_blocked | 0.14 % | 7.584e+03 | 7.595e+03 | 10.96 | 71.62 | 69.21 |
test_proximal_freeb_jac_batched | -0.39 % | 7.673e+03 | 7.643e+03 | -30.02 | 71.52 | 69.21 |
test_proximal_jac_ripple | -1.17 % | 7.659e+03 | 7.570e+03 | -89.99 | 73.96 | 71.43 |
test_proximal_jac_ripple_spline | 0.75 % | 3.402e+03 | 3.427e+03 | 25.67 | 74.14 | 75.55 |
test_eq_solve | -1.72 % | 2.053e+03 | 2.018e+03 | -35.34 | 125.89 | 124.77 |For the memory plots, go to the summary of |
|
Should I just point #1855 to this PR? |
Yeah, or you can just merge these changes. But also if you are gonna make a new function, I don't know how much that would help. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1839 +/- ##
==========================================
- Coverage 95.76% 95.74% -0.03%
==========================================
Files 100 100
Lines 27541 27565 +24
==========================================
+ Hits 26375 26392 +17
- Misses 1166 1173 +7
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR improves the field_line_integrate function to address compilation issues and provide better control over integration parameters.
- Refactors lambda function to prevent recompilation during JAX transforms
- Replaces
jnp.vectorizewithvmap_chunkedfor better performance - Renames
maxstepsparameter tomax_stepsfor consistency with diffrax
Reviewed Changes
Copilot reviewed 4 out of 5 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| desc/magnetic_fields/_core.py | Major refactoring of field_line_integrate function, moving lambda out and adding new parameters |
| tests/test_magnetic_fields.py | Updates tests to use new max_steps parameter and adds recompilation test |
| tests/test_plotting.py | Removes chunk_size parameter and adds options dict test |
| CHANGELOG.md | Documents API changes and bug fixes |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_midres | +0.04 +/- 2.98 | +2.66e-04 +/- 1.96e-02 | 6.58e-01 +/- 1.5e-02 | 6.57e-01 +/- 1.2e-02 |
test_build_transform_fft_highres | +0.66 +/- 2.16 | +5.94e-03 +/- 1.93e-02 | 9.00e-01 +/- 1.1e-02 | 8.94e-01 +/- 1.6e-02 |
test_equilibrium_init_lowres | -0.08 +/- 1.48 | -3.35e-03 +/- 6.35e-02 | 4.30e+00 +/- 4.8e-02 | 4.30e+00 +/- 4.1e-02 |
test_objective_compile_atf | +0.64 +/- 2.86 | +3.81e-02 +/- 1.71e-01 | 6.02e+00 +/- 1.2e-01 | 5.98e+00 +/- 1.2e-01 |
test_objective_compute_atf | +0.51 +/- 3.25 | +1.01e-05 +/- 6.44e-05 | 1.99e-03 +/- 3.2e-05 | 1.98e-03 +/- 5.6e-05 |
test_objective_jac_atf | -1.66 +/- 2.97 | -2.90e-02 +/- 5.19e-02 | 1.71e+00 +/- 2.6e-02 | 1.74e+00 +/- 4.5e-02 |
test_perturb_1 | -0.45 +/- 2.32 | -6.09e-02 +/- 3.14e-01 | 1.35e+01 +/- 2.1e-01 | 1.35e+01 +/- 2.3e-01 |
test_proximal_jac_atf | +0.02 +/- 1.53 | +1.01e-03 +/- 8.52e-02 | 5.56e+00 +/- 6.4e-02 | 5.56e+00 +/- 5.6e-02 |
test_proximal_freeb_compute | +0.96 +/- 2.53 | +1.54e-03 +/- 4.05e-03 | 1.62e-01 +/- 2.7e-03 | 1.61e-01 +/- 3.0e-03 |
test_solve_fixed_iter | +0.17 +/- 1.30 | +4.93e-02 +/- 3.69e-01 | 2.84e+01 +/- 3.3e-01 | 2.84e+01 +/- 1.7e-01 |
test_objective_compute_ripple | +0.47 +/- 1.14 | +1.21e-02 +/- 2.97e-02 | 2.61e+00 +/- 1.3e-02 | 2.60e+00 +/- 2.7e-02 |
test_objective_grad_ripple | +0.01 +/- 1.91 | +3.03e-04 +/- 8.97e-02 | 4.69e+00 +/- 5.6e-02 | 4.69e+00 +/- 7.0e-02 |
test_build_transform_fft_lowres | -0.23 +/- 2.75 | -1.25e-03 +/- 1.52e-02 | 5.53e-01 +/- 1.1e-02 | 5.54e-01 +/- 1.1e-02 |
test_equilibrium_init_medres | +0.02 +/- 0.82 | +9.65e-04 +/- 3.97e-02 | 4.82e+00 +/- 2.4e-02 | 4.82e+00 +/- 3.1e-02 |
test_equilibrium_init_highres | +1.45 +/- 1.95 | +7.86e-02 +/- 1.06e-01 | 5.52e+00 +/- 1.0e-01 | 5.44e+00 +/- 3.0e-02 |
test_objective_compile_dshape_current | +0.03 +/- 0.94 | +1.15e-03 +/- 3.12e-02 | 3.31e+00 +/- 1.8e-02 | 3.31e+00 +/- 2.6e-02 |
test_objective_compute_dshape_current | -3.55 +/- 6.41 | -2.79e-05 +/- 5.04e-05 | 7.59e-04 +/- 2.3e-05 | 7.87e-04 +/- 4.5e-05 |
test_objective_jac_dshape_current | -0.37 +/- 16.20 | -1.19e-04 +/- 5.22e-03 | 3.21e-02 +/- 2.8e-03 | 3.22e-02 +/- 4.4e-03 |
test_perturb_2 | -1.46 +/- 2.16 | -2.49e-01 +/- 3.68e-01 | 1.68e+01 +/- 1.4e-01 | 1.70e+01 +/- 3.4e-01 |
test_proximal_jac_atf_with_eq_update | -0.73 +/- 0.72 | -9.89e-02 +/- 9.76e-02 | 1.35e+01 +/- 9.3e-02 | 1.36e+01 +/- 3.0e-02 |
test_proximal_freeb_jac | -0.24 +/- 1.67 | -1.20e-02 +/- 8.36e-02 | 4.99e+00 +/- 4.9e-02 | 5.00e+00 +/- 6.8e-02 |
test_solve_fixed_iter_compiled | -0.32 +/- 1.20 | -5.38e-02 +/- 2.05e-01 | 1.70e+01 +/- 5.8e-02 | 1.71e+01 +/- 2.0e-01 |
test_LinearConstraintProjection_build | -2.05 +/- 3.27 | -1.74e-01 +/- 2.79e-01 | 8.34e+00 +/- 9.8e-02 | 8.51e+00 +/- 2.6e-01 |
test_objective_compute_ripple_spline | +0.14 +/- 2.34 | +4.18e-04 +/- 6.89e-03 | 2.94e-01 +/- 5.1e-03 | 2.94e-01 +/- 4.7e-03 |
test_objective_grad_ripple_spline | -0.24 +/- 3.24 | -2.70e-03 +/- 3.60e-02 | 1.11e+00 +/- 1.7e-02 | 1.11e+00 +/- 3.1e-02 |G C p c b n W e t b d s t t i a |
tests/test_magnetic_fields.py
Outdated
|
|
||
| # check if it is jittable | ||
| r, z = jit(field_line_integrate)(r0, z0, phis, field) | ||
| r, z = jit(fun0)(r0, z0, field) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think you don't want to jit fun0 here if you're trying to check recompilation. This should pass on master I think? since its only caching based on r0, z0 and field.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the test, but basically I need to close over some of the variables that are not hashable (declaring static doesn't work). One option could be writing a dummy objective, and calling compute and grad, because basically what I care is not recompiling the objective function.
dpanici
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still reviewing but add bs_chunk_size to retain that functionality
Includes some of the changes realized during PlasmaControl#1820 - The lambda function definition in the `field_line_integrate` can cause recompilation if one uses it in an objective. This PR moves the lambda function out, passing all the arguments in `args`. - Uses `vmap_chunked` instead of `jnp.vectorize` - Adds more arguments of diffrax to docs (will be useful for objectives) - Renames `maxsteps` to `max_steps` and updates the docs - Bumps the minimum `diffrax` version to 0.6.0 (`Event` is added in that release) Resolves PlasmaControl#1759
Includes some of the changes realized during #1820
field_line_integratecan cause recompilation if one uses it in an objective. This PR moves the lambda function out, passing all the arguments inargs.vmap_chunkedinstead ofjnp.vectorizemaxstepstomax_stepsand updates the docsdiffraxversion to 0.6.0 (Eventis added in that release)Resolves #1759