Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optimistix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@
polak_ribiere as polak_ribiere,
SteepestDescent as SteepestDescent,
)
from ._termination import (
AbstractTermination as AbstractTermination,
CauchyTermination as CauchyTermination,
)


__version__ = importlib.metadata.version("optimistix")
7 changes: 1 addition & 6 deletions optimistix/_iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox import AbstractVar
from jaxtyping import Array, Bool, PyTree, Scalar
from jaxtyping import Array, Bool, PyTree

from ._adjoint import AbstractAdjoint
from ._custom_types import Aux, Fn, Out, SolverState, Y
Expand All @@ -26,10 +25,6 @@
class AbstractIterativeSolver(eqx.Module, Generic[Y, Out, Aux, SolverState]):
"""Abstract base class for all iterative solvers."""

rtol: AbstractVar[float]
atol: AbstractVar[float]
norm: AbstractVar[Callable[[PyTree], Scalar]]

@abc.abstractmethod
def init(
self,
Expand Down
36 changes: 3 additions & 33 deletions optimistix/_misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from collections.abc import Callable
from typing import Any, Literal, overload, TypeVar
from typing import Any, Literal, overload

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -10,7 +10,6 @@
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jaxtyping import Array, ArrayLike, Bool, PyTree, Scalar, ScalarLike
from lineax.internal import (
default_floating_dtype as _default_floating_dtype,
Expand All @@ -21,8 +20,6 @@
two_norm as _two_norm,
)

from ._custom_types import Y


# Make the wrapped function a genuine member of this module.
def _wrap(fn):
Expand Down Expand Up @@ -74,10 +71,8 @@ def tree_full_like(struct: PyTree, fill_value: ArrayLike, allow_static: bool = F
fn = lambda x: jnp.ones(x.shape, x.dtype)
if allow_static:
_fn = fn
fn = (
lambda x: _fn(x)
if eqx.is_array(x) or isinstance(x, jax.ShapeDtypeStruct)
else x
fn = lambda x: (
_fn(x) if eqx.is_array(x) or isinstance(x, jax.ShapeDtypeStruct) else x
)
return jtu.tree_map(fn, struct)

Expand Down Expand Up @@ -285,31 +280,6 @@ def inexact_asarray(x):
return _asarray(dtype, x)


_F = TypeVar("_F")


def cauchy_termination(
rtol: float,
atol: float,
norm: Callable[[PyTree], Scalar],
y: Y,
y_diff: Y,
f: _F,
f_diff: _F,
) -> Bool[Array, ""]:
"""Terminate if there is a small difference in both `y` space and `f` space, as
determined by `rtol` and `atol`.

Specifically, this checks that `y_diff < atol + rtol * y` and
`f_diff < atol + rtol * f_prev`, terminating if both of these are true.
"""
y_scale = (atol + rtol * ω(y).call(jnp.abs)).ω
f_scale = (atol + rtol * ω(f).call(jnp.abs)).ω
y_converged = norm((ω(y_diff).call(jnp.abs) / y_scale**ω).ω) < 1
f_converged = norm((ω(f_diff).call(jnp.abs) / f_scale**ω).ω) < 1
return y_converged & f_converged


class _JaxprEqual:
def __init__(self, jaxpr: jex.core.Jaxpr):
self.jaxpr = jaxpr
Expand Down
68 changes: 0 additions & 68 deletions optimistix/_solver/best_so_far.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,6 @@ class _AbstractBestSoFarSolver(AbstractIterativeSolver, Generic[Y, Out, Aux]):
@abc.abstractmethod
def _to_loss(self, y: Y, f: Out) -> Scalar: ...

@property
def rtol(self):
return self.solver.rtol

@property
def atol(self):
return self.solver.atol

@property
def norm(self): # pyright: ignore[reportIncompatibleMethodOverride]
return self.solver.norm

def init(
self,
fn: Fn[Y, Out, Aux],
Expand Down Expand Up @@ -132,20 +120,6 @@ def __init__(self, solver: AbstractMinimiser[Y, tuple[Scalar, Aux], Any]):
def _to_loss(self, y: Y, f: Scalar) -> Scalar:
return f

# Redeclare these three to work around the Equinox bug fixed here:
# https://github.com/patrick-kidger/equinox/pull/544
@property
def rtol(self):
return self.solver.rtol

@property
def atol(self):
return self.solver.atol

@property
def norm(self): # pyright: ignore[reportIncompatibleMethodOverride]
return self.solver.norm


BestSoFarMinimiser.__init__.__doc__ = """**Arguments:**

Expand All @@ -172,20 +146,6 @@ def __init__(
def _to_loss(self, y: Y, f: Out) -> Scalar:
return sum_squares(f)

# Redeclare these three to work around the Equinox bug fixed here:
# https://github.com/patrick-kidger/equinox/pull/544
@property
def rtol(self):
return self.solver.rtol

@property
def atol(self):
return self.solver.atol

@property
def norm(self): # pyright: ignore[reportIncompatibleMethodOverride]
return self.solver.norm


BestSoFarLeastSquares.__init__.__doc__ = """**Arguments:**

Expand All @@ -210,20 +170,6 @@ def __init__(self, solver: AbstractRootFinder[Y, Out, tuple[Out, Aux], Any]):
def _to_loss(self, y: Y, f: Out) -> Scalar:
return sum_squares(f)

# Redeclare these three to work around the Equinox bug fixed here:
# https://github.com/patrick-kidger/equinox/pull/544
@property
def rtol(self):
return self.solver.rtol

@property
def atol(self):
return self.solver.atol

@property
def norm(self): # pyright: ignore[reportIncompatibleMethodOverride]
return self.solver.norm


BestSoFarRootFinder.__init__.__doc__ = """**Arguments:**

Expand All @@ -248,20 +194,6 @@ def __init__(self, solver: AbstractFixedPointSolver[Y, tuple[Y, Aux], Any]):
def _to_loss(self, y: Y, f: Y) -> Scalar:
return sum_squares((y**ω - f**ω).ω)

# Redeclare these three to work around the Equinox bug fixed here:
# https://github.com/patrick-kidger/equinox/pull/544
@property
def rtol(self):
return self.solver.rtol

@property
def atol(self):
return self.solver.atol

@property
def norm(self): # pyright: ignore[reportIncompatibleMethodOverride]
return self.solver.norm


BestSoFarFixedPoint.__init__.__doc__ = """**Arguments:**

Expand Down
9 changes: 3 additions & 6 deletions optimistix/_solver/dogleg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .._root_find import AbstractRootFinder, root_find
from .._search import AbstractDescent, FunctionInfo
from .._solution import RESULTS
from .._termination import CauchyTermination
from .bisection import Bisection
from .gauss_newton import AbstractGaussNewton, newton_step
from .trust_region import ClassicalTrustRegion
Expand Down Expand Up @@ -235,11 +236,9 @@ class Dogleg(AbstractGaussNewton[Y, Out, Aux]):
a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation.
"""

rtol: float
atol: float
norm: Callable[[PyTree], Scalar]
descent: DoglegDescent[Y]
search: ClassicalTrustRegion[Y]
termination: CauchyTermination[Y]
verbose: Callable[..., None]

def __init__(
Expand All @@ -253,9 +252,7 @@ def __init__(
# We don't expose root_finder to the default API for Dogleg because
# we assume the `trust_region_norm` norm is `two_norm`, which has
# an analytic formula for the intersection with the dogleg path.
self.rtol = rtol
self.atol = atol
self.norm = norm
self.termination = CauchyTermination(rtol=rtol, atol=atol, norm=norm)
self.descent = DoglegDescent(linear_solver=linear_solver)
self.search = ClassicalTrustRegion()
self.verbose = default_verbose(verbose)
Expand Down
35 changes: 15 additions & 20 deletions optimistix/_solver/gauss_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .._custom_types import Args, Aux, DescentState, Fn, Out, SearchState, Y
from .._least_squares import AbstractLeastSquaresSolver
from .._misc import (
cauchy_termination,
default_verbose,
filter_cond,
max_norm,
Expand All @@ -26,13 +25,16 @@
FunctionInfo,
)
from .._solution import RESULTS
from .._termination import AbstractTermination, CauchyTermination
from .learning_rate import LearningRate


def newton_step(
f_info: FunctionInfo.EvalGradHessian
| FunctionInfo.EvalGradHessianInv
| FunctionInfo.ResidualJac,
f_info: (
FunctionInfo.EvalGradHessian
| FunctionInfo.EvalGradHessianInv
| FunctionInfo.ResidualJac
),
linear_solver: lx.AbstractLinearSolver,
) -> tuple[PyTree[Array], RESULTS]:
"""Compute a Newton step.
Expand Down Expand Up @@ -112,9 +114,11 @@ def init(self, y: Y, f_info_struct: FunctionInfo) -> _NewtonDescentState:
def query(
self,
y: Y,
f_info: FunctionInfo.EvalGradHessian
| FunctionInfo.EvalGradHessianInv
| FunctionInfo.ResidualJac,
f_info: (
FunctionInfo.EvalGradHessian
| FunctionInfo.EvalGradHessianInv
| FunctionInfo.ResidualJac
),
state: _NewtonDescentState,
) -> _NewtonDescentState:
del state
Expand Down Expand Up @@ -204,9 +208,7 @@ class AbstractGaussNewton(AbstractLeastSquaresSolver[Y, Out, Aux, _GaussNewtonSt
a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation.
"""

rtol: AbstractVar[float]
atol: AbstractVar[float]
norm: AbstractVar[Callable[[PyTree], Scalar]]
termination: AbstractVar[AbstractTermination[Y]]
descent: AbstractVar[AbstractDescent[Y, FunctionInfo.ResidualJac, Any]]
search: AbstractVar[
AbstractSearch[Y, FunctionInfo.ResidualJac, FunctionInfo.ResidualJac, Any]
Expand Down Expand Up @@ -277,10 +279,7 @@ def accepted(descent_state):
descent_state = self.descent.query(state.y_eval, f_eval_info, descent_state)
y_diff = (state.y_eval**ω - y**ω).ω
f_diff = (f_eval_info.residual**ω - state.f_info.residual**ω).ω
terminate = cauchy_termination(
self.rtol,
self.atol,
self.norm,
terminate = self.termination(
state.y_eval,
y_diff,
f_eval_info.residual,
Expand Down Expand Up @@ -372,11 +371,9 @@ class GaussNewton(AbstractGaussNewton[Y, Out, Aux]):
a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation.
"""

rtol: float
atol: float
norm: Callable[[PyTree], Scalar]
descent: NewtonDescent[Y]
search: LearningRate[Y]
termination: CauchyTermination[Y]
verbose: Callable[..., None]

def __init__(
Expand All @@ -387,11 +384,9 @@ def __init__(
linear_solver: lx.AbstractLinearSolver = lx.AutoLinearSolver(well_posed=None),
verbose: bool | Callable[..., None] = False,
):
self.rtol = rtol
self.atol = atol
self.norm = norm
self.descent = NewtonDescent(linear_solver=linear_solver)
self.search = LearningRate(1.0)
self.termination = CauchyTermination(rtol=rtol, atol=atol, norm=norm)
self.verbose = default_verbose(verbose)


Expand Down
20 changes: 9 additions & 11 deletions optimistix/_solver/golden.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
from collections.abc import Callable
from typing import Any, ClassVar
from typing import Any

import equinox as eqx
import jax
Expand All @@ -10,8 +9,9 @@

from .._custom_types import Aux, Fn
from .._minimise import AbstractMinimiser
from .._misc import cauchy_termination, tree_where
from .._misc import tree_where
from .._solution import RESULTS
from .._termination import CauchyTermination


class _GoldenSearchState(eqx.Module):
Expand Down Expand Up @@ -47,10 +47,11 @@ class GoldenSearch(AbstractMinimiser[Float[Array, ""], Aux, _GoldenSearchState])
between interval segments is always maintained.
"""

rtol: float
atol: float
# All norms are the same for scalars.
norm: ClassVar[Callable[[PyTree], Float[Array, ""]]] = jnp.abs
termination: CauchyTermination

def __init__(self, rtol: float, atol: float):
# All norms are the same for scalars.
self.termination = CauchyTermination(rtol, atol, norm=jnp.abs)

def init(
self,
Expand Down Expand Up @@ -115,10 +116,7 @@ def step(
# since that is always the point closest to the current `y_`.
y_diff = state.middle - y_
f_diff = state.f_middle - f
terminate = cauchy_termination(
self.rtol,
self.atol,
jnp.abs,
terminate = self.termination(
state.middle,
y_diff,
state.f_middle,
Expand Down
Loading
Loading