Skip to content

Commit 1e8cb9c

Browse files
Merge pull request #32735 from mattjj:vjp3-args-res-tuple-trees
PiperOrigin-RevId: 820932861
2 parents 8065f96 + 9fda7f7 commit 1e8cb9c

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

jax/_src/api.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,7 +2292,8 @@ def _vjp3(fun, *primals, has_aux=False):
22922292
spec = [used.add(id(r)) or RSpec(id_map[id(r)], True) if id(r) in id_map else # type: ignore
22932293
RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore
22942294
for r in residuals]
2295-
args_res = tree_map(lambda x: x if id(x) in used else NotNeeded(), primals)
2295+
args_res = tuptree_map(lambda x: x if id(x) in used else NotNeeded(),
2296+
in_tree, primals)
22962297
out_primal_avals = [typeof(x) for x in out_primals_flat]
22972298
f_vjp = VJP(partial(_vjp3_callable, spec, out_known, jaxpr, out_primal_avals),
22982299
in_tree, out_tree, list(args_res), opaque_residuals)
@@ -2302,6 +2303,9 @@ def _vjp3(fun, *primals, has_aux=False):
23022303
else:
23032304
return out_primals, f_vjp, tree_unflatten(aux_tree, aux)
23042305

2306+
def tuptree_map(f, treedef, x):
2307+
return treedef.walk(lambda xs, _: tuple(xs), f, x)
2308+
23052309
def _is_ref(x):
23062310
from jax._src.state.types import AbstractRef
23072311
try: return isinstance(typeof(x), AbstractRef)
@@ -2311,10 +2315,8 @@ def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree,
23112315
args_res, opaque_res, *maybe_ct_refs):
23122316
maybe_ct_refs_flat, in_tree_ = tree_flatten(maybe_ct_refs)
23132317
if in_tree != in_tree_: raise Exception
2314-
args_res_flat, in_tree_ = tree_flatten(
2315-
tuple(args_res), is_leaf=lambda x: isinstance(x, NotNeeded))
2316-
if in_tree != in_tree_: raise Exception
2317-
residuals = [args_res_flat[i.idx] if i.primal else opaque_res[i.idx] for i in spec]
2318+
args_res_ = tree_leaves(args_res, is_leaf=lambda x: isinstance(x, NotNeeded))
2319+
residuals = [args_res_[i.idx] if i.primal else opaque_res[i.idx] for i in spec]
23182320
maybe_refs = [ad.RefAccum(v.aval, x) if _is_ref(x) else ad.ValAccum(v.aval)
23192321
for v, x in zip(jaxpr.invars, maybe_ct_refs_flat)]
23202322
return Partial(partial(_vjp3_bwd, in_tree, out_tree, out_known, jaxpr,

tests/api_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7817,12 +7817,27 @@ def test_basic_unused(self):
78177817
with self.assertRaisesRegex(Exception, "not used by the backward pass: x"):
78187818
_ = api.si_vjp(f, [True], *primals, allow_unused=False)
78197819

7820+
def test_basic_unused_vjp3(self):
7821+
f = jnp.sin
7822+
primals = 3.,
7823+
y, f_vjp = api.vjp3(f, *primals)
7824+
x_ct, = f_vjp(1.)
7825+
self.assertAllClose(y, jnp.sin(3.))
7826+
self.assertAllClose(x_ct, jnp.cos(3.))
7827+
self.assertIsInstance(f_vjp.args_res[0], api.NotNeeded) # can check if unused
7828+
78207829
def test_basic_opaque(self):
78217830
f = jnp.sin
78227831
primals = 3.,
78237832
with self.assertRaisesRegex(Exception, "the backward pass requires opaque"):
78247833
_ = api.si_vjp(f, [True], *primals, allow_opaque=False)
78257834

7835+
def test_basic_opaque_vjp3(self):
7836+
f = jnp.sin
7837+
primals = 3.,
7838+
_, f_vjp = api.vjp3(f, *primals)
7839+
assert f_vjp.opaque_residuals # can detect if opaque res are used
7840+
78267841
def test_basic_pytree_error(self):
78277842
def f(x):
78287843
return [x['hi'] * x['bye']]
@@ -7835,6 +7850,20 @@ def f(x):
78357850
with self.assertRaisesRegex(ValueError, "but the structures differ"):
78367851
f_vjp(1., {'hi': 2.})
78377852

7853+
# TODO(mattjj): improve this vjp3 error message
7854+
# def test_basic_pytree_error_vjp3(self):
7855+
# def f(x):
7856+
# return [x['hi'] * x['bye']]
7857+
7858+
# y, f_vjp = api.vjp3(f, {'hi': 2., 'bye': 3.})
7859+
# arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.})
7860+
# self.assertAllClose(y, [6.])
7861+
# self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.})
7862+
7863+
# f_vjp.args_res[0] = {'hi': 2.}
7864+
# with self.assertRaisesRegex(ValueError, "but the structures differ"):
7865+
# f_vjp(1.)
7866+
78387867
def test_fsdp(self):
78397868
# see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp"
78407869
def f2(x, w):
@@ -7849,6 +7878,24 @@ def f2(x, w):
78497878
y_grad = jnp.ones_like(y)
78507879
x_grad, w_grad = f2_sivjp(y_grad, w)
78517880
self.assertAllClose(x_grad, 2. * y_grad @ w.T)
7881+
7882+
def test_fsdp_vjp3(self):
7883+
# see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp"
7884+
def f2(x, w):
7885+
x = 1. * x
7886+
x = x @ w
7887+
x = 2. * x
7888+
return x
7889+
7890+
x = jnp.ones((3, 4))
7891+
w = jnp.ones((4, 4))
7892+
y, f2_vjp = api.vjp3(f2, x, w)
7893+
f2_vjp.args_res[1] = None
7894+
y_grad = jnp.ones_like(y)
7895+
f2_vjp.args_res[1] = w
7896+
x_grad, w_grad = f2_vjp(y_grad)
7897+
self.assertAllClose(x_grad, 2. * y_grad @ w.T)
7898+
self.assertAllClose(w_grad, 2. * x.T @ y_grad)
78527899
self.assertAllClose(w_grad, 2. * x.T @ y_grad)
78537900

78547901
def test_doesnt_leak_symbolic_zeros(self):

0 commit comments

Comments
 (0)