1414"""K-FAC utilities for classes with staged methods."""
1515
1616import functools
17+ import inspect
1718import numbers
1819import operator
1920from typing import Any , Callable , Sequence
@@ -165,8 +166,15 @@ def staged(
165166 This decorator **should** only be applied to instance methods of classes that
166167 inherit from the `WithStagedMethods` class. The decorator makes the decorated
167168 method staged, which is equivalent to `jax.jit` if `instance.multi_device` is
168- `False` and to `jax.pmap` otherwise. When specifying static and donated
169- argunms, the `self` reference **must not** be counted. Example:
169+ `False` and to `jax.pmap` otherwise.
170+
171+ Note that the point of this abstraction around JAX's compilation is to make
172+ sure that jitting/pmapping is only done once, so that if we are already in a
173+ compiled/staged method, we won't initiate a second nested compilation when
174+ calling into second staged method.
175+
176+ Note that when specifying static and donated argunms, the `self` reference
177+ **must not** be counted. Example:
170178
171179 @functools.partial(staged, donate_argunms=0)
172180 def try(self, x):
@@ -210,12 +218,22 @@ def try(self, x):
210218 donate_argnums = donate_argnums )
211219
212220 @functools .wraps (method )
213- def decorated (instance : "WithStagedMethods" , * args : Any ) -> TArrayTree :
221+ def decorated (
222+ instance : "WithStagedMethods" ,
223+ * args : Any ,
224+ ** kwargs : Any
225+ ) -> TArrayTree :
226+
227+ sig = inspect .signature (method )
228+ bound_args = sig .bind (instance , * args , ** kwargs )
229+ bound_args .apply_defaults ()
230+ args , kwargs = bound_args .args [1 :], bound_args .kwargs
214231
215232 if instance .in_staging :
216- return method (instance , * args )
233+ return method (instance , * args , ** kwargs )
217234
218235 with instance .staging_context ():
236+
219237 if instance .multi_device and instance .debug :
220238 # In this case we want to call `method` once for each device index.
221239 # Note that this might not always produce sensible behavior, and will
@@ -241,14 +259,16 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
241259 for j in range (len (args ))
242260 ]
243261
262+ kwargs_i = jax .tree_util .tree_map (operator .itemgetter (i ), kwargs )
263+
244264 with jax .disable_jit ():
245- outs .append (method (instance , * args_i ))
265+ outs .append (method (instance , * args_i , ** kwargs_i ))
246266
247267 outs = jax .tree_util .tree_map (lambda * args_ : jnp .stack (args_ ), * outs )
248268
249269 elif instance .debug :
250270 with jax .disable_jit ():
251- outs = method (instance , * args )
271+ outs = method (instance , * args , ** kwargs )
252272
253273 elif instance .multi_device :
254274
@@ -274,10 +294,10 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
274294 )
275295 pmap_funcs [key ] = func
276296
277- outs = func (instance , * args )
297+ outs = func (instance , * args , ** kwargs )
278298
279299 else :
280- outs = jitted_func (instance , * args )
300+ outs = jitted_func (instance , * args , ** kwargs )
281301
282302 return outs
283303
0 commit comments