diff --git a/export/orbax/__init__.py b/export/orbax/__init__.py index 8f45ab4ac..c144925e8 100644 --- a/export/orbax/__init__.py +++ b/export/orbax/__init__.py @@ -20,6 +20,7 @@ from orbax.export import obm_configs from orbax.export import typing from orbax.export import utils +from orbax.export.data_processors.tf_data_processor import TfDataProcessor from orbax.export.dtensor_utils import dtensor_initialized from orbax.export.dtensor_utils import get_current_dtensor_mesh from orbax.export.dtensor_utils import get_current_mesh @@ -29,6 +30,7 @@ from orbax.export.dtensor_utils import shutdown_dtensor from orbax.export.export_manager import ExportManager from orbax.export.jax_module import JaxModule +from orbax.export.obm_configs import Jax2ObmOptions from orbax.export.serving_config import ServingConfig # TODO(dinghua): remove them after we change all references to # utils.remove_signature_defaults. diff --git a/export/orbax/export/__init__.py b/export/orbax/export/__init__.py index 8f45ab4ac..c144925e8 100644 --- a/export/orbax/export/__init__.py +++ b/export/orbax/export/__init__.py @@ -20,6 +20,7 @@ from orbax.export import obm_configs from orbax.export import typing from orbax.export import utils +from orbax.export.data_processors.tf_data_processor import TfDataProcessor from orbax.export.dtensor_utils import dtensor_initialized from orbax.export.dtensor_utils import get_current_dtensor_mesh from orbax.export.dtensor_utils import get_current_mesh @@ -29,6 +30,7 @@ from orbax.export.dtensor_utils import shutdown_dtensor from orbax.export.export_manager import ExportManager from orbax.export.jax_module import JaxModule +from orbax.export.obm_configs import Jax2ObmOptions from orbax.export.serving_config import ServingConfig # TODO(dinghua): remove them after we change all references to # utils.remove_signature_defaults. diff --git a/export/orbax/export/testdata/expected_complex_graph_oex_orchestration_pipelines.textproto b/export/orbax/export/testdata/expected_complex_graph_oex_orchestration_pipelines.textproto index fed6fe910..f49c506f6 100644 --- a/export/orbax/export/testdata/expected_complex_graph_oex_orchestration_pipelines.textproto +++ b/export/orbax/export/testdata/expected_complex_graph_oex_orchestration_pipelines.textproto @@ -158,7 +158,7 @@ name_to_pipeline { named_signature { inputs { named_tensor_types { - name: "inputs_model_output" + name: "model_output" tensor_type { shape { shape_with_known_rank { diff --git a/export/orbax/export/testdata/expected_jax_only_oex_orchestration_pipelines.textproto b/export/orbax/export/testdata/expected_jax_only_oex_orchestration_pipelines.textproto new file mode 100644 index 000000000..cba16942e --- /dev/null +++ b/export/orbax/export/testdata/expected_jax_only_oex_orchestration_pipelines.textproto @@ -0,0 +1,169 @@ +# proto-file: orbax/export/protos/oex_orchestration.proto +# proto-message: Pipelines + +name_to_pipeline { + key: "__SIGNATURE_KEY__" + value { + signature { + input { + tuple { + elements { + tuple { + elements { + leaf { + tensor_type { + shape { + shape_with_known_rank { + dimension_sizes { + size: 2 + } + dimension_sizes { + size: 2 + } + } + } + dtype: f32 + } + } + } + } + } + elements { + dict { + } + } + } + } + output { + dict { + string_to_type { + key: "a" + value { + leaf { + tensor_type { + shape { + shape_with_known_rank { + dimension_sizes { + size: 2 + } + dimension_sizes { + size: 2 + } + } + } + dtype: f32 + } + } + } + } + string_to_type { + key: "b" + value { + dict { + string_to_type { + key: "c" + value { + leaf { + tensor_type { + shape { + shape_with_known_rank { + dimension_sizes { + size: 2 + } + dimension_sizes { + size: 2 + } + } + } + dtype: f32 + } + } + } + } + } + } + } + } + } + } + named_signature { + outputs { + named_tensor_types { + name: "a" + tensor_type { + shape { + shape_with_known_rank { + dimension_sizes { + size: 2 + } + dimension_sizes { + size: 2 + } + } + } + dtype: f32 + } + } + named_tensor_types { + name: "b.c" + tensor_type { + shape { + shape_with_known_rank { + dimension_sizes { + size: 2 + } + dimension_sizes { + size: 2 + } + } + } + dtype: f32 + } + } + } + } + model_functions { + model_function_name: "__MODEL_FUNCTION_NAME__" + } + components { + function_name: "__MODEL_FUNCTION_NAME__" + role: ROLE_MODEL + named_signature { + outputs { + named_tensor_types { + name: "a" + tensor_type { + shape { + shape_with_known_rank { + dimension_sizes { + size: 2 + } + dimension_sizes { + size: 2 + } + } + } + dtype: f32 + } + } + named_tensor_types { + name: "b.c" + tensor_type { + shape { + shape_with_known_rank { + dimension_sizes { + size: 2 + } + dimension_sizes { + size: 2 + } + } + } + dtype: f32 + } + } + } + } + } + } +} diff --git a/model/orbax/experimental/model/jax2obm/jax_specific_info.py b/model/orbax/experimental/model/jax2obm/jax_specific_info.py index bcfdf0a2d..e46402217 100644 --- a/model/orbax/experimental/model/jax2obm/jax_specific_info.py +++ b/model/orbax/experimental/model/jax2obm/jax_specific_info.py @@ -16,7 +16,7 @@ # pylint: disable=g-importing-member import dataclasses -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar +from typing import Any, Iterable, List, Optional, Sequence, Tuple, TypeVar import jax # Somehow JAX requires this import to make `jax.export` available. @@ -30,7 +30,6 @@ from orbax.experimental.model.jax2obm.jax_supplemental_pb2 import ShapeDTypeRefinements from orbax.experimental.model.jax2obm.jax_supplemental_pb2 import ShapeRefinement - T1 = TypeVar("T1") T2 = TypeVar("T2") @@ -47,6 +46,48 @@ def unzip2( return tuple(xs), tuple(ys) +def _name_leaf( + path: jax.tree_util.KeyPath, leaf: obm.ShloTensorSpec +) -> obm.ShloTensorSpec: + """Assigns a name to a leaf node in-place based on its PyTree path. + + The name is a dot-separated string representation of the path keys, extracted + from dictionary keys, sequence indices, or attribute names. This function is + designed to be used as the transform function in + `jax.tree_util.tree_map_with_path`. + + The `leaf` argument's `name` attribute is modified in-place. + + Args: + path: A tuple of path elements (e.g., DictKey, SequenceKey, GetAttrKey) + provided by `tree_map_with_path`, representing the path to the leaf in a + PyTree. + leaf: The leaf node to be named. It must have a `name` attribute that can be + set. + + Returns: + The leaf node with its `name` attribute set to the dot-separated path + string. + + Raises: + TypeError: If an unknown key type is encountered in the path. + """ + path_str_parts = [] + for key in path: + if isinstance(key, jax.tree_util.DictKey): + path_str_parts.append(str(key.key)) + elif isinstance(key, jax.tree_util.SequenceKey): + path_str_parts.append(str(key.idx)) + elif isinstance(key, jax.tree_util.GetAttrKey): + path_str_parts.append(str(key.name)) + elif isinstance(key, jax.tree_util.FlattenedIndexKey): + path_str_parts.append(str(key.key)) + else: + raise TypeError(f"Unknown key type: {type(key)}") + leaf.name = ".".join(path_str_parts) + return leaf + + def _serialize_effect(eff: jax.core.Effect) -> str: """Serializes a JAX Effect to a string. @@ -311,6 +352,7 @@ def _to_shlo_spec_tree_and_refinement_tuple( avals: Sequence[jax.core.AbstractValue], shardings: Sequence[Any], tree_def: Optional[jax.tree_util.PyTreeDef], + name_leaves: bool = False, ) -> Tuple[ obm.Tree[obm.ShloTensorSpec], Tuple[ShapeDTypeRefinementPair, ...] | None ]: @@ -329,4 +371,6 @@ def assert_leaf(x: Any) -> None: ) obm.tree_util.assert_tree(assert_leaf, jax_tree) jax_tree: obm.Tree[obm.ShloTensorSpec] + if name_leaves: + jax_tree = jax.tree_util.tree_map_with_path(_name_leaf, jax_tree) return jax_tree, refinements diff --git a/model/orbax/experimental/model/jax2obm/jax_specific_info_test.py b/model/orbax/experimental/model/jax2obm/jax_specific_info_test.py index 908c32680..b4a818715 100644 --- a/model/orbax/experimental/model/jax2obm/jax_specific_info_test.py +++ b/model/orbax/experimental/model/jax2obm/jax_specific_info_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from absl.testing import absltest from absl.testing import parameterized import jax import jax.numpy as jnp @@ -21,7 +22,25 @@ from orbax.experimental.model.jax2obm import jax_supplemental_pb2 from tensorflow.python.util.protobuf import compare from google.protobuf import text_format -from absl.testing import absltest + + +def _get_spec(): + """Returns a dummy spec, creates a new instance each time.""" + return obm.ShloTensorSpec(shape=(1,), dtype=obm.ShloDType.f32) + + +class _CustomNode: + + def __init__(self, x, y): + self.x = x + self.y = y + + +jax.tree_util.register_pytree_node( + _CustomNode, + lambda node: ((node.x, node.y), None), + lambda _, children: _CustomNode(*children), +) class JaxSpecificInfoTest(parameterized.TestCase): @@ -170,6 +189,50 @@ def test_to_shlo_dtype_and_refinement_wrong_type(self, jax_dtype): ): jax_specific_info._to_shlo_dtype_and_refinement(jax_dtype) + @parameterized.named_parameters( + dict( + testcase_name='dict_and_list', + tree={'a': [_get_spec(), _get_spec()], 'b': {'c': _get_spec()}}, + expected_names=['a.0', 'a.1', 'b.c'], + ), + dict( + testcase_name='tuple_and_nested_dict', + tree=(_get_spec(), {'x': _get_spec()}), + expected_names=['0', '1.x'], + ), + dict( + testcase_name='list_of_tuples', + tree=[(_get_spec(),), (_get_spec(), _get_spec())], + expected_names=['0.0', '1.0', '1.1'], + ), + dict( + testcase_name='custom_node', + tree={'node': _CustomNode(_get_spec(), _get_spec())}, + expected_names=['node.0', 'node.1'], + ), + dict( + testcase_name='single_spec', + tree=_get_spec(), + expected_names=[''], + ), + dict( + testcase_name='list_with_one_spec', + tree=[_get_spec()], + expected_names=['0'], + ), + ) + def test_name_leaf(self, tree, expected_names): + def _name_leaf_wrapper(tree): + return jax.tree_util.tree_map_with_path( + jax_specific_info._name_leaf, tree + ) + + named_tree = _name_leaf_wrapper(tree) + leaves, treedef = jax.tree_util.tree_flatten(named_tree) + self.assertEqual(treedef, jax.tree_util.tree_structure(tree)) + self.assertLen(leaves, len(expected_names)) + self.assertEqual([leaf.name for leaf in leaves], expected_names) + if __name__ == '__main__': absltest.main() diff --git a/model/orbax/experimental/model/jax2obm/main_lib.py b/model/orbax/experimental/model/jax2obm/main_lib.py index f16a54ab3..324ac2a09 100644 --- a/model/orbax/experimental/model/jax2obm/main_lib.py +++ b/model/orbax/experimental/model/jax2obm/main_lib.py @@ -53,6 +53,7 @@ def jax_exported_to_shlo_fn( sharding.hlo_sharding_to_op_sharding(sd) for sd in exported.out_shardings_hlo ]) + # TODO: b/476448823 - properly get the name for the input signature. shlo_in_sig, jax_in_sig_refinements = ( jax_specific_info._to_shlo_spec_tree_and_refinement_tuple( exported.in_avals, @@ -60,11 +61,22 @@ def jax_exported_to_shlo_fn( exported.in_tree, ) ) + # Since jax.ShapeDtypeStruct does not have a name field, we assign + # names to output tensors specs when converting them to ShloTensorSpec by + # passing `name_leaves=True`. This ensures that the JAX model + # produces a NamedSignature for its output (e.g., {'results': + # ShloTensorSpec(...)}), allowing downstream components (e.g., TF data + # processors) to reference outputs by name in keyword-based pipelines. + # This will prevent signature mismatches that could otherwise occur, e.g., the + # JAX model has the output signature like "model_output: ShloTensorSpec(...)", + # while the following tf data processor has the input signature like + # "input_model_output: ShloTensorSpec(...)". shlo_out_sig, jax_out_sig_refinements = ( jax_specific_info._to_shlo_spec_tree_and_refinement_tuple( exported.out_avals, out_shardings_hlo, exported.out_tree, + name_leaves=True, ) ) supplemental_info_ = {} diff --git a/model/orbax/experimental/model/voxel2obm/utils.py b/model/orbax/experimental/model/voxel2obm/utils.py index 20a4b8df3..3c3604634 100644 --- a/model/orbax/experimental/model/voxel2obm/utils.py +++ b/model/orbax/experimental/model/voxel2obm/utils.py @@ -50,6 +50,7 @@ def _voxel_to_obm_dtype(t) -> obm.ShloDType: return obm.np_dtype_to_shlo_dtype(t) +# TODO: b/476448823 - Add name to the output ShloTensorSpec. def voxel_signature_to_obm_spec( signature: jd.VoxelSchemaTree, ) -> obm.Tree[obm.ShloTensorSpec]: