@@ -2292,7 +2292,7 @@ def _vjp3(fun, *primals, has_aux=False):
22922292 RSpec (opaque_residuals .append (r ) or (len (opaque_residuals ) - 1 ), False ) # type: ignore
22932293 for r in residuals ]
22942294 args_res = tuptree_map (lambda x : x if id (x ) in used else NotNeeded (),
2295- in_tree , primals )
2295+ in_tree , primals_flat )
22962296 out_primal_avals = [typeof (x ) for x in out_primals_flat ]
22972297 f_vjp = VJP (partial (_vjp3_callable , spec , out_known , jaxpr , out_primal_avals ),
22982298 in_tree , out_tree , list (args_res ), opaque_residuals )
@@ -2305,15 +2305,19 @@ def _vjp3(fun, *primals, has_aux=False):
23052305def tuptree_map (f , treedef , x ):
23062306 return treedef .walk (lambda xs , _ : tuple (xs ), f , x )
23072307
2308+
23082309def _is_ref (x ):
23092310 from jax ._src .state .types import AbstractRef
23102311 try : return isinstance (typeof (x ), AbstractRef )
23112312 except : return False
23122313
23132314def _vjp3_callable (spec , out_known , jaxpr , out_primal_avals , in_tree , out_tree ,
23142315 args_res , opaque_res , * maybe_ct_refs ):
2315- maybe_ct_refs_flat , in_tree_ = tree_flatten (maybe_ct_refs )
2316- if in_tree != in_tree_ : raise Exception
2316+ if not maybe_ct_refs :
2317+ maybe_ct_refs_flat = [GradValue ()] * in_tree .num_leaves
2318+ else :
2319+ maybe_ct_refs_flat , in_tree_ = tree_flatten (maybe_ct_refs )
2320+ if in_tree != in_tree_ : raise Exception # TODO accept isomorph tuple tree
23172321 args_res_ = tree_leaves (args_res , is_leaf = lambda x : isinstance (x , NotNeeded ))
23182322 residuals = [args_res_ [i .idx ] if i .primal else opaque_res [i .idx ] for i in spec ]
23192323 maybe_refs = [ad .RefAccum (v .aval , x ) if _is_ref (x ) else ad .ValAccum (v .aval )
@@ -2407,9 +2411,8 @@ def __call__(self, out_ct, *extra_args):
24072411 if extra_args :
24082412 name , * _ = self .jaxpr .debug_info .func_src_info .split (' ' )
24092413 raise TypeError (_vjp_too_many_args (name , len (extra_args )))
2410- dums = tree_unflatten (self .in_tree , [GradValue ()] * self .in_tree .num_leaves )
24112414 return self .fun (self .in_tree , self .out_tree , self .args_res ,
2412- self .opaque_residuals , * dums )(out_ct )
2415+ self .opaque_residuals )(out_ct )
24132416
24142417 def with_refs (self , * maybe_ct_refs ):
24152418 return self .fun (self .in_tree , self .out_tree , self .args_res ,
0 commit comments