diff --git a/export/orbax/__init__.py b/export/orbax/__init__.py index 8f45ab4ac..c144925e8 100644 --- a/export/orbax/__init__.py +++ b/export/orbax/__init__.py @@ -20,6 +20,7 @@ from orbax.export import obm_configs from orbax.export import typing from orbax.export import utils +from orbax.export.data_processors.tf_data_processor import TfDataProcessor from orbax.export.dtensor_utils import dtensor_initialized from orbax.export.dtensor_utils import get_current_dtensor_mesh from orbax.export.dtensor_utils import get_current_mesh @@ -29,6 +30,7 @@ from orbax.export.dtensor_utils import shutdown_dtensor from orbax.export.export_manager import ExportManager from orbax.export.jax_module import JaxModule +from orbax.export.obm_configs import Jax2ObmOptions from orbax.export.serving_config import ServingConfig # TODO(dinghua): remove them after we change all references to # utils.remove_signature_defaults. diff --git a/export/orbax/export/__init__.py b/export/orbax/export/__init__.py index 8f45ab4ac..c144925e8 100644 --- a/export/orbax/export/__init__.py +++ b/export/orbax/export/__init__.py @@ -20,6 +20,7 @@ from orbax.export import obm_configs from orbax.export import typing from orbax.export import utils +from orbax.export.data_processors.tf_data_processor import TfDataProcessor from orbax.export.dtensor_utils import dtensor_initialized from orbax.export.dtensor_utils import get_current_dtensor_mesh from orbax.export.dtensor_utils import get_current_mesh @@ -29,6 +30,7 @@ from orbax.export.dtensor_utils import shutdown_dtensor from orbax.export.export_manager import ExportManager from orbax.export.jax_module import JaxModule +from orbax.export.obm_configs import Jax2ObmOptions from orbax.export.serving_config import ServingConfig # TODO(dinghua): remove them after we change all references to # utils.remove_signature_defaults. diff --git a/export/orbax/export/testdata/expected_complex_graph_oex_orchestration_pipelines.textproto b/export/orbax/export/testdata/expected_complex_graph_oex_orchestration_pipelines.textproto index fed6fe910..f49c506f6 100644 --- a/export/orbax/export/testdata/expected_complex_graph_oex_orchestration_pipelines.textproto +++ b/export/orbax/export/testdata/expected_complex_graph_oex_orchestration_pipelines.textproto @@ -158,7 +158,7 @@ name_to_pipeline { named_signature { inputs { named_tensor_types { - name: "inputs_model_output" + name: "model_output" tensor_type { shape { shape_with_known_rank { diff --git a/model/orbax/experimental/model/jax2obm/jax_specific_info.py b/model/orbax/experimental/model/jax2obm/jax_specific_info.py index bcfdf0a2d..0b16cc7ff 100644 --- a/model/orbax/experimental/model/jax2obm/jax_specific_info.py +++ b/model/orbax/experimental/model/jax2obm/jax_specific_info.py @@ -47,6 +47,38 @@ def unzip2( return tuple(xs), tuple(ys) +def _name_leaf(path, leaf): + """Assigns a name to a leaf node based on its PyTree path. + + The name is a dot-separated string representation of the path keys. + This function is designed to be used with `jax.tree_util.tree_map_with_path`. + It modifies the `leaf` in-place by setting its `name` attribute. + + Args: + path: A tuple of path elements (e.g., DictKey, SequenceKey) representing the + path to the leaf in a PyTree. + leaf: The leaf node to be named. It must have a `name` attribute that can be + set. + + Returns: + The leaf node with its `name` attribute set to the path string. + """ + path_str_parts = [] + for key in path: + if isinstance(key, jax.tree_util.DictKey): + path_str_parts.append(str(key.key)) + elif isinstance(key, jax.tree_util.SequenceKey): + path_str_parts.append(str(key.idx)) + elif isinstance(key, jax.tree_util.GetAttrKey): + path_str_parts.append(str(key.name)) + elif isinstance(key, jax.tree_util.FlattenedIndexKey): + path_str_parts.append(str(key.idx)) + else: + raise TypeError(f"Unknown key type: {type(key)}") + leaf.name = ".".join(path_str_parts) + return leaf + + def _serialize_effect(eff: jax.core.Effect) -> str: """Serializes a JAX Effect to a string. @@ -311,6 +343,7 @@ def _to_shlo_spec_tree_and_refinement_tuple( avals: Sequence[jax.core.AbstractValue], shardings: Sequence[Any], tree_def: Optional[jax.tree_util.PyTreeDef], + name_leaves: bool = False, ) -> Tuple[ obm.Tree[obm.ShloTensorSpec], Tuple[ShapeDTypeRefinementPair, ...] | None ]: @@ -329,4 +362,6 @@ def assert_leaf(x: Any) -> None: ) obm.tree_util.assert_tree(assert_leaf, jax_tree) jax_tree: obm.Tree[obm.ShloTensorSpec] + if name_leaves: + jax_tree = jax.tree_util.tree_map_with_path(_name_leaf, jax_tree) return jax_tree, refinements diff --git a/model/orbax/experimental/model/jax2obm/main_lib.py b/model/orbax/experimental/model/jax2obm/main_lib.py index f16a54ab3..330af8d44 100644 --- a/model/orbax/experimental/model/jax2obm/main_lib.py +++ b/model/orbax/experimental/model/jax2obm/main_lib.py @@ -65,6 +65,7 @@ def jax_exported_to_shlo_fn( exported.out_avals, out_shardings_hlo, exported.out_tree, + name_leaves=True, ) ) supplemental_info_ = {} diff --git a/model/orbax/experimental/model/tf2obm/utils.py b/model/orbax/experimental/model/tf2obm/utils.py index b77f535df..4b8274aaa 100644 --- a/model/orbax/experimental/model/tf2obm/utils.py +++ b/model/orbax/experimental/model/tf2obm/utils.py @@ -53,17 +53,11 @@ def tf_dtype_to_obm(t: tf.DType) -> obm.ShloDType: return obm.np_dtype_to_shlo_dtype(np_dtype) -def tf_tensor_spec_to_obm(spec: Any) -> obm.ShloTensorSpec: +def tf_tensor_spec_to_obm( + spec: tf.TensorSpec | tf.Tensor, +) -> obm.ShloTensorSpec: """Converts a tf.TensorSpec or tf.SymbolicTensor to ShloTensorSpec.""" - # ConcreteFunction.structured_outputs returns `SymbolicTensor`s, not - # `TensorSpec`s, so we need to also check for `SymbolicTensor`. - if not (isinstance(spec, tf.TensorSpec) or tf.is_symbolic_tensor(spec)): - raise ValueError( - f'Expected a tf.TensorSpec or a SymbolicTensor, got {spec} of type' - f' {type(spec)}' - ) - if spec.shape.rank is None: obm_shape = None else: diff --git a/model/orbax/experimental/model/voxel2obm/utils.py b/model/orbax/experimental/model/voxel2obm/utils.py index 20a4b8df3..3c3604634 100644 --- a/model/orbax/experimental/model/voxel2obm/utils.py +++ b/model/orbax/experimental/model/voxel2obm/utils.py @@ -50,6 +50,7 @@ def _voxel_to_obm_dtype(t) -> obm.ShloDType: return obm.np_dtype_to_shlo_dtype(t) +# TODO: b/476448823 - Add name to the output ShloTensorSpec. def voxel_signature_to_obm_spec( signature: jd.VoxelSchemaTree, ) -> obm.Tree[obm.ShloTensorSpec]: