Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,33 @@
PLACEHOLDER = type_handlers.PLACEHOLDER


def deep_namedtuple_to_dict(obj):
"""Recursively converts namedtuples and tuples within a PyTree to dicts and lists.

Args:
obj: The object (PyTree) to convert.

Returns:
A new object with namedtuples converted to dicts and tuples converted to
lists, recursively. Other types are preserved.
"""
if hasattr(obj, '_asdict'): # Check if it's a namedtuple
# Convert namedtuple to dict and recurse on its values
return {k: deep_namedtuple_to_dict(v) for k, v in obj._asdict().items()}
elif isinstance(obj, dict):
# Recurse on dictionary values
return {k: deep_namedtuple_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
# Recurse on list items
return [deep_namedtuple_to_dict(elem) for elem in obj]
elif isinstance(obj, tuple):
# Convert tuple to list and recurse on its items
return [deep_namedtuple_to_dict(elem) for elem in obj]
else:
# Base case: not a namedtuple, dict, list, or tuple
return obj


class CheckpointerTestBase:
"""Common tests for AbstractCheckpointer subclasses."""

Expand Down Expand Up @@ -495,9 +522,12 @@ def test_save_restore_named_tuple(
directory = self.directory / 'rich_typed_metadata'
checkpointer.save(directory, save_input_provider())
self.wait_if_async(checkpointer)
# TODO: b/365169723 - Update this test when restore is ready.
with self.assertRaises(NotImplementedError):
_ = checkpointer.restore(directory)
restored = checkpointer.restore(directory)
test_utils.assert_tree_equal(
self,
deep_namedtuple_to_dict(restored),
expected_restored_provider(),
)

def test_save_step_metadata(self):
"""Basic save and restore test."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def _group_value(
if info.skip_deserialize:
return

if not isinstance(arg, (SaveArgs, RestoreArgs)):
if tree_utils.is_empty_node(arg):
return

if isinstance(arg, RestoreArgs):
assert isinstance(value, tree_metadata.ValueMetadataEntry), type(value)
metadata_restore_type = value.value_type
Expand Down Expand Up @@ -1044,6 +1048,19 @@ class TrainState:
restore_args = tree_metadata.serialize_tree(
restore_args, self._pytree_metadata_options
)
if item is not None:
try:
value_metadata_tree_deserialized = tree_utils.deserialize_tree(
value_metadata_tree, item
)
restore_args_deserialized = tree_utils.deserialize_tree(
restore_args, item
)
value_metadata_tree = value_metadata_tree_deserialized
restore_args = restore_args_deserialized
except ValueError:
pass

param_infos = self._get_param_infos(
item=value_metadata_tree,
directory=directory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts


PyTree = Any
TupleKey = Tuple[str, ...]
RestoreArgs = type_handlers.RestoreArgs
Expand Down Expand Up @@ -815,11 +814,6 @@ class TrainState:
ValueError: `transforms` is provided without `item`.
ValueError: `transforms` contains elements with `multi_value_fn`.
"""
if self._pytree_metadata_options.support_rich_types:
raise NotImplementedError(
'Restore is not supported for rich typed metadata yet. Please set'
' PyTreeMetadataOptions.support_rich_types=False.'
)
if not directory.exists():
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}.'
Expand Down Expand Up @@ -871,7 +865,23 @@ class TrainState:
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}'
)
structure, use_zarr3_metadata = self._get_internal_metadata(directory)
try:
structure, use_zarr3_metadata = self._get_internal_metadata(directory)
except FileNotFoundError:
if item is None:
raise
# If the checkpoint doesn't have a structure file, use the item as the
# structure. This is for backward compatibility with checkpoints that
# don't have a structure file.
structure = jax.tree.map(
lambda x: tree_metadata.ValueMetadataEntry(
value_type=empty_values.RESTORE_TYPE_UNKNOWN,
skip_deserialize=False,
),
item,
)
use_zarr3_metadata = None

# `checkpoint_restore_args` has a structure relative to the checkpoint,
# while `restore_args` remains structured relative to the output.

Expand Down
21 changes: 15 additions & 6 deletions checkpoint/orbax/checkpoint/_src/tree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,23 @@ def deserialize_tree(
) -> PyTree:
"""Deserializes a PyTree to the same structure as `target`."""

def _reconstruct_from_keypath(keypath, _):
def _reconstruct_from_keypath(keypath, x):
del x
result = serialized
for key in keypath:
key_name = get_key_name(key)
# Special case to support Pax.
if not isinstance(result, list) and key_name not in result:
key_name = str(key_name)
result = result[key_name]
if isinstance(key, jax.tree_util.SequenceKey):
result = result[key.idx]
elif isinstance(key, jax.tree_util.DictKey):
result = result[key.key]
elif isinstance(key, jax.tree_util.GetAttrKey):
if isinstance_of_namedtuple(result):
result = getattr(result, key.name)
else:
result = result[key.name]
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
result = result[key.key]
else:
raise ValueError(f'Unsupported KeyEntry: {type(key)}: "{key}"')
return result

return jax.tree_util.tree_map_with_path(
Expand Down
Loading