Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions optimistix/_solver/newton_chord.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ def init(
init_later_state = self.linear_solver.init(jac, options={})
init_later_state = lax.stop_gradient(init_later_state)
linear_state = (jac, init_later_state)
if self.cauchy_termination:
if self._is_newton and self.cauchy_termination:
# For Newton Cauchy convergence y usually bottleneck so no point syncing f
f_val = tree_full_like(f_struct, jnp.inf)
else:
f_val = None
# In all other cases syncing f can allow earlier exit
f_val, _ = fn(y, args)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this implies another compilation of fn, which is generally undesirable.

Staring at the suggested implementation of step it looks like fn can actually be compiled again then too; this might be quite a large pessimization?

(Don't forget that in some cases, fn is not a small piece of algebra, but actually a whole diffrax.diffeqsolve or similar. Compilation times matter!)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this. Aren't y and args the same shape/dtype/tree structure as when called in step with new_y and args. If so and fn is the same jitted function, shouldn't it already be compiled? What triggers the additional compilation?

Also, for a diffrax sim in particular isn't compilation largely made up of the time for compiling a single step? Vector field is guaranteed to always be the same funciton, and diffrax errors if pytree structure changes and causes recompilation.


dtype = tree_dtype(f_struct)
return _NewtonChordState(
f=f_val,
Expand All @@ -112,20 +115,27 @@ def step(
lower = options.get("lower")
upper = options.get("upper")
del options
if self._is_newton:
if self._is_newton and self.cauchy_termination:
fx, lin_fn, aux = jax.linearize(lambda _y: fn(_y, args), y, has_aux=True)
jac = lx.FunctionLinearOperator(
lin_fn, jax.eval_shape(lambda: y), tags=tags
)
sol = lx.linear_solve(jac, fx, self.linear_solver, throw=False)
else:
fx, aux = fn(y, args)
jac, linear_state = state.linear_state # pyright: ignore
linear_state = lax.stop_gradient(linear_state)
sol = lx.linear_solve(
jac, fx, self.linear_solver, state=linear_state, throw=False
)
fx = state.f
if self._is_newton:
# No need to use jax.linearize as we already have f from state
jac = lx.JacobianLinearOperator(_NoAux(fn), y, args, tags=tags)
sol = lx.linear_solve(jac, fx, self.linear_solver, throw=False)
else:
jac, linear_state = state.linear_state # pyright: ignore
linear_state = lax.stop_gradient(linear_state)
sol = lx.linear_solve(
jac, fx, self.linear_solver, state=linear_state, throw=False
)

diff = sol.value

new_y = (y**ω - diff**ω).ω
if lower is not None:
new_y = jtu.tree_map(lambda a, b: jnp.clip(a, min=b), new_y, lower)
Expand All @@ -136,10 +146,26 @@ def step(
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
with jax.numpy_dtype_promotion("standard"):
diffsize = self.norm((diff**ω / scale**ω).ω)
if self.cauchy_termination:

if self._is_newton and self.cauchy_termination:
# Cauchy Newton: store lagged f, syncing fx is unlikely to reduce iterations
f_val = fx
elif not self.cauchy_termination:
# Non-Cauchy (Newton or Chord): capped at 2 iterations, only cache after first step
_, aux_shape = jax.eval_shape(fn, new_y, args)

def _skip(y):
return fx, jtu.tree_map(
lambda s: jnp.zeros(s.shape, s.dtype), aux_shape
)

f_val, aux = lax.cond(
state.step == 0, lambda new_y: fn(new_y, args), _skip, new_y
)
else:
f_val = None
# Cauchy Chord: sync fx
f_val, aux = fn(new_y, args)

new_state = _NewtonChordState(
f=f_val,
linear_state=state.linear_state,
Expand Down Expand Up @@ -177,16 +203,25 @@ def terminate(
)
terminate_result = RESULTS.successful
else:
# TODO(kidger): perform only one iteration when solving a linear system!
at_least_two = state.step >= 2
two = state.step >= 2
rate = state.diffsize / state.diffsize_prev
factor = state.diffsize * rate / (1 - rate)
small = _small(state.diffsize)
diverged = _diverged(rate)
converged = _converged(factor, self.kappa)
terminate = at_least_two & (small | diverged | converged)

# allow 1-step exit for linear problems
def _check_f_small(f):
with jax.numpy_dtype_promotion("standard"):
return self.norm((ω(f).call(jnp.abs) / self.atol).ω) < 1

f_small = lax.cond(
state.step == 1, _check_f_small, lambda _: jnp.array(False), state.f
)

terminate = f_small | two
terminate_result = RESULTS.where(
at_least_two & jnp.invert(small) & (diverged | jnp.invert(converged)),
two & jnp.invert(small) & (diverged | jnp.invert(converged)),
RESULTS.nonlinear_divergence,
RESULTS.successful,
)
Expand Down