2626from typing import ClassVar , Generic , Type , TypeVar
2727
2828import equinox as eqx
29+ import jax .numpy as jnp
30+ import jax .tree_util as jtu
2931import lineax as lx
32+ from equinox .internal import ω
3033from jaxtyping import Array , Bool , Scalar
3134
3235from ._custom_types import (
3538 SearchState ,
3639 Y ,
3740)
38- from ._misc import sum_squares
41+ from ._misc import sum_squares , tree_dot
3942from ._solution import RESULTS
4043
4144
@@ -89,6 +92,9 @@ class EvalGrad(FunctionInfo, Generic[Y], strict=True):
8992 def as_min (self ):
9093 return self .f
9194
95+ def compute_grad_dot (self , y : Y ):
96+ return tree_dot (self .grad , y )
97+
9298
9399# NOT PUBLIC, despite lacking an underscore. This is so pyright gets the name right.
94100class EvalGradHessian (FunctionInfo , Generic [Y ], strict = True ):
@@ -104,6 +110,9 @@ class EvalGradHessian(FunctionInfo, Generic[Y], strict=True):
104110 def as_min (self ):
105111 return self .f
106112
113+ def compute_grad_dot (self , y : Y ):
114+ return tree_dot (self .grad , y )
115+
107116
108117# NOT PUBLIC, despite lacking an underscore. This is so pyright gets the name right.
109118class EvalGradHessianInv (FunctionInfo , Generic [Y ], strict = True ):
@@ -118,6 +127,9 @@ class EvalGradHessianInv(FunctionInfo, Generic[Y], strict=True):
118127 def as_min (self ):
119128 return self .f
120129
130+ def compute_grad_dot (self , y : Y ):
131+ return tree_dot (self .grad , y )
132+
121133
122134# NOT PUBLIC, despite lacking an underscore. This is so pyright gets the name right.
123135class Residual (FunctionInfo , Generic [Out ], strict = True ):
@@ -144,18 +156,48 @@ class ResidualJac(FunctionInfo, Generic[Y, Out], strict=True):
144156
145157 residual : Out
146158 jac : lx .AbstractLinearOperator
147- grad : Y
148-
149- def __init__ (self , residual : Out , jac : lx .AbstractLinearOperator ):
150- self .residual = residual
151- self .jac = jac
152- # The gradient is used ubiquitously, so compute it once here, so that it can be
153- # used without recomputation in both the descent and search.
154- self .grad = jac .transpose ().mv (residual )
155159
156160 def as_min (self ):
157161 return 0.5 * sum_squares (self .residual )
158162
163+ def compute_grad (self ):
164+ # Not precomputed during `__init__` as this may hit reverse-mode autodiff which
165+ # may not be valid.
166+ if any (jnp .iscomplexobj (x ) for x in jtu .tree_leaves (self .residual )):
167+ conj_residual = jtu .tree_map (jnp .conj , self .residual )
168+ conj_jac = lx .conj (self .jac )
169+ return (
170+ 0.5
171+ * (
172+ self .jac .transpose ().mv (conj_residual ) ** ω
173+ + conj_jac .transpose ().mv (self .residual ) ** ω
174+ )
175+ ).ω
176+ else :
177+ return self .jac .transpose ().mv (self .residual )
178+
179+ def compute_grad_dot (self , y : Y ):
180+ # If `self.jac` is a `lx.JacobianLinearOperator` (or a
181+ # `lx.FunctionLinearOperator` wrapping the result of `jax.linearize`), then
182+ # `grad = jac^T residual`, so that what we want to compute is
183+ # `residual^T jac y`. Doing the reduction in this order means we hit
184+ # forward-mode rather than reverse-mode autodiff.
185+ if any (jnp .iscomplexobj (x ) for x in jtu .tree_leaves (self .residual )):
186+ # In this case then actually
187+ # `grad = 0.5 * (jac^T residual^bar + jac^Tbar residual)`.
188+ # all of this.
189+ conj_residual = jtu .tree_map (jnp .conj , self .residual )
190+ conj_jac = lx .conj (self .jac )
191+ return (
192+ 0.5
193+ * (
194+ tree_dot (conj_residual , self .jac .mv (y )) ** ω
195+ + tree_dot (self .residual , conj_jac .mv (y )) ** ω
196+ )
197+ ).ω
198+ else :
199+ return tree_dot (self .residual , self .jac .mv (y ))
200+
159201
160202Eval .__qualname__ = "FunctionInfo.Eval"
161203EvalGrad .__qualname__ = "FunctionInfo.EvalGrad"
0 commit comments