Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions export/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from orbax.export import obm_configs
from orbax.export import typing
from orbax.export import utils
from orbax.export.data_processors.tf_data_processor import TfDataProcessor
from orbax.export.dtensor_utils import dtensor_initialized
from orbax.export.dtensor_utils import get_current_dtensor_mesh
from orbax.export.dtensor_utils import get_current_mesh
Expand All @@ -29,6 +30,7 @@
from orbax.export.dtensor_utils import shutdown_dtensor
from orbax.export.export_manager import ExportManager
from orbax.export.jax_module import JaxModule
from orbax.export.obm_configs import Jax2ObmOptions
from orbax.export.serving_config import ServingConfig
# TODO(dinghua): remove them after we change all references to
# utils.remove_signature_defaults.
Expand Down
2 changes: 2 additions & 0 deletions export/orbax/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from orbax.export import obm_configs
from orbax.export import typing
from orbax.export import utils
from orbax.export.data_processors.tf_data_processor import TfDataProcessor
from orbax.export.dtensor_utils import dtensor_initialized
from orbax.export.dtensor_utils import get_current_dtensor_mesh
from orbax.export.dtensor_utils import get_current_mesh
Expand All @@ -29,6 +30,7 @@
from orbax.export.dtensor_utils import shutdown_dtensor
from orbax.export.export_manager import ExportManager
from orbax.export.jax_module import JaxModule
from orbax.export.obm_configs import Jax2ObmOptions
from orbax.export.serving_config import ServingConfig
# TODO(dinghua): remove them after we change all references to
# utils.remove_signature_defaults.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ name_to_pipeline {
named_signature {
inputs {
named_tensor_types {
name: "inputs_model_output"
name: "model_output"
tensor_type {
shape {
shape_with_known_rank {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# proto-file: orbax/export/protos/oex_orchestration.proto
# proto-message: Pipelines

name_to_pipeline {
key: "__SIGNATURE_KEY__"
value {
signature {
input {
tuple {
elements {
tuple {
elements {
leaf {
tensor_type {
shape {
shape_with_known_rank {
dimension_sizes {
size: 2
}
dimension_sizes {
size: 2
}
}
}
dtype: f32
}
}
}
}
}
elements {
dict {
}
}
}
}
output {
dict {
string_to_type {
key: "a"
value {
leaf {
tensor_type {
shape {
shape_with_known_rank {
dimension_sizes {
size: 2
}
dimension_sizes {
size: 2
}
}
}
dtype: f32
}
}
}
}
string_to_type {
key: "b"
value {
dict {
string_to_type {
key: "c"
value {
leaf {
tensor_type {
shape {
shape_with_known_rank {
dimension_sizes {
size: 2
}
dimension_sizes {
size: 2
}
}
}
dtype: f32
}
}
}
}
}
}
}
}
}
}
named_signature {
outputs {
named_tensor_types {
name: "a"
tensor_type {
shape {
shape_with_known_rank {
dimension_sizes {
size: 2
}
dimension_sizes {
size: 2
}
}
}
dtype: f32
}
}
named_tensor_types {
name: "b.c"
tensor_type {
shape {
shape_with_known_rank {
dimension_sizes {
size: 2
}
dimension_sizes {
size: 2
}
}
}
dtype: f32
}
}
}
}
model_functions {
model_function_name: "__MODEL_FUNCTION_NAME__"
}
components {
function_name: "__MODEL_FUNCTION_NAME__"
role: ROLE_MODEL
named_signature {
outputs {
named_tensor_types {
name: "a"
tensor_type {
shape {
shape_with_known_rank {
dimension_sizes {
size: 2
}
dimension_sizes {
size: 2
}
}
}
dtype: f32
}
}
named_tensor_types {
name: "b.c"
tensor_type {
shape {
shape_with_known_rank {
dimension_sizes {
size: 2
}
dimension_sizes {
size: 2
}
}
}
dtype: f32
}
}
}
}
}
}
}
48 changes: 46 additions & 2 deletions model/orbax/experimental/model/jax2obm/jax_specific_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# pylint: disable=g-importing-member
import dataclasses
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar
from typing import Any, Iterable, List, Optional, Sequence, Tuple, TypeVar

import jax
# Somehow JAX requires this import to make `jax.export` available.
Expand All @@ -30,7 +30,6 @@
from orbax.experimental.model.jax2obm.jax_supplemental_pb2 import ShapeDTypeRefinements
from orbax.experimental.model.jax2obm.jax_supplemental_pb2 import ShapeRefinement


T1 = TypeVar("T1")
T2 = TypeVar("T2")

Expand All @@ -47,6 +46,48 @@ def unzip2(
return tuple(xs), tuple(ys)


def _name_leaf(
path: jax.tree_util.KeyPath, leaf: obm.ShloTensorSpec
) -> obm.ShloTensorSpec:
"""Assigns a name to a leaf node in-place based on its PyTree path.

The name is a dot-separated string representation of the path keys, extracted
from dictionary keys, sequence indices, or attribute names. This function is
designed to be used as the transform function in
`jax.tree_util.tree_map_with_path`.

The `leaf` argument's `name` attribute is modified in-place.

Args:
path: A tuple of path elements (e.g., DictKey, SequenceKey, GetAttrKey)
provided by `tree_map_with_path`, representing the path to the leaf in a
PyTree.
leaf: The leaf node to be named. It must have a `name` attribute that can be
set.

Returns:
The leaf node with its `name` attribute set to the dot-separated path
string.

Raises:
TypeError: If an unknown key type is encountered in the path.
"""
path_str_parts = []
for key in path:
if isinstance(key, jax.tree_util.DictKey):
path_str_parts.append(str(key.key))
elif isinstance(key, jax.tree_util.SequenceKey):
path_str_parts.append(str(key.idx))
elif isinstance(key, jax.tree_util.GetAttrKey):
path_str_parts.append(str(key.name))
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
path_str_parts.append(str(key.key))
else:
raise TypeError(f"Unknown key type: {type(key)}")
leaf.name = ".".join(path_str_parts)
return leaf


def _serialize_effect(eff: jax.core.Effect) -> str:
"""Serializes a JAX Effect to a string.

Expand Down Expand Up @@ -311,6 +352,7 @@ def _to_shlo_spec_tree_and_refinement_tuple(
avals: Sequence[jax.core.AbstractValue],
shardings: Sequence[Any],
tree_def: Optional[jax.tree_util.PyTreeDef],
name_leaves: bool = False,
) -> Tuple[
obm.Tree[obm.ShloTensorSpec], Tuple[ShapeDTypeRefinementPair, ...] | None
]:
Expand All @@ -329,4 +371,6 @@ def assert_leaf(x: Any) -> None:
)
obm.tree_util.assert_tree(assert_leaf, jax_tree)
jax_tree: obm.Tree[obm.ShloTensorSpec]
if name_leaves:
jax_tree = jax.tree_util.tree_map_with_path(_name_leaf, jax_tree)
return jax_tree, refinements
65 changes: 64 additions & 1 deletion model/orbax/experimental/model/jax2obm/jax_specific_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
Expand All @@ -21,7 +22,25 @@
from orbax.experimental.model.jax2obm import jax_supplemental_pb2
from tensorflow.python.util.protobuf import compare
from google.protobuf import text_format
from absl.testing import absltest


def _get_spec():
"""Returns a dummy spec, creates a new instance each time."""
return obm.ShloTensorSpec(shape=(1,), dtype=obm.ShloDType.f32)


class _CustomNode:

def __init__(self, x, y):
self.x = x
self.y = y


jax.tree_util.register_pytree_node(
_CustomNode,
lambda node: ((node.x, node.y), None),
lambda _, children: _CustomNode(*children),
)


class JaxSpecificInfoTest(parameterized.TestCase):
Expand Down Expand Up @@ -170,6 +189,50 @@ def test_to_shlo_dtype_and_refinement_wrong_type(self, jax_dtype):
):
jax_specific_info._to_shlo_dtype_and_refinement(jax_dtype)

@parameterized.named_parameters(
dict(
testcase_name='dict_and_list',
tree={'a': [_get_spec(), _get_spec()], 'b': {'c': _get_spec()}},
expected_names=['a.0', 'a.1', 'b.c'],
),
dict(
testcase_name='tuple_and_nested_dict',
tree=(_get_spec(), {'x': _get_spec()}),
expected_names=['0', '1.x'],
),
dict(
testcase_name='list_of_tuples',
tree=[(_get_spec(),), (_get_spec(), _get_spec())],
expected_names=['0.0', '1.0', '1.1'],
),
dict(
testcase_name='custom_node',
tree={'node': _CustomNode(_get_spec(), _get_spec())},
expected_names=['node.0', 'node.1'],
),
dict(
testcase_name='single_spec',
tree=_get_spec(),
expected_names=[''],
),
dict(
testcase_name='list_with_one_spec',
tree=[_get_spec()],
expected_names=['0'],
),
)
def test_name_leaf(self, tree, expected_names):
def _name_leaf_wrapper(tree):
return jax.tree_util.tree_map_with_path(
jax_specific_info._name_leaf, tree
)

named_tree = _name_leaf_wrapper(tree)
leaves, treedef = jax.tree_util.tree_flatten(named_tree)
self.assertEqual(treedef, jax.tree_util.tree_structure(tree))
self.assertLen(leaves, len(expected_names))
self.assertEqual([leaf.name for leaf in leaves], expected_names)


if __name__ == '__main__':
absltest.main()
Loading
Loading