Enable earlier exit from Newton-Chord including one step exit for non-Cauchy linear problems#219
Enable earlier exit from Newton-Chord including one step exit for non-Cauchy linear problems#219jpbrodrick89 wants to merge 3 commits intopatrick-kidger:mainfrom
Conversation
|
Let me know if you'd like me to add tests to confirm one step exit for non-Cauchy and linear problems. |
patrick-kidger
left a comment
There was a problem hiding this comment.
Sorry, this has taken a me a long time to get around to 😅
Note that after e9c9abb non-Cauchy NEVER performs more than two steps
This is not intended. I've not tried digging into this but it's probably an issue with the divergence detection being triggered too easily? IIRC scipy and Julia both have this as well (with different thresholds).
I don't have access to Hairer and Wanner
Minor side note, it's diffrax.VeryChord that is the H&W algorithm. The difference to optx.Chord is that this re-uses the initial Jaocbian linearization even between different stages of the Runge-Kutta solver.
I've avoided making too many comments here as it seems that the 'never take more than 2 steps' problem might affect the results here / might need fixing first. Do you have any idea what's causing that?
| else: | ||
| f_val = None | ||
| # In all other cases syncing f can allow earlier exit | ||
| f_val, _ = fn(y, args) |
There was a problem hiding this comment.
Note that this implies another compilation of fn, which is generally undesirable.
Staring at the suggested implementation of step it looks like fn can actually be compiled again then too; this might be quite a large pessimization?
(Don't forget that in some cases, fn is not a small piece of algebra, but actually a whole diffrax.diffeqsolve or similar. Compilation times matter!)
There was a problem hiding this comment.
Not sure I understand this. Aren't y and args the same shape/dtype/tree structure as when called in step with new_y and args. If so and fn is the same jitted function, shouldn't it already be compiled? What triggers the additional compilation?
Also, for a diffrax sim in particular isn't compilation largely made up of the time for compiling a single step? Vector field is guaranteed to always be the same funciton, and diffrax errors if pytree structure changes and causes recompilation.
|
Agree that this needs to wait until after we resolve the Hairer-Wanner behaviour, I had a feeling it was a bug but not a feature but couldn't be sure! 😅 I will continue discussion about that on #221 so as not to pollute this PR unecessarily. |
JiwaniZakir
left a comment
There was a problem hiding this comment.
The _check_f_small helper in terminate only divides by self.atol, but the rest of the termination logic uses a mixed absolute/relative scale — see how scale is computed just above as self.atol + self.rtol * ω(new_y).call(jnp.abs). For problems where the residual magnitude scales with the solution values, this purely absolute check could produce spurious early exits when atol is loose relative to the actual problem scale, or never trigger when atol is tight. It should likely mirror the diffsize/scale pattern.
Additionally, in the non-Cauchy step branch, the _skip lambda captures fx from the outer scope (the current step's residual before the update), so after step 0 state.f will hold a stale value. This is safe only because non-Cauchy is capped at 2 iterations and terminate gates the _check_f_small path on state.step == 1, but this implicit coupling is fragile — a comment explaining why stale f is acceptable here would prevent future confusion.
Finally, _check_f_small is conditioned on state.step == 1, but for the Cauchy Chord path (else at the bottom of step) f_val, aux = fn(new_y, args) is called unconditionally every step, meaning state.f always holds fresh data. It's worth confirming that the f_small gate in terminate is actually reachable for the Cauchy Chord case, or whether it's dead code for that branch.
I've lost count of the amount of times I stared at this code trying to work out how to address the TODO, I decided to finally grit my teeth and give it a go. It took me three attempts going down wrong alleyways and being surprised by lax.cond reducing XLA's CSE and massively slowing performance. I think this attempt works and only has a minuscule negative impact on Newton non-Cauchy for nonlinear problems by replacing
jax.linearizewith a separate primal call andJacobianLinearOperator.Note that after e9c9abb non-Cauchy NEVER performs more than two steps, I don't have access to Hairer and Wanner to confirm whether this is intended or not but I have made this behaviour more explicit and easier to understand. If the current behaviour is incorrect please let me know as it may mean this PR has a proportionally greater negative impact for the non-Cauchy case if more than 2 steps could happen.
The essential idea here is simply to compute f at the END of each step and check whether its small to allow for early exit, saving a linear solve (and a Jacobian materialisation for Newton). However, as we never perform more than two steps for non-Cauchy it is pointless doing the extra function evaluation so I wrap it in a lax.cond. As this only fires once I don't think we're losing out an any material compiler optimisations here but for very expensive f evaluations (especially if these are somehow more expensive than Jacobian materialisation and linear solve) I might be wrong.
Speedup is most noticeable for Newton non-Cauchy and can exceed 25% in the problems I've tested (large dense matrix linear problems) we didn't get the desired 2x speedup because Jacobian materialisation took 3 times longer than the linear solve, and for a linear problem CSE was able to completely eliminate this on the second iteration). After we merge my lineax tridiagonal PR it is likely we can get much closer to the desired 2x speedup. For Chord it is less noticeable as the bottleneck is in solver.init (O(N^3)), but if solver.step (which is O(N^2)) is isolated we still see consistent improvement especially at smaller matrix sizes.
Newton Cauchy is unaffected by this PR as it would only benefit from very ill-posed Jacobians (basically ones that had eigenvalues less than sqrt(2*atol)). In well-posed cases y always converges slower than f so there is no point syncing f and losing out on any benefit jax.linearize might give us.