Skip to content

Commit cb69cab

Browse files
mattjjyashk2810
authored andcommitted
[mutable-arrays] in vjp3, don't stuff GradValue sentinels into user pytree nodes
Co-authored-by: Yash Katariya <yashkatariya@google.com> PiperOrigin-RevId: 823694237
1 parent e4292e6 commit cb69cab

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

jax/_src/api.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,7 +2292,7 @@ def _vjp3(fun, *primals, has_aux=False):
22922292
RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore
22932293
for r in residuals]
22942294
args_res = tuptree_map(lambda x: x if id(x) in used else NotNeeded(),
2295-
in_tree, primals)
2295+
in_tree, primals_flat)
22962296
out_primal_avals = [typeof(x) for x in out_primals_flat]
22972297
f_vjp = VJP(partial(_vjp3_callable, spec, out_known, jaxpr, out_primal_avals),
22982298
in_tree, out_tree, list(args_res), opaque_residuals)
@@ -2305,15 +2305,19 @@ def _vjp3(fun, *primals, has_aux=False):
23052305
def tuptree_map(f, treedef, x):
23062306
return treedef.walk(lambda xs, _: tuple(xs), f, x)
23072307

2308+
23082309
def _is_ref(x):
23092310
from jax._src.state.types import AbstractRef
23102311
try: return isinstance(typeof(x), AbstractRef)
23112312
except: return False
23122313

23132314
def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree,
23142315
args_res, opaque_res, *maybe_ct_refs):
2315-
maybe_ct_refs_flat, in_tree_ = tree_flatten(maybe_ct_refs)
2316-
if in_tree != in_tree_: raise Exception
2316+
if not maybe_ct_refs:
2317+
maybe_ct_refs_flat = [GradValue()] * in_tree.num_leaves
2318+
else:
2319+
maybe_ct_refs_flat, in_tree_ = tree_flatten(maybe_ct_refs)
2320+
if in_tree != in_tree_: raise Exception # TODO accept isomorph tuple tree
23172321
args_res_ = tree_leaves(args_res, is_leaf=lambda x: isinstance(x, NotNeeded))
23182322
residuals = [args_res_[i.idx] if i.primal else opaque_res[i.idx] for i in spec]
23192323
maybe_refs = [ad.RefAccum(v.aval, x) if _is_ref(x) else ad.ValAccum(v.aval)
@@ -2407,9 +2411,8 @@ def __call__(self, out_ct, *extra_args):
24072411
if extra_args:
24082412
name, *_ = self.jaxpr.debug_info.func_src_info.split(' ')
24092413
raise TypeError(_vjp_too_many_args(name, len(extra_args)))
2410-
dums = tree_unflatten(self.in_tree, [GradValue()] * self.in_tree.num_leaves)
24112414
return self.fun(self.in_tree, self.out_tree, self.args_res,
2412-
self.opaque_residuals, *dums)(out_ct)
2415+
self.opaque_residuals)(out_ct)
24132416

24142417
def with_refs(self, *maybe_ct_refs):
24152418
return self.fun(self.in_tree, self.out_tree, self.args_res,

0 commit comments

Comments
 (0)