Skip to content

Commit cf3acc0

Browse files
james-martensKfacJaxDev
authored andcommitted
Adding support for keyword arguments to staged methods.
PiperOrigin-RevId: 702342993
1 parent bc000c6 commit cf3acc0

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

kfac_jax/_src/utils/staging.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""K-FAC utilities for classes with staged methods."""
1515

1616
import functools
17+
import inspect
1718
import numbers
1819
import operator
1920
from typing import Any, Callable, Sequence
@@ -165,8 +166,15 @@ def staged(
165166
This decorator **should** only be applied to instance methods of classes that
166167
inherit from the `WithStagedMethods` class. The decorator makes the decorated
167168
method staged, which is equivalent to `jax.jit` if `instance.multi_device` is
168-
`False` and to `jax.pmap` otherwise. When specifying static and donated
169-
argunms, the `self` reference **must not** be counted. Example:
169+
`False` and to `jax.pmap` otherwise.
170+
171+
Note that the point of this abstraction around JAX's compilation is to make
172+
sure that jitting/pmapping is only done once, so that if we are already in a
173+
compiled/staged method, we won't initiate a second nested compilation when
174+
calling into second staged method.
175+
176+
Note that when specifying static and donated argunms, the `self` reference
177+
**must not** be counted. Example:
170178
171179
@functools.partial(staged, donate_argunms=0)
172180
def try(self, x):
@@ -210,12 +218,22 @@ def try(self, x):
210218
donate_argnums=donate_argnums)
211219

212220
@functools.wraps(method)
213-
def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
221+
def decorated(
222+
instance: "WithStagedMethods",
223+
*args: Any,
224+
**kwargs: Any
225+
) -> TArrayTree:
226+
227+
sig = inspect.signature(method)
228+
bound_args = sig.bind(instance, *args, **kwargs)
229+
bound_args.apply_defaults()
230+
args, kwargs = bound_args.args[1:], bound_args.kwargs
214231

215232
if instance.in_staging:
216-
return method(instance, *args)
233+
return method(instance, *args, **kwargs)
217234

218235
with instance.staging_context():
236+
219237
if instance.multi_device and instance.debug:
220238
# In this case we want to call `method` once for each device index.
221239
# Note that this might not always produce sensible behavior, and will
@@ -241,14 +259,16 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
241259
for j in range(len(args))
242260
]
243261

262+
kwargs_i = jax.tree_util.tree_map(operator.itemgetter(i), kwargs)
263+
244264
with jax.disable_jit():
245-
outs.append(method(instance, *args_i))
265+
outs.append(method(instance, *args_i, **kwargs_i))
246266

247267
outs = jax.tree_util.tree_map(lambda *args_: jnp.stack(args_), *outs)
248268

249269
elif instance.debug:
250270
with jax.disable_jit():
251-
outs = method(instance, *args)
271+
outs = method(instance, *args, **kwargs)
252272

253273
elif instance.multi_device:
254274

@@ -274,10 +294,10 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
274294
)
275295
pmap_funcs[key] = func
276296

277-
outs = func(instance, *args)
297+
outs = func(instance, *args, **kwargs)
278298

279299
else:
280-
outs = jitted_func(instance, *args)
300+
outs = jitted_func(instance, *args, **kwargs)
281301

282302
return outs
283303

0 commit comments

Comments
 (0)