Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/api/progress_meter.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Progress meters

Progress meters display how far a solve has progressed. They are passed as the `progress_meter` argument to [`optimistix.minimise`][], [`optimistix.least_squares`][], [`optimistix.root_find`][], and [`optimistix.fixed_point`][].

??? abstract "`optimistix.AbstractProgressMeter`"

::: optimistix.AbstractProgressMeter
options:
members:
- init
- step
- close

::: optimistix.NoProgressMeter
options:
members:
- __init__

::: optimistix.TextProgressMeter
options:
members:
- __init__

::: optimistix.TqdmProgressMeter
options:
members:
- __init__
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ nav:
- Advanced API:
- 'api/norms.md'
- 'api/adjoints.md'
- 'api/progress_meter.md'
- Searches and descents:
- 'api/searches/introduction.md'
- 'api/searches/searches.md'
Expand Down
6 changes: 6 additions & 0 deletions optimistix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
rms_norm as rms_norm,
two_norm as two_norm,
)
from ._progress_meter import (
AbstractProgressMeter as AbstractProgressMeter,
NoProgressMeter as NoProgressMeter,
TextProgressMeter as TextProgressMeter,
TqdmProgressMeter as TqdmProgressMeter,
)
from ._root_find import (
AbstractRootFinder as AbstractRootFinder,
root_find as root_find,
Expand Down
16 changes: 15 additions & 1 deletion optimistix/_custom_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections.abc import Callable
from typing import Any, TypeAlias, TypeVar
from typing import Any, TYPE_CHECKING, TypeAlias, TypeVar

import equinox.internal as eqxi
import numpy as np
from jaxtyping import Array, ArrayLike, Bool, Float, Int, Real


Args: TypeAlias = Any
Expand All @@ -18,3 +20,15 @@
MaybeAuxFn: TypeAlias = Fn[Y, Out, Aux] | NoAuxFn[Y, Out]

sentinel: Any = eqxi.doc_repr(object(), "sentinel")


if TYPE_CHECKING:
BoolScalarLike = bool | Array | np.ndarray
FloatScalarLike = float | Array | np.ndarray
IntScalarLike = int | Array | np.ndarray
RealScalarLike = bool | int | float | Array | np.ndarray
else:
BoolScalarLike = Bool[ArrayLike, ""]
FloatScalarLike = Float[ArrayLike, ""]
IntScalarLike = Int[ArrayLike, ""]
RealScalarLike = Real[ArrayLike, ""]
6 changes: 6 additions & 0 deletions optimistix/_fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ._least_squares import AbstractLeastSquaresSolver
from ._minimise import AbstractMinimiser
from ._misc import inexact_asarray, NoneAux, OutAsArray
from ._progress_meter import AbstractProgressMeter, NoProgressMeter
from ._root_find import AbstractRootFinder, root_find
from ._solution import Solution

Expand Down Expand Up @@ -63,6 +64,7 @@ def fixed_point(
adjoint: AbstractAdjoint = ImplicitAdjoint(),
throw: bool = True,
tags: frozenset[object] = frozenset(),
progress_meter: AbstractProgressMeter = NoProgressMeter(),
) -> Solution[Y, Aux]:
"""Find a fixed-point of a function.

Expand Down Expand Up @@ -102,6 +104,8 @@ def fixed_point(
is, the structure of the matrix `dfn/dy - I`.) Used with
[`optimistix.ImplicitAdjoint`][] to implement the implicit function theorem as
efficiently as possible. Keyword only argument.
- `progress_meter`: A progress meter to display the progress of the solve. Defaults
to [`optimistix.NoProgressMeter`][]. Keyword only argument.

**Returns:**

Expand All @@ -126,6 +130,7 @@ def fixed_point(
max_steps=max_steps,
adjoint=adjoint,
throw=throw,
progress_meter=progress_meter,
)
else:
y0 = jtu.tree_map(inexact_asarray, y0)
Expand All @@ -151,4 +156,5 @@ def fixed_point(
f_struct=f_struct,
aux_struct=aux_struct,
rewrite_fn=_rewrite_fn,
progress_meter=progress_meter,
)
46 changes: 42 additions & 4 deletions optimistix/_iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._adjoint import AbstractAdjoint
from ._custom_types import Aux, Fn, Out, SolverState, Y
from ._misc import tree_allfinite, unwrap_jaxpr, wrap_jaxpr
from ._progress_meter import AbstractProgressMeter, NoProgressMeter
from ._solution import RESULTS, Solution


Expand Down Expand Up @@ -202,6 +203,7 @@ def _iterate(inputs):
f_struct,
aux_struct,
tags,
progress_meter,
while_loop,
) = inputs
del inputs
Expand All @@ -218,27 +220,45 @@ def terminate_and_result(_y, _state):

init_terminate, init_result = terminate_and_result(y0, init_state)
dynamic_init_state, static_state = eqx.partition(init_state, eqx.is_array)
progress_meter_state = progress_meter.init()
init_carry = (
y0,
jnp.array(0),
dynamic_init_state,
init_aux,
init_terminate,
init_result,
progress_meter_state,
)

def cond_fun(carry):
_, _, _, _, terminate, result = carry
_, _, _, _, terminate, result, _ = carry
return jnp.invert(terminate) & (result == RESULTS.successful)

def body_fun(carry):
y, num_steps, dynamic_state, _, _, _ = carry
y, num_steps, dynamic_state, _, _, _, progress_meter_state = carry
state = eqx.combine(static_state, dynamic_state)
new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
new_terminate, new_result = terminate_and_result(y, new_state)
new_dynamic_state, new_static_state = eqx.partition(new_state, eqx.is_array)
assert eqx.tree_equal(static_state, new_static_state) is True
return new_y, num_steps + 1, new_dynamic_state, aux, new_terminate, new_result

# Update progress meter
if max_steps is not None:
progress = (num_steps + 1) / max_steps
else:
progress = jnp.zeros(())
new_progress_meter_state = progress_meter.step(progress_meter_state, progress)

return (
new_y,
num_steps + 1,
new_dynamic_state,
aux,
new_terminate,
new_result,
new_progress_meter_state,
)

final_carry = while_loop(cond_fun, body_fun, init_carry, max_steps=max_steps)
(
Expand All @@ -248,6 +268,7 @@ def body_fun(carry):
final_aux,
terminate,
result,
final_progress_meter_state,
) = final_carry
final_state = eqx.combine(static_state, dynamic_final_state)
result = RESULTS.where(
Expand All @@ -258,6 +279,7 @@ def body_fun(carry):
final_y, final_aux, stats = solver.postprocess(
fn, final_y, final_aux, args, options, final_state, tags, result
)
progress_meter.close(final_progress_meter_state)
return final_y, (
num_steps,
result,
Expand Down Expand Up @@ -288,6 +310,7 @@ def iterative_solve(
f_struct: PyTree[jax.ShapeDtypeStruct],
aux_struct: PyTree[jax.ShapeDtypeStruct],
rewrite_fn: Callable,
progress_meter: AbstractProgressMeter = NoProgressMeter(),
) -> Solution[Y, Aux]:
"""Compute the iterates of an iterative numerical method.

Expand Down Expand Up @@ -337,9 +360,24 @@ def iterative_solve(
"imaginary parts, so that Optimistix sees only real numbers."
)

# Validate that max_steps is set when using a progress meter
if not isinstance(progress_meter, NoProgressMeter) and max_steps is None:
raise ValueError("Progress meters require max_steps to be set")

f_struct = jtu.tree_map(eqxi.Static, f_struct)
aux_struct = jtu.tree_map(eqxi.Static, aux_struct)
inputs = fn, solver, y0, args, options, max_steps, f_struct, aux_struct, tags
inputs = (
fn,
solver,
y0,
args,
options,
max_steps,
f_struct,
aux_struct,
tags,
progress_meter,
)
(
out,
(
Expand Down
6 changes: 6 additions & 0 deletions optimistix/_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._iterate import AbstractIterativeSolver, iterative_solve
from ._minimise import AbstractMinimiser, minimise
from ._misc import inexact_asarray, NoneAux, OutAsArray, sum_squares
from ._progress_meter import AbstractProgressMeter, NoProgressMeter
from ._solution import Solution


Expand Down Expand Up @@ -55,6 +56,7 @@ def least_squares(
adjoint: AbstractAdjoint = ImplicitAdjoint(),
throw: bool = True,
tags: frozenset[object] = frozenset(),
progress_meter: AbstractProgressMeter = NoProgressMeter(),
) -> Solution[Y, Aux]:
r"""Solve a nonlinear least-squares problem.

Expand Down Expand Up @@ -89,6 +91,8 @@ def least_squares(
any structure of the Hessian of `y -> sum(fn(y, args)**2)` with respect to y.
Used with [`optimistix.ImplicitAdjoint`][] to implement the implicit function
theorem as efficiently as possible. Keyword only argument.
- `progress_meter`: A progress meter to display the progress of the solve. Defaults
to [`optimistix.NoProgressMeter`][]. Keyword only argument.

**Returns:**

Expand All @@ -111,6 +115,7 @@ def least_squares(
max_steps=max_steps,
adjoint=adjoint,
throw=throw,
progress_meter=progress_meter,
)
else:
y0 = jtu.tree_map(inexact_asarray, y0)
Expand All @@ -132,4 +137,5 @@ def least_squares(
f_struct=f_struct,
aux_struct=aux_struct,
rewrite_fn=_rewrite_fn,
progress_meter=progress_meter,
)
5 changes: 5 additions & 0 deletions optimistix/_minimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._custom_types import Aux, Fn, MaybeAuxFn, SolverState, Y
from ._iterate import AbstractIterativeSolver, iterative_solve
from ._misc import inexact_asarray, NoneAux, OutAsArray
from ._progress_meter import AbstractProgressMeter, NoProgressMeter
from ._solution import Solution


Expand Down Expand Up @@ -48,6 +49,7 @@ def minimise(
adjoint: AbstractAdjoint = ImplicitAdjoint(),
throw: bool = True,
tags: frozenset[object] = frozenset(),
progress_meter: AbstractProgressMeter = NoProgressMeter(),
) -> Solution[Y, Aux]:
"""Minimise a function.

Expand Down Expand Up @@ -78,6 +80,8 @@ def minimise(
any structure of the Hessian of `fn` with respect to `y`. Used with
[`optimistix.ImplicitAdjoint`][] to implement the implicit function theorem as
efficiently as possible. Keyword only argument.
- `progress_meter`: A progress meter to display the progress of the solve. Defaults
to [`optimistix.NoProgressMeter`][]. Keyword only argument.

**Returns:**

Expand Down Expand Up @@ -116,4 +120,5 @@ def minimise(
aux_struct=aux_struct,
f_struct=f_struct,
rewrite_fn=_rewrite_fn,
progress_meter=progress_meter,
)
Loading
Loading