@@ -7817,12 +7817,27 @@ def test_basic_unused(self):
78177817 with self .assertRaisesRegex (Exception , "not used by the backward pass: x" ):
78187818 _ = api .si_vjp (f , [True ], * primals , allow_unused = False )
78197819
7820+ def test_basic_unused_vjp3 (self ):
7821+ f = jnp .sin
7822+ primals = 3. ,
7823+ y , f_vjp = api .vjp3 (f , * primals )
7824+ x_ct , = f_vjp (1. )
7825+ self .assertAllClose (y , jnp .sin (3. ))
7826+ self .assertAllClose (x_ct , jnp .cos (3. ))
7827+ self .assertIsInstance (f_vjp .args_res [0 ], api .NotNeeded ) # can check if unused
7828+
78207829 def test_basic_opaque (self ):
78217830 f = jnp .sin
78227831 primals = 3. ,
78237832 with self .assertRaisesRegex (Exception , "the backward pass requires opaque" ):
78247833 _ = api .si_vjp (f , [True ], * primals , allow_opaque = False )
78257834
7835+ def test_basic_opaque_vjp3 (self ):
7836+ f = jnp .sin
7837+ primals = 3. ,
7838+ _ , f_vjp = api .vjp3 (f , * primals )
7839+ assert f_vjp .opaque_residuals # can detect if opaque res are used
7840+
78267841 def test_basic_pytree_error (self ):
78277842 def f (x ):
78287843 return [x ['hi' ] * x ['bye' ]]
@@ -7835,6 +7850,20 @@ def f(x):
78357850 with self .assertRaisesRegex (ValueError , "but the structures differ" ):
78367851 f_vjp (1. , {'hi' : 2. })
78377852
7853+ # TODO(mattjj): improve this vjp3 error message
7854+ # def test_basic_pytree_error_vjp3(self):
7855+ # def f(x):
7856+ # return [x['hi'] * x['bye']]
7857+
7858+ # y, f_vjp = api.vjp3(f, {'hi': 2., 'bye': 3.})
7859+ # arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.})
7860+ # self.assertAllClose(y, [6.])
7861+ # self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.})
7862+
7863+ # f_vjp.args_res[0] = {'hi': 2.}
7864+ # with self.assertRaisesRegex(ValueError, "but the structures differ"):
7865+ # f_vjp(1.)
7866+
78387867 def test_fsdp (self ):
78397868 # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp"
78407869 def f2 (x , w ):
@@ -7849,6 +7878,24 @@ def f2(x, w):
78497878 y_grad = jnp .ones_like (y )
78507879 x_grad , w_grad = f2_sivjp (y_grad , w )
78517880 self .assertAllClose (x_grad , 2. * y_grad @ w .T )
7881+
7882+ def test_fsdp_vjp3 (self ):
7883+ # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp"
7884+ def f2 (x , w ):
7885+ x = 1. * x
7886+ x = x @ w
7887+ x = 2. * x
7888+ return x
7889+
7890+ x = jnp .ones ((3 , 4 ))
7891+ w = jnp .ones ((4 , 4 ))
7892+ y , f2_vjp = api .vjp3 (f2 , x , w )
7893+ f2_vjp .args_res [1 ] = None
7894+ y_grad = jnp .ones_like (y )
7895+ f2_vjp .args_res [1 ] = w
7896+ x_grad , w_grad = f2_vjp (y_grad )
7897+ self .assertAllClose (x_grad , 2. * y_grad @ w .T )
7898+ self .assertAllClose (w_grad , 2. * x .T @ y_grad )
78527899 self .assertAllClose (w_grad , 2. * x .T @ y_grad )
78537900
78547901 def test_doesnt_leak_symbolic_zeros (self ):
0 commit comments