Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ def _grad_general(

if not graph:
if any(isinstance(x, DiffState) for x in jax.tree.leaves(argnums)):
raise ValueError('`argnums` cannot contain `DiffState` objects when `graph=False`')
raise ValueError(
'`argnums` cannot contain `DiffState` objects '
'when `graph=False`'
)

gradded_fn = transform(
TreeGradFn(f, has_aux),
Expand Down
23 changes: 23 additions & 0 deletions flax/nnx/transforms/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,18 @@ def jit(
if was_bound:
_raise_bound_method_error('jit')

if not graph:
if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_shardings)):
raise ValueError(
'`in_shardings` cannot contain `StateSharding` objects '
'when `graph=False`'
)
if any(isinstance(x, StateSharding) for x in jax.tree.leaves(out_shardings)):
raise ValueError(
'`out_shardings` cannot contain `StateSharding` objects '
'when `graph=False`'
)

wrapped_cls = JitWrapped if graph else TreeJitWrapped
return wrapped_cls(
fun_unbound,
Expand Down Expand Up @@ -1243,6 +1255,17 @@ def f(m, x):
_raise_bound_method_error('shard_map')

if not graph:
if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_specs)):
raise ValueError(
'`in_specs` cannot contain `StateSharding` objects '
'when `graph=False`'
)
if any(isinstance(x, StateSharding) for x in jax.tree.leaves(out_specs)):
raise ValueError(
'`out_specs` cannot contain `StateSharding` objects '
'when `graph=False`'
)

tree_shard_map_fn = jax.shard_map(
TreeShardMapFn(f_unbound),
mesh=mesh,
Expand Down
11 changes: 11 additions & 0 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,17 @@ def vmap(
_raise_bound_method_error('vmap')

if not graph:
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)):
raise ValueError(
'`in_axes` cannot contain `StateAxes` objects '
'when `graph=False`'
)
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)):
raise ValueError(
'`out_axes` cannot contain `StateAxes` objects '
'when `graph=False`'
)

vmapped_fn = jax.vmap(
TreeVmapFn(f_unbound),
in_axes=in_axes,
Expand Down
Loading