Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions export/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from orbax.export import obm_configs
from orbax.export import typing
from orbax.export import utils
from orbax.export.data_processors.tf_data_processor import TfDataProcessor
from orbax.export.dtensor_utils import dtensor_initialized
from orbax.export.dtensor_utils import get_current_dtensor_mesh
from orbax.export.dtensor_utils import get_current_mesh
Expand All @@ -29,6 +30,7 @@
from orbax.export.dtensor_utils import shutdown_dtensor
from orbax.export.export_manager import ExportManager
from orbax.export.jax_module import JaxModule
from orbax.export.obm_configs import Jax2ObmOptions
from orbax.export.serving_config import ServingConfig
# TODO(dinghua): remove them after we change all references to
# utils.remove_signature_defaults.
Expand Down
2 changes: 2 additions & 0 deletions export/orbax/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from orbax.export import obm_configs
from orbax.export import typing
from orbax.export import utils
from orbax.export.data_processors.tf_data_processor import TfDataProcessor
from orbax.export.dtensor_utils import dtensor_initialized
from orbax.export.dtensor_utils import get_current_dtensor_mesh
from orbax.export.dtensor_utils import get_current_mesh
Expand All @@ -29,6 +30,7 @@
from orbax.export.dtensor_utils import shutdown_dtensor
from orbax.export.export_manager import ExportManager
from orbax.export.jax_module import JaxModule
from orbax.export.obm_configs import Jax2ObmOptions
from orbax.export.serving_config import ServingConfig
# TODO(dinghua): remove them after we change all references to
# utils.remove_signature_defaults.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ name_to_pipeline {
named_signature {
inputs {
named_tensor_types {
name: "inputs_model_output"
name: "model_output"
tensor_type {
shape {
shape_with_known_rank {
Expand Down
35 changes: 35 additions & 0 deletions model/orbax/experimental/model/jax2obm/jax_specific_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,38 @@ def unzip2(
return tuple(xs), tuple(ys)


def _name_leaf(path, leaf):
"""Assigns a name to a leaf node based on its PyTree path.

The name is a dot-separated string representation of the path keys.
This function is designed to be used with `jax.tree_util.tree_map_with_path`.
It modifies the `leaf` in-place by setting its `name` attribute.

Args:
path: A tuple of path elements (e.g., DictKey, SequenceKey) representing the
path to the leaf in a PyTree.
leaf: The leaf node to be named. It must have a `name` attribute that can be
set.

Returns:
The leaf node with its `name` attribute set to the path string.
"""
path_str_parts = []
for key in path:
if isinstance(key, jax.tree_util.DictKey):
path_str_parts.append(str(key.key))
elif isinstance(key, jax.tree_util.SequenceKey):
path_str_parts.append(str(key.idx))
elif isinstance(key, jax.tree_util.GetAttrKey):
path_str_parts.append(str(key.name))
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
path_str_parts.append(str(key.idx))
else:
raise TypeError(f"Unknown key type: {type(key)}")
leaf.name = ".".join(path_str_parts)
return leaf


def _serialize_effect(eff: jax.core.Effect) -> str:
"""Serializes a JAX Effect to a string.

Expand Down Expand Up @@ -311,6 +343,7 @@ def _to_shlo_spec_tree_and_refinement_tuple(
avals: Sequence[jax.core.AbstractValue],
shardings: Sequence[Any],
tree_def: Optional[jax.tree_util.PyTreeDef],
name_leaves: bool = False,
) -> Tuple[
obm.Tree[obm.ShloTensorSpec], Tuple[ShapeDTypeRefinementPair, ...] | None
]:
Expand All @@ -329,4 +362,6 @@ def assert_leaf(x: Any) -> None:
)
obm.tree_util.assert_tree(assert_leaf, jax_tree)
jax_tree: obm.Tree[obm.ShloTensorSpec]
if name_leaves:
jax_tree = jax.tree_util.tree_map_with_path(_name_leaf, jax_tree)
return jax_tree, refinements
1 change: 1 addition & 0 deletions model/orbax/experimental/model/jax2obm/main_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def jax_exported_to_shlo_fn(
exported.out_avals,
out_shardings_hlo,
exported.out_tree,
name_leaves=True,
)
)
supplemental_info_ = {}
Expand Down
12 changes: 3 additions & 9 deletions model/orbax/experimental/model/tf2obm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,11 @@ def tf_dtype_to_obm(t: tf.DType) -> obm.ShloDType:
return obm.np_dtype_to_shlo_dtype(np_dtype)


def tf_tensor_spec_to_obm(spec: Any) -> obm.ShloTensorSpec:
def tf_tensor_spec_to_obm(
spec: tf.TensorSpec | tf.Tensor,
) -> obm.ShloTensorSpec:
"""Converts a tf.TensorSpec or tf.SymbolicTensor to ShloTensorSpec."""

# ConcreteFunction.structured_outputs returns `SymbolicTensor`s, not
# `TensorSpec`s, so we need to also check for `SymbolicTensor`.
if not (isinstance(spec, tf.TensorSpec) or tf.is_symbolic_tensor(spec)):
raise ValueError(
f'Expected a tf.TensorSpec or a SymbolicTensor, got {spec} of type'
f' {type(spec)}'
)

if spec.shape.rank is None:
obm_shape = None
else:
Expand Down
1 change: 1 addition & 0 deletions model/orbax/experimental/model/voxel2obm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _voxel_to_obm_dtype(t) -> obm.ShloDType:
return obm.np_dtype_to_shlo_dtype(t)


# TODO: b/476448823 - Add name to the output ShloTensorSpec.
def voxel_signature_to_obm_spec(
signature: jd.VoxelSchemaTree,
) -> obm.Tree[obm.ShloTensorSpec]:
Expand Down
Loading