Skip to content

Support true second order methods#225

Open
jpbrodrick89 wants to merge 4 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/second-order
Open

Support true second order methods#225
jpbrodrick89 wants to merge 4 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/second-order

Conversation

@jpbrodrick89
Copy link
Copy Markdown

@jpbrodrick89 jpbrodrick89 commented Mar 18, 2026

Another (overrunning) weekend project, inspired by realising lineax doesn't really need a HessianLinearOperator. This is a giant PR so I fully appreciate it may take quite some time to review, iterate and merge this (assuming this is something you want to support in optimistix, understand if not). Consider this mostly a proof of concept to show what's possible, I am very open to feedback and opinions on API and design. Note this was written with a lot of LLM support and I didn't want to overly finesse the docs until final architecture is agreed upon and there might be some boilerplate/wiring code (especially in TruncatedCG) that needs further simplification/refactoring.

Note this was motivated by trying to come up with good examples for a Hessian linear operator on a lineax issue discussion. The key behind everything is the _make_hessian_f_info function everything else is just managing negative curvature, trust regions, performance, adding linear solvers to make sure this isn't a damp squib etc. Note that I never use JacobianLinearOperator but FunctionLinearOperator as I need to use jax.linearize to access the gradient value. Therefore a hessian_linear_operator would not actually be used here, the main use case for Hessians. We could of course just have the below function as helper in lineax grad_and_hessian_op or something but might be tricky to document and explain clearly.

def _grad_hessian(
    fn: Fn[Y, Scalar, Aux],
    y: Y,
    args: PyTree,
    tags: frozenset[object],
    *,
    autodiff_mode: str = "bwd",
) -> tuple[Y, lx.FunctionLinearOperator]:
    """Return ``(grad, hessian_operator)`` at ``y``."""
    if autodiff_mode == "bwd":
        grad_fn = jax.grad(_NoAux(fn))
    else:
        grad_fn = jax.jacfwd(_NoAux(fn))
    grad, hvp_fn = jax.linearize(lambda _y: grad_fn(_y, args), y)
    hessian = lx.FunctionLinearOperator(
        hvp_fn, jax.eval_shape(lambda: y), frozenset({lx.symmetric_tag}) | tags
    )
    return grad, hessian

I believe this would make optimistix the first jax library with true jittable second-order methods, (pretty sure optax doesn't have any and jaxopt only has a non-accelerated wrapper of scipy.optimize.minimize).

API

This PR offers two main points of entry:

LineSearchNewton

Uses NewtonDescent and BacktrackingArmijo linesearch with exact Hessian operator, uses steepest descent in regions of negative curvature.

TrustNewton

Uses ClassicalTrustRegion with either the new SteihaugCGDescent (to replicate scipy's trust-ncg) or IndirectDampedNewtonDescent (akin to scipy's trust-exact)

We also offer AbstractNewtonMinimiser to allow users to thread custom descents and searches (e.g. allowing a DoglegNewton like DoglegBFGS).

It also introduces a new linear solver TruncatedCG that detects negative curvature and is aware of trust regions. It can be used directly in LineSearchNewton (with linear_solver=... coming close to scipy's newton-cg) or indirectly by TrustNewton with use_steihaug=True.

Summary of scipy mapping

Newton-CG -> LineSearchNewton(linear_solver=TruncatedCG()) (Gaps: scipy uses Wolfe linesearch we use backtracking Armijo)

trust-ncg -> TrustNewton(use_steihaug=True) (pretty faithful implementation works even for non-SPD)

‘trust-exact’ -> Needs further development to properly support More-Sorensen but for SPD cases solver=TrustNewton(linear_solver=lx.Cholesky()), tags=frozenset({lx.positive_semidefinite_tag}) should work pretty well

‘trust-krylov’ – Not supported yet requires even more complex linear solvers and wiring I believe

Design

Created a new _AbstractNewtonBase which acts as a parent for _AbstractQuasiNewton and the new _AbstractNewtonMinimiser, all the child classes need to do is define init and _prepare_step.

Performance

I have not bechmarked runtime/compile time but iteration wise these solvers beat all existing minimisers but typically lose against least square solvers. For a quadratic bowl all these solvers except for those that use TruncatedCG complete in 3 iterations (should be 1 but optimistix requires two steps for confirming Cauchy convergence but then only checks this at the beginning of the third step, we can experiment with pure gradient-based termination to improve if you're interested) where the best quasi-Newton solver (LBFGS) takes 48. TruncatedCG takes 5 iterations. For Beale/Himmelblau they converge in 6/8 iterations (although Steihaug takes 10 for Himmelblau) against 14/13 with the best quasi-Newton solver.

For completeness here are the *minimiser comparisons:

Solver                          bowl  matyas   beale  himmelblau  sq_minus_one  glob_convex  glob_convex_far
------------------------------------------------------------------------------------------------------------
LineSearchNewton(Cholesky)         3       3       6           8             3            4               14
LineSearchNewton(TruncCG)          5       3       6           8             3            4               14
TrustNewton(Cholesky)              3       3       6           8             3            4               14
TrustNewton(Steihaug)              5       6       6          10             3            4               14
BFGS                              98       4      14          18             4            5               44
LBFGS                             48       4      15          13             4            5               26

And the least square comparisons:

Solver                        diag_bowl  rosenbrock_a  rosenbrock_b
-------------------------------------------------------------------
LineSearchNewton(Cholesky)           15            14            23
TrustNewton(Cholesky)                15            17            27
TrustNewton(Steihaug)                17            20            24
GaussNewton                          10             4             4
LevenbergMarquardt                  138            12            30
IndirectLevenbergMarquardt           10            14            19
BFGS                                 78            33            42
LBFGS                                31            30            43

Questions

Assuming we want to support second-order methods, these questions come to mind first as things to iron out to get this pull ready.

  1. What is your preferred API Surface

a) Make _AbstractNewtonMinimiser the only public API (with a better name) users can just plug and play with descents and searches.
b) (This PR) Concretise on search type only: LineSearchNewton and TrustNewton
c) Concretise very specific solvers (e.g. NewtonCG, TrustNCG, TrustExact) and maybe add even more Abstract classes

If we go with b) there is currently a bit of a gotcha, use_steihaug=True ALWAYS uses TruncatedCG and requires linear_solver to be sentinel _UNSET and errors if it isn't. As such having the linear_solver argument could be a bit confusing and separating Steihaug out as TrustNCG (option c) might be sensible.

  1. What are your thoughts on including these as options in optimistix.compat.minimize (including those you already support)
  2. Should we rename _QuasiNewtonState or create a parent class?
  3. Have I made any mistakes with regards to reducing compilation overhead (while the total memory usage of test_minimize.py is now about ~8GB I didn't notice any large jumps or slow downs when running).
  4. Would you rather TruncatedCG in optimistix or lineax?
  5. Cholesky can sometimes succeed by luck in the non-SPD case, should we highlight that as a possibility to users or too much of a footgun?
  6. I've added an iteration benchmark script for convenience, I'm intending to remove it when this PR is ready, let me know if you'd prefer to keep it.

Tags

  • should we EXPLICITLY declare the default tags for ALL _AbstractNewtonBase minimisers to be is_symmetric (should be generally true), or just silently thread it in as I am doing currently in `_grad_hessian
  • damped_newton_step currently only tags the Hessian as SPD when the initial Hessian is tagged as SPD, however, if the root find is successful then it is guaranteed to be SPD in any case because the damping is sufficient to make it SPD. Do you want to handle that case any differently?

Future Linear Solvers (out of scope for this PR)

LineSearchNewton would arguably benefit from a More-Sorensen approach such as ModifiedCholesky or ModifiedLDL, let me know how keen you are in having those in optimistix/lineax. I am currently working on exposing sytrf in jax to support LDL, but it performs pretty poorly on GPU compared to LU.

@jpbrodrick89
Copy link
Copy Markdown
Author

Note I currently don't support backward-over-backward because that would typically mean you need a custom_vjp of a custom_vjp which would I think be fairly rare. However, if this would actually work for a recursive checkpointing diffrax sim then maybe it's worth supporting.

@patrick-kidger
Copy link
Copy Markdown
Owner

Oh wow! This is fairly gigantic 😅 I have to be honest I think this is probably out-of-scope here. I think it's much larger than I can commit to maintaining. (Plus, you are currently sending PRs faster than I can find time to review them :p I am very aware of the stack of Lineax PRs I am slowly working through...)

This kind of thing might find a good home in some kind of 'optimistixtra' (made up a name). Or possibly we should introduce an optimistix.contrib? Though contrib namespaces have been a bit of a mess in other libraries, I'm not sure where I land on whether that should be a thing we introduce.

Over on the technical side, it's also not immediately clear to me how this differs from the existing optx.Newton, which when applied to the gradient-of-a-root becomes a second order optimizer that uses the Hessian + will be matrix-free if its provided linear solver is. Perhaps you're substituting the Hessian with some kind of positive-ish equivalent, to avoid finding non-minimal equilibria?

WDYT?

@jpbrodrick89
Copy link
Copy Markdown
Author

I have to be honest I think this is probably out-of-scope here. I think it's much larger than I can commit to maintaining.

I think that's a pretty fair take 😅 I guess I got a bit carried away showing how much ground this can potentially cover rather than focussing on a digestible incremental approach, but in another sense it's probably healthy seeing a preview of a minimal working API before committing to it blindly before realising it's beyond maintenance appetite.

That said, if you do change your mind on this I'd like to allay any pressure on timelines, with my full understanding that this would be a >(>)6 month project that we would just chew slowly through as we have time. (I have no urgent need for this in my own work right now.)

it's also not immediately clear to me how this differs from the existing optx.Newton,

The idea of just passing a root finder to minimize is also what crossed my mind first, but it turns out optx.Newton is much too simple to handle the sophistication required to match industry standards. The key features is the ability to handle line searches, trust regions and termination conditions based on the function not it's gradient which is already supported perfectly in AbstractQuasiNewton. Also, SteihaugCG needs awareness of the trust region. Essentially I had to reinvent the wheel a lot less this way and could provide a very familiar API to users.

This kind of thing might find a good home in some kind of 'optimistixtra' (made up a name). Or possibly we should introduce an optimistix.contrib?

While I am very keen to eventually create a lineaxtra as I pull together enough useful features, I don't have a burning desire to add an optimistixtra as a separate repo to support all on my own right now (not least because writing this PR made me realise how little I know about optimisation algorithms).

One alternative, to .contrib is .experimental meaning you have arguably more say on what goes in there but less commitment to making sure it works perfectly. But arguably in such cases it might just be better to have more worked examples.

As such, before we move on it's probably worthwhile exploring "what would be the MVP to allow power users to use the existing API to more easily roll their own custom second order optimisers and provide a minimal example of how to do this that minimises maintenance" rather than "what this could potentially be given many months of work and added maintenance".

These are potential options in order of increasing complexity/maintenance. Each change to the codebase could be done incrementally in a separate digestible PR.

  1. Just add negative curvature fallback to NewtonDescent (lines 146-148) and then (not benefitting from the former) show how to pass a gradient function to optx.Newton with a globally convex function warning about the limitations and pitfalls of this pointing to this (to be closed) PR for a more extensive example of what is required.
  2. Create AbstractNewtonBase and refactor AbstractQuasiNewton, then give an example how to implement TrustNewton with a direct solver.
  3. Create AbstractNewtonMinimiser meaning that users just need to create a concrete version of the class, allowing for a very minimal example for implementing TrustNewton.
  4. Add TruncatedCG linear solver and SteighaugCGDescent and given an example for implementing TrustNCG.
  5. this PR more or less as is with all public API into concrete classes.

Personally, I'd probably lean towards 2 being a sensible stopping point but I'd still be very content with 1. Either of these would mean I could theoretically add TruncatedCG to lineaxtra and then add another example to optimistix on how to implement SteighaugCGDescent (basically 3 lines of algorithm, all the rest boilerplate/wiring).

And of course if you don't think it's valuable to add such an example at all and keep things as are I am also fully sympathetic to that viewpoint and would not be overly disappointed.

Eisenstat-Walker

Note how TruncatedCG allows rtol and atol to be updated in the linear_solve call through options so that they can be updated each step. What do you think of that as a pattern rather than reconstructing the linear solver each step as suggested in #208? Should we mirror it in other linear solvers in lineax? Hypothetically we could give Descents the ability to customise tolerance schemes through an options updater or similar but maybe just providing an example with solver.step is the more maintainable approach.

@patrick-kidger
Copy link
Copy Markdown
Owner

"what would be the MVP to allow power users to use the existing API to more easily roll their own custom second order optimisers and provide a minimal example of how to do this that minimises maintenance"

I like this characterisation a lot!

On your hierarchy of options: as a nit, I confess I don't love the name 'AbstractNewtonBase`, what with 'Base' actually just being a synonym for 'Abstract'. :p

More seriously, the current iterate-over-steps+accept/reject is something we kind of have a bit copy-paste between three different solvers at the moment (AbstractQuasiNewton, AbstractBFGS, AbstractGradientDescent). I'd lean towards either copy-pasting this again for Newton, or finding a way to unify these and only having this code appear once. If possible, that is... I might be missing something in why I did it this was originally.

(NB, I'm also conscious of one other mistake I made here — our 'latest' iterate is actually held in state.y_eval and not y, and this means that we don't compose well with stuff like BestSoFar solvers.)


On rtol/atol, I lean towards not making these dynamic, as the dominant use case is actually the opposite way around I think: see for example how Diffrax allows you to put rtol/atol just on the step size controller, and letting everything else inherit them from there. Fiddling with these is kind of annoying. So if dynamic stuff is needed I'd lean towards using eqx.tree_at instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants