Skip to content

Commit b930ea6

Browse files
Switched from reverse mode to forward mode where possible.
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment). Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the test I've added fails. I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!
1 parent d9b7ba6 commit b930ea6

9 files changed

Lines changed: 153 additions & 31 deletions

File tree

optimistix/_search.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
from typing import ClassVar, Generic, Type, TypeVar
2727

2828
import equinox as eqx
29+
import jax.numpy as jnp
30+
import jax.tree_util as jtu
2931
import lineax as lx
32+
from equinox.internal import ω
3033
from jaxtyping import Array, Bool, Scalar
3134

3235
from ._custom_types import (
@@ -35,7 +38,7 @@
3538
SearchState,
3639
Y,
3740
)
38-
from ._misc import sum_squares
41+
from ._misc import sum_squares, tree_dot
3942
from ._solution import RESULTS
4043

4144

@@ -89,6 +92,9 @@ class EvalGrad(FunctionInfo, Generic[Y], strict=True):
8992
def as_min(self):
9093
return self.f
9194

95+
def compute_grad_dot(self, y: Y):
96+
return tree_dot(self.grad, y)
97+
9298

9399
# NOT PUBLIC, despite lacking an underscore. This is so pyright gets the name right.
94100
class EvalGradHessian(FunctionInfo, Generic[Y], strict=True):
@@ -104,6 +110,9 @@ class EvalGradHessian(FunctionInfo, Generic[Y], strict=True):
104110
def as_min(self):
105111
return self.f
106112

113+
def compute_grad_dot(self, y: Y):
114+
return tree_dot(self.grad, y)
115+
107116

108117
# NOT PUBLIC, despite lacking an underscore. This is so pyright gets the name right.
109118
class EvalGradHessianInv(FunctionInfo, Generic[Y], strict=True):
@@ -118,6 +127,9 @@ class EvalGradHessianInv(FunctionInfo, Generic[Y], strict=True):
118127
def as_min(self):
119128
return self.f
120129

130+
def compute_grad_dot(self, y: Y):
131+
return tree_dot(self.grad, y)
132+
121133

122134
# NOT PUBLIC, despite lacking an underscore. This is so pyright gets the name right.
123135
class Residual(FunctionInfo, Generic[Out], strict=True):
@@ -144,18 +156,48 @@ class ResidualJac(FunctionInfo, Generic[Y, Out], strict=True):
144156

145157
residual: Out
146158
jac: lx.AbstractLinearOperator
147-
grad: Y
148-
149-
def __init__(self, residual: Out, jac: lx.AbstractLinearOperator):
150-
self.residual = residual
151-
self.jac = jac
152-
# The gradient is used ubiquitously, so compute it once here, so that it can be
153-
# used without recomputation in both the descent and search.
154-
self.grad = jac.transpose().mv(residual)
155159

156160
def as_min(self):
157161
return 0.5 * sum_squares(self.residual)
158162

163+
def compute_grad(self):
164+
# Not precomputed during `__init__` as this may hit reverse-mode autodiff which
165+
# may not be valid.
166+
if any(jnp.iscomplexobj(x) for x in jtu.tree_leaves(self.residual)):
167+
conj_residual = jtu.tree_map(jnp.conj, self.residual)
168+
conj_jac = lx.conj(self.jac)
169+
return (
170+
0.5
171+
* (
172+
self.jac.transpose().mv(conj_residual) ** ω
173+
+ conj_jac.transpose().mv(self.residual) ** ω
174+
)
175+
).ω
176+
else:
177+
return self.jac.transpose().mv(self.residual)
178+
179+
def compute_grad_dot(self, y: Y):
180+
# If `self.jac` is a `lx.JacobianLinearOperator` (or a
181+
# `lx.FunctionLinearOperator` wrapping the result of `jax.linearize`), then
182+
# `grad = jac^T residual`, so that what we want to compute is
183+
# `residual^T jac y`. Doing the reduction in this order means we hit
184+
# forward-mode rather than reverse-mode autodiff.
185+
if any(jnp.iscomplexobj(x) for x in jtu.tree_leaves(self.residual)):
186+
# In this case then actually
187+
# `grad = 0.5 * (jac^T residual^bar + jac^Tbar residual)`.
188+
# all of this.
189+
conj_residual = jtu.tree_map(jnp.conj, self.residual)
190+
conj_jac = lx.conj(self.jac)
191+
return (
192+
0.5
193+
* (
194+
tree_dot(conj_residual, self.jac.mv(y)) ** ω
195+
+ tree_dot(self.residual, conj_jac.mv(y)) ** ω
196+
)
197+
).ω
198+
else:
199+
return tree_dot(self.residual, self.jac.mv(y))
200+
159201

160202
Eval.__qualname__ = "FunctionInfo.Eval"
161203
EvalGrad.__qualname__ = "FunctionInfo.EvalGrad"

optimistix/_solver/backtracking.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
from jaxtyping import Array, Bool, Scalar, ScalarLike
88

99
from .._custom_types import Y
10-
from .._misc import (
11-
tree_dot,
12-
)
1310
from .._search import AbstractSearch, FunctionInfo
1411
from .._solution import RESULTS
1512

@@ -55,7 +52,7 @@ def __post_init__(self):
5552
)
5653

5754
def init(self, y: Y, f_info_struct: _FnInfo) -> _BacktrackingState:
58-
del f_info_struct
55+
del y, f_info_struct
5956
return _BacktrackingState(step_size=jnp.array(self.step_init))
6057

6158
def step(
@@ -67,7 +64,7 @@ def step(
6764
f_eval_info: _FnEvalInfo,
6865
state: _BacktrackingState,
6966
) -> tuple[Scalar, Bool[Array, ""], RESULTS, _BacktrackingState]:
70-
if isinstance(
67+
if not isinstance(
7168
f_info,
7269
(
7370
FunctionInfo.EvalGrad,
@@ -76,16 +73,14 @@ def step(
7673
FunctionInfo.ResidualJac,
7774
),
7875
):
79-
grad = f_info.grad
80-
else:
8176
raise ValueError(
8277
"Cannot use `BacktrackingArmijo` with this solver. This is because "
8378
"`BacktrackingArmijo` requires gradients of the target function, but "
8479
"this solver does not evaluate such gradients."
8580
)
8681

8782
y_diff = (y_eval**ω - y**ω).ω
88-
predicted_reduction = tree_dot(grad, y_diff)
83+
predicted_reduction = f_info.compute_grad_dot(y_diff)
8984
# Terminate when the Armijo condition is satisfied. That is, `fn(y_eval)`
9085
# must do better than its linear approximation:
9186
# `fn(y_eval) < fn(y) + grad•y_diff`

optimistix/_solver/dogleg.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,15 @@ def query(
7272
f_info: Union[FunctionInfo.EvalGradHessian, FunctionInfo.ResidualJac],
7373
state: _DoglegDescentState,
7474
) -> _DoglegDescentState:
75-
del state
75+
del y, state
7676
# Compute `denom = grad^T Hess grad.`
7777
if isinstance(f_info, FunctionInfo.EvalGradHessian):
78-
denom = tree_dot(f_info.grad, f_info.hessian.mv(f_info.grad))
78+
grad = f_info.grad
79+
denom = tree_dot(f_info.grad, f_info.hessian.mv(grad))
7980
elif isinstance(f_info, FunctionInfo.ResidualJac):
8081
# Use Gauss--Newton approximation `Hess ~ J^T J`
81-
denom = sum_squares(f_info.jac.mv(f_info.grad))
82+
grad = f_info.compute_grad()
83+
denom = sum_squares(f_info.jac.mv(grad))
8284
else:
8385
raise ValueError(
8486
"`DoglegDescent` can only be used with least-squares solvers, or "
@@ -88,7 +90,7 @@ def query(
8890
denom_nonzero = denom > jnp.finfo(denom.dtype).eps
8991
safe_denom = jnp.where(denom_nonzero, denom, 1)
9092
# Compute `grad^T grad / (grad^T Hess grad)`
91-
scaling = jnp.where(denom_nonzero, sum_squares(f_info.grad) / safe_denom, 0.0)
93+
scaling = jnp.where(denom_nonzero, sum_squares(grad) / safe_denom, 0.0)
9294
scaling = cast(Array, scaling)
9395

9496
# Downhill towards the bottom of the quadratic basin.
@@ -97,7 +99,7 @@ def query(
9799
newton_norm = self.trust_region_norm(newton_sol)
98100

99101
# Downhill steepest descent.
100-
cauchy = (-scaling * f_info.grad**ω).ω
102+
cauchy = (-scaling * grad**ω).ω
101103
cauchy_norm = self.trust_region_norm(cauchy)
102104

103105
return _DoglegDescentState(

optimistix/_solver/gradient_methods.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,11 @@ def query(
5858
FunctionInfo.EvalGrad,
5959
FunctionInfo.EvalGradHessian,
6060
FunctionInfo.EvalGradHessianInv,
61-
FunctionInfo.ResidualJac,
6261
),
6362
):
6463
grad = f_info.grad
64+
elif isinstance(f_info, FunctionInfo.ResidualJac):
65+
grad = f_info.compute_grad()
6566
else:
6667
raise ValueError(
6768
"Cannot use `SteepestDescent` with this solver. This is because "

optimistix/_solver/nonlinear_cg.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,19 @@ def query(
119119
],
120120
state: _NonlinearCGDescentState,
121121
) -> _NonlinearCGDescentState:
122-
if not isinstance(
122+
del y
123+
if isinstance(
123124
f_info,
124125
(
125126
FunctionInfo.EvalGrad,
126127
FunctionInfo.EvalGradHessian,
127128
FunctionInfo.EvalGradHessianInv,
128-
FunctionInfo.ResidualJac,
129129
),
130130
):
131+
grad = f_info.grad
132+
elif isinstance(f_info, FunctionInfo.ResidualJac):
133+
grad = f_info.compute_grad()
134+
else:
131135
raise ValueError(
132136
"Cannot use `NonlinearCGDescent` with this solver. This is because "
133137
"`NonlinearCGDescent` requires gradients of the target function, but "
@@ -140,16 +144,16 @@ def query(
140144
# Furthermore, the same mechanism handles convergence: once
141145
# `state.{grad, y_diff} = 0`, i.e. our previous step hit a local minima, then
142146
# on this next step we'll again just use gradient descent, and stop.
143-
beta = self.method(f_info.grad, state.grad, state.y_diff)
144-
neg_grad = (-(f_info.grad**ω)).ω
147+
beta = self.method(grad, state.grad, state.y_diff)
148+
neg_grad = (-(grad**ω)).ω
145149
nonlinear_cg_direction = (neg_grad**ω + beta * state.y_diff**ω).ω
146150
# Check if this is a descent direction. Use gradient descent if it isn't.
147151
y_diff = tree_where(
148-
tree_dot(f_info.grad, nonlinear_cg_direction) < 0,
152+
tree_dot(grad, nonlinear_cg_direction) < 0,
149153
nonlinear_cg_direction,
150154
neg_grad,
151155
)
152-
return _NonlinearCGDescentState(y_diff=y_diff, grad=f_info.grad)
156+
return _NonlinearCGDescentState(y_diff=y_diff, grad=grad)
153157

154158
def step(
155159
self, step_size: Scalar, state: _NonlinearCGDescentState

optimistix/_solver/trust_region.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def predict_reduction(
273273
FunctionInfo.ResidualJac,
274274
),
275275
):
276-
return tree_dot(f_info.grad, y_diff)
276+
return f_info.compute_grad_dot(y_diff)
277277
else:
278278
raise ValueError(
279279
"Cannot use `LinearTrustRegion` with this solver. This is because "

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ classifiers = [
3333
"Topic :: Scientific/Engineering :: Mathematics",
3434
]
3535
urls = {repository = "https://github.com/patrick-kidger/optimistix" }
36-
dependencies = ["jax>=0.4.18", "jaxtyping>=0.2.23", "lineax>=0.0.4", "equinox>=0.11.1", "typing_extensions>=4.5.0"]
36+
dependencies = ["jax>=0.4.18", "jaxtyping>=0.2.23", "lineax>=0.0.5", "equinox>=0.11.1", "typing_extensions>=4.5.0"]
3737

3838
[build-system]
3939
requires = ["hatchling"]

tests/test_least_squares.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,66 @@ def f_bwd(sign, g):
149149

150150
with pytest.raises(TypeError, match="forward-mode autodiff"):
151151
optx.least_squares(f, solver, y0, options=dict(jac="fwd"), max_steps=512)
152+
153+
154+
def test_residual_jac():
155+
# First grab values as computed using the complex implementation. We compute the
156+
# gradient both using the `.compute_grad` method (which uses a custom more-efficient
157+
# approach using forward-mode-autodiff) and the simple way using `jax.grad`.
158+
159+
def residual1(y1):
160+
return y1**2
161+
162+
def compute1(y1):
163+
r = residual1(y1)
164+
jac = lx.MatrixLinearOperator(jax.jacfwd(residual1, holomorphic=True)(y1))
165+
f_info = optx.FunctionInfo.ResidualJac(r, jac)
166+
return f_info.as_min(), (f_info.compute_grad(), f_info.compute_grad_dot(z))
167+
168+
y1 = jnp.array([2 + 3j, 4 + 1j])
169+
z = jnp.array([-1 + 0j, 2 - 5j])
170+
true_min = 0.5 * jnp.sum(y1**2 * jnp.conj(y1**2))
171+
(min1, (grad1, grad_dot1)), true_grad1 = jax.value_and_grad(compute1, has_aux=True)(
172+
y1
173+
)
174+
true_grad_dot1 = jnp.sum(true_grad1 * jnp.conj(z))
175+
176+
# Next compute the same quantities using just the real implementation.
177+
178+
def residual2(y2):
179+
real, imag = y2
180+
return real**2 - imag**2, 2 * real * imag
181+
182+
def compute2(y2):
183+
r = residual2(y2)
184+
jac = lx.PyTreeLinearOperator(
185+
jax.jacfwd(residual2)(y2), jax.eval_shape(lambda: y2)
186+
)
187+
f_info = optx.FunctionInfo.ResidualJac(r, jac)
188+
return f_info.as_min(), (
189+
f_info.compute_grad(),
190+
f_info.compute_grad_dot((z.real, z.imag)),
191+
)
192+
193+
y2 = (y1.real, y1.imag)
194+
(min2, (grad2, grad_dot2)), true_grad2 = jax.value_and_grad(compute2, has_aux=True)(
195+
y2
196+
)
197+
true_grad2_real, true_grad2_imag = true_grad2
198+
true_grad_dot2 = (
199+
true_grad2_real * z.real + true_grad2_imag * z.imag,
200+
true_grad2_imag * z.real - true_grad2_real * z.imag,
201+
)
202+
203+
# Now check consistency.
204+
205+
assert tree_allclose(min1, min2)
206+
assert tree_allclose(min1.astype(jnp.complex128), true_min)
207+
208+
assert tree_allclose(grad2, true_grad2)
209+
assert tree_allclose((grad1.real, grad1.imag), grad2)
210+
assert tree_allclose(grad1, true_grad1)
211+
212+
assert tree_allclose(grad_dot2, true_grad_dot2)
213+
assert tree_allclose((grad_dot1.real, grad_dot1.imag), grad_dot2)
214+
assert tree_allclose(grad_dot1, true_grad_dot1)

tests/test_solve.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import equinox.internal as eqxi
12
import jax
3+
import jax.numpy as jnp
24
import optimistix as optx
35

46

@@ -48,3 +50,16 @@ def fn(x, _):
4850
return optx.fixed_point(fn, solver, 0.0).value
4951

5052
f(0.0)
53+
54+
55+
def test_forward_mode():
56+
def f(y, _):
57+
return eqxi.nondifferentiable_backward(y)
58+
59+
sol = optx.least_squares(
60+
f,
61+
optx.LevenbergMarquardt(rtol=1e-4, atol=1e-4),
62+
jnp.arange(3.0),
63+
options=dict(jac="fwd"),
64+
)
65+
return sol.value

0 commit comments

Comments
 (0)