Support true second order methods#225
Support true second order methods#225jpbrodrick89 wants to merge 4 commits intopatrick-kidger:mainfrom
Conversation
|
Note I currently don't support backward-over-backward because that would typically mean you need a custom_vjp of a custom_vjp which would I think be fairly rare. However, if this would actually work for a recursive checkpointing diffrax sim then maybe it's worth supporting. |
|
Oh wow! This is fairly gigantic 😅 I have to be honest I think this is probably out-of-scope here. I think it's much larger than I can commit to maintaining. (Plus, you are currently sending PRs faster than I can find time to review them :p I am very aware of the stack of Lineax PRs I am slowly working through...) This kind of thing might find a good home in some kind of 'optimistixtra' (made up a name). Or possibly we should introduce an Over on the technical side, it's also not immediately clear to me how this differs from the existing WDYT? |
I think that's a pretty fair take 😅 I guess I got a bit carried away showing how much ground this can potentially cover rather than focussing on a digestible incremental approach, but in another sense it's probably healthy seeing a preview of a minimal working API before committing to it blindly before realising it's beyond maintenance appetite. That said, if you do change your mind on this I'd like to allay any pressure on timelines, with my full understanding that this would be a >(>)6 month project that we would just chew slowly through as we have time. (I have no urgent need for this in my own work right now.)
The idea of just passing a root finder to
While I am very keen to eventually create a lineaxtra as I pull together enough useful features, I don't have a burning desire to add an optimistixtra as a separate repo to support all on my own right now (not least because writing this PR made me realise how little I know about optimisation algorithms). One alternative, to As such, before we move on it's probably worthwhile exploring "what would be the MVP to allow power users to use the existing API to more easily roll their own custom second order optimisers and provide a minimal example of how to do this that minimises maintenance" rather than "what this could potentially be given many months of work and added maintenance". These are potential options in order of increasing complexity/maintenance. Each change to the codebase could be done incrementally in a separate digestible PR.
Personally, I'd probably lean towards 2 being a sensible stopping point but I'd still be very content with 1. Either of these would mean I could theoretically add TruncatedCG to lineaxtra and then add another example to optimistix on how to implement SteighaugCGDescent (basically 3 lines of algorithm, all the rest boilerplate/wiring). And of course if you don't think it's valuable to add such an example at all and keep things as are I am also fully sympathetic to that viewpoint and would not be overly disappointed. Eisenstat-Walker Note how TruncatedCG allows rtol and atol to be updated in the |
I like this characterisation a lot! On your hierarchy of options: as a nit, I confess I don't love the name 'AbstractNewtonBase`, what with 'Base' actually just being a synonym for 'Abstract'. :p More seriously, the current iterate-over-steps+accept/reject is something we kind of have a bit copy-paste between three different solvers at the moment (AbstractQuasiNewton, AbstractBFGS, AbstractGradientDescent). I'd lean towards either copy-pasting this again for Newton, or finding a way to unify these and only having this code appear once. If possible, that is... I might be missing something in why I did it this was originally. (NB, I'm also conscious of one other mistake I made here — our 'latest' iterate is actually held in On rtol/atol, I lean towards not making these dynamic, as the dominant use case is actually the opposite way around I think: see for example how Diffrax allows you to put rtol/atol just on the step size controller, and letting everything else inherit them from there. Fiddling with these is kind of annoying. So if dynamic stuff is needed I'd lean towards using |
Another (overrunning) weekend project, inspired by realising lineax doesn't really need a HessianLinearOperator. This is a giant PR so I fully appreciate it may take quite some time to review, iterate and merge this (assuming this is something you want to support in optimistix, understand if not). Consider this mostly a proof of concept to show what's possible, I am very open to feedback and opinions on API and design. Note this was written with a lot of LLM support and I didn't want to overly finesse the docs until final architecture is agreed upon and there might be some boilerplate/wiring code (especially in TruncatedCG) that needs further simplification/refactoring.
Note this was motivated by trying to come up with good examples for a Hessian linear operator on a lineax issue discussion. The key behind everything is the
_make_hessian_f_infofunction everything else is just managing negative curvature, trust regions, performance, adding linear solvers to make sure this isn't a damp squib etc. Note that I never useJacobianLinearOperatorbutFunctionLinearOperatoras I need to usejax.linearizeto access the gradient value. Therefore ahessian_linear_operatorwould not actually be used here, the main use case for Hessians. We could of course just have the below function as helper in lineaxgrad_and_hessian_opor something but might be tricky to document and explain clearly.I believe this would make optimistix the first jax library with true jittable second-order methods, (pretty sure optax doesn't have any and jaxopt only has a non-accelerated wrapper of scipy.optimize.minimize).
API
This PR offers two main points of entry:
LineSearchNewton
Uses NewtonDescent and BacktrackingArmijo linesearch with exact Hessian operator, uses steepest descent in regions of negative curvature.
TrustNewton
Uses ClassicalTrustRegion with either the new
SteihaugCGDescent(to replicate scipy's trust-ncg) orIndirectDampedNewtonDescent(akin to scipy's trust-exact)We also offer
AbstractNewtonMinimiserto allow users to thread custom descents and searches (e.g. allowing a DoglegNewton like DoglegBFGS).It also introduces a new linear solver TruncatedCG that detects negative curvature and is aware of trust regions. It can be used directly in
LineSearchNewton(withlinear_solver=...coming close to scipy's newton-cg) or indirectly byTrustNewtonwithuse_steihaug=True.Summary of scipy mapping
Newton-CG->LineSearchNewton(linear_solver=TruncatedCG())(Gaps: scipy uses Wolfe linesearch we use backtracking Armijo)trust-ncg->TrustNewton(use_steihaug=True)(pretty faithful implementation works even for non-SPD)‘trust-exact’ -> Needs further development to properly support More-Sorensen but for SPD cases
solver=TrustNewton(linear_solver=lx.Cholesky()), tags=frozenset({lx.positive_semidefinite_tag})should work pretty well‘trust-krylov’ – Not supported yet requires even more complex linear solvers and wiring I believe
Design
Created a new _AbstractNewtonBase which acts as a parent for _AbstractQuasiNewton and the new _AbstractNewtonMinimiser, all the child classes need to do is define
initand_prepare_step.Performance
I have not bechmarked runtime/compile time but iteration wise these solvers beat all existing minimisers but typically lose against least square solvers. For a quadratic bowl all these solvers except for those that use TruncatedCG complete in 3 iterations (should be 1 but optimistix requires two steps for confirming Cauchy convergence but then only checks this at the beginning of the third step, we can experiment with pure gradient-based termination to improve if you're interested) where the best quasi-Newton solver (LBFGS) takes 48. TruncatedCG takes 5 iterations. For Beale/Himmelblau they converge in 6/8 iterations (although Steihaug takes 10 for Himmelblau) against 14/13 with the best quasi-Newton solver.
For completeness here are the *minimiser comparisons:
And the least square comparisons:
Questions
Assuming we want to support second-order methods, these questions come to mind first as things to iron out to get this pull ready.
a) Make _AbstractNewtonMinimiser the only public API (with a better name) users can just plug and play with descents and searches.
b) (This PR) Concretise on search type only: LineSearchNewton and TrustNewton
c) Concretise very specific solvers (e.g. NewtonCG, TrustNCG, TrustExact) and maybe add even more Abstract classes
If we go with b) there is currently a bit of a gotcha,
use_steihaug=TrueALWAYS usesTruncatedCGand requireslinear_solverto be sentinel_UNSETand errors if it isn't. As such having thelinear_solverargument could be a bit confusing and separating Steihaug out as TrustNCG (option c) might be sensible.optimistix.compat.minimize(including those you already support)Tags
_AbstractNewtonBaseminimisers to be is_symmetric (should be generally true), or just silently thread it in as I am doing currently in `_grad_hessianFuture Linear Solvers (out of scope for this PR)
LineSearchNewton would arguably benefit from a More-Sorensen approach such as ModifiedCholesky or ModifiedLDL, let me know how keen you are in having those in optimistix/lineax. I am currently working on exposing sytrf in jax to support LDL, but it performs pretty poorly on GPU compared to LU.