Flexible termination criteria: an AbstractTermination interface#214
Flexible termination criteria: an AbstractTermination interface#214michael-0brien wants to merge 4 commits intopatrick-kidger:mainfrom
AbstractTermination interface#214Conversation
|
Thanks for this! So:
(Don't forget that you can always already override the |
The idea is to replace lines like: terminate = cauchy_termination(
self.rtol, self.atol, self.norm, state.y_eval, y_diff, f_eval, f_diff
)with: terminate = self.termination(state.y_eval, y_diff, f_eval, f_diff
)I keep API the same, so for example the self.atol = atol
self.rtol = rtol
self.norm = normwith: self.termination = CauchyTermination(rtol, atol, norm)The essence of this is to indeed factor out the terminate method, but in practice it is slightly challenging to do this. Consider the AbstractGaussNewton.terminate. Here the termination was evaluated during the step, and the terminate method forwards along the value of the termination stored in the solver state. There are a few examples like this.
The issue is the AbstractIterativeSolver interface declares rtol, atol, and norm as AbstractVars. I think it should be relatively straightforward to remove this, but there are still challenges (for example, I still need to fix the code where root finders are converted to other problems). I wouldn’t be surprised if it were necessary to replace it with another AbstractVar or two somewhere, but I haven’t quite sorted this out yet. Note that for this change, minimisers and least squares solvers would get the new interface, but I haven’t messed with things like the FixedPointIteration class, which still has the rtol, atol, and norm directly as fields. There are a few other classes that would maintain these as fields, but it would no longer be abstract API. Philosophically, this makes sense to me; I think it is too hard to unite all problems with this API. There are a few other things to potentially discuss as well—for example, I could see there being a better naming pattern for the new interface. |
|
Just wanted to weigh in that this does sound really useful in principle as could help address #165 and allow for super cheap termination routines that just always return False so I can deterministically run for a certain number of steps. |
|
Just took an every so slightly closer look at this, I really like the idea, I would lean towards trying to support this with root finders too, particularly NewtonChord, and offer the HairerWanner termination as its own termination class. Offering one more standard termination class where it only checks for f convergence and not y might also be desirable in some cases where performance matters more than accuracy/robustness (of course just leaving it to user's to implement themselves to prevent offering a footgun seems sensible). If you need a hand with Newton Chord I'd be happy to help as I've just been getting very deep with that this week. |
|
@jpbrodrick89 this sounds reasonable to me! If I had to guess I think this PR is probably most likely to be pushed through if a comprehensive alternative is offered compared to the current model for terminations, so extending this to root finders sounds great. I could use the help as I don’t work on root finders. If you also want to take a whack at fixing where root finds are converted to minimizations and least squares solves please go for it as well! |
Keep in mind with this it can be tricky what ends up being “cheap” termination at the end of the day. This PR would only change termination criteria, not the JAX control flow in a given iterative solver (which is expensive on GPU). Unless the compiler figures out it should remove that control flow as well, I suspect there won’t necessarily be performance increases. |
|
@patrick-kidger @johannahaffner when you have a moment could you help me come to a solution on this? I can work on this but want to make sure it’s the right direction, and if not what would be. There are two approaches to resolution in my mind:
I would lean towards 1. because at the least optimistix should support the ability to have different norms and tolerances for y and f as this is quickly broken by many problems, but curious if there is a feasible path for 2. as well. |
|
To present an alternative to the current version of things, it would be possible to preserve class AbstractTerminationCriteria(eqx.Module):
atol: eqx.AbstractVar[float] # these are forwarded to the `AbstractIterativeSolver` API, when appropriate
rtol: eqx.AbstractVar[float]
norm: eqx.AbstractVar[Callable]
@abc.abstractmethod
def __call__(...):
...
class CustomTerminationCriteria1(AbstractTerminationCriteria):
atol: float # atol and rtol are used for function tolerance
rtol: float
yatol: PyTree[float] # pytree-valued atol and rtol for parameter tolerance
yrtol: PyTree[float]
norm: Callable
def __call__(...):
...
class CustomTerminationCriteria2(AbstractTerminationCriteria):
atol: float
rtol: float
norm: Callable
def __call__(...):
# atol and rtol only used for function tolerance; no parameter tolerance considered
...
Under this idea I am still not sure how to handle the issue of problem conversion (i.e. using a minimiser to solve a least squares problem). In either case this is the only thing that breaks the proposed API. I personally think this is a bit messier, but it is less API breaking in case users are accessing |
|
Apologies for the lack of update on my side, the reason I've not attempted this for root finders yet is we're still trying to determine the correct non-Cauchy termination conditions for Newton. |
|
@jpbrodrick89 no problem! There's no rush, I mostly want to discuss direction so I can start using this PR in my own fork. I also quickly want to correct myself:
Looking through the code a little bit more closely, actually root finder problem conversion (i.e. using least squares solver or a minimiser to solve a root find) is the only case where this PR breaks. This is good news actually! I bet we can find a simple way to address this. Further, the current method of handling terminations here looks slightly hacky so I bet there's an elegant solution. |
Welp, brace yourself: here's an initial attempt for addressing #182. I propose creating an interface for handling terminations:
I've put together a very rough draft with this proposal, where tests are running for minimizers and least squares solvers. Halfway into working on this I realized that
rtolandatoland pretty baked into the code and it's a pretty large refactor. However, I think this would be a significant enhancement tooptimistixas I don't think that changing termination criteria are edge cases (for example,scipy.optimize.least_squareshas different tolerances and norms for function vs. parameter convergence). This is particularly needed for messy problems on noisy scientific data. There are a few current issues implementing this in practice:atol,rtol, andnormfrom theAbstractIterativeSolverinterface. This could potentially be replaced by theAbstractTerminationinterface, but I didn't want to do this without discussion (also, it is a little out of my depth).optimistixwill be able to take advantage of the new interface. It is particularly easy for users to do this for cases where implementing a new solver requires only writing an__init__(i.e. the mixing and matching approach), but it is not-so-elegant to take advantage of the new changes for cases where theAbstractMinimiser, etc are subclassed directly to a concrete class (e.g. in the case of theOptaxMinimiser). In these cases, the user can resort to a tree_at call. This isn’t so bad but would be good to talk about what the right approach is.