From e6435f0a2b7442236cd7bad6f5f9a8fbe6dd21db Mon Sep 17 00:00:00 2001 From: Justin Pan Date: Thu, 15 Jan 2026 15:07:07 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 856835907 --- .../checkpointers/checkpointer_test_utils.py | 36 +++++++++++++++++-- .../base_pytree_checkpoint_handler.py | 17 +++++++++ .../handlers/pytree_checkpoint_handler.py | 24 +++++++++---- .../orbax/checkpoint/_src/tree/utils.py | 21 +++++++---- 4 files changed, 82 insertions(+), 16 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer_test_utils.py b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer_test_utils.py index 296a92b75..1fd996f62 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer_test_utils.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer_test_utils.py @@ -67,6 +67,33 @@ PLACEHOLDER = type_handlers.PLACEHOLDER +def deep_namedtuple_to_dict(obj): + """Recursively converts namedtuples and tuples within a PyTree to dicts and lists. + + Args: + obj: The object (PyTree) to convert. + + Returns: + A new object with namedtuples converted to dicts and tuples converted to + lists, recursively. Other types are preserved. + """ + if hasattr(obj, '_asdict'): # Check if it's a namedtuple + # Convert namedtuple to dict and recurse on its values + return {k: deep_namedtuple_to_dict(v) for k, v in obj._asdict().items()} + elif isinstance(obj, dict): + # Recurse on dictionary values + return {k: deep_namedtuple_to_dict(v) for k, v in obj.items()} + elif isinstance(obj, list): + # Recurse on list items + return [deep_namedtuple_to_dict(elem) for elem in obj] + elif isinstance(obj, tuple): + # Convert tuple to list and recurse on its items + return [deep_namedtuple_to_dict(elem) for elem in obj] + else: + # Base case: not a namedtuple, dict, list, or tuple + return obj + + class CheckpointerTestBase: """Common tests for AbstractCheckpointer subclasses.""" @@ -495,9 +522,12 @@ def test_save_restore_named_tuple( directory = self.directory / 'rich_typed_metadata' checkpointer.save(directory, save_input_provider()) self.wait_if_async(checkpointer) - # TODO: b/365169723 - Update this test when restore is ready. - with self.assertRaises(NotImplementedError): - _ = checkpointer.restore(directory) + restored = checkpointer.restore(directory) + test_utils.assert_tree_equal( + self, + deep_namedtuple_to_dict(restored), + expected_restored_provider(), + ) def test_save_step_metadata(self): """Basic save and restore test.""" diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index b0ffeed39..9e4a59726 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -209,6 +209,10 @@ def _group_value( if info.skip_deserialize: return + if not isinstance(arg, (SaveArgs, RestoreArgs)): + if tree_utils.is_empty_node(arg): + return + if isinstance(arg, RestoreArgs): assert isinstance(value, tree_metadata.ValueMetadataEntry), type(value) metadata_restore_type = value.value_type @@ -1044,6 +1048,19 @@ class TrainState: restore_args = tree_metadata.serialize_tree( restore_args, self._pytree_metadata_options ) + if item is not None: + try: + value_metadata_tree_deserialized = tree_utils.deserialize_tree( + value_metadata_tree, item + ) + restore_args_deserialized = tree_utils.deserialize_tree( + restore_args, item + ) + value_metadata_tree = value_metadata_tree_deserialized + restore_args = restore_args_deserialized + except ValueError: + pass + param_infos = self._get_param_infos( item=value_metadata_tree, directory=directory, diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py index a5c4239dd..f4a38d6ae 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py @@ -53,7 +53,6 @@ from orbax.checkpoint._src.tree import utils as tree_utils import tensorstore as ts - PyTree = Any TupleKey = Tuple[str, ...] RestoreArgs = type_handlers.RestoreArgs @@ -815,11 +814,6 @@ class TrainState: ValueError: `transforms` is provided without `item`. ValueError: `transforms` contains elements with `multi_value_fn`. """ - if self._pytree_metadata_options.support_rich_types: - raise NotImplementedError( - 'Restore is not supported for rich typed metadata yet. Please set' - ' PyTreeMetadataOptions.support_rich_types=False.' - ) if not directory.exists(): raise FileNotFoundError( f'Requested directory for restore does not exist at {directory}.' @@ -871,7 +865,23 @@ class TrainState: raise FileNotFoundError( f'Requested directory for restore does not exist at {directory}' ) - structure, use_zarr3_metadata = self._get_internal_metadata(directory) + try: + structure, use_zarr3_metadata = self._get_internal_metadata(directory) + except FileNotFoundError: + if item is None: + raise + # If the checkpoint doesn't have a structure file, use the item as the + # structure. This is for backward compatibility with checkpoints that + # don't have a structure file. + structure = jax.tree.map( + lambda x: tree_metadata.ValueMetadataEntry( + value_type=empty_values.RESTORE_TYPE_UNKNOWN, + skip_deserialize=False, + ), + item, + ) + use_zarr3_metadata = None + # `checkpoint_restore_args` has a structure relative to the checkpoint, # while `restore_args` remains structured relative to the output. diff --git a/checkpoint/orbax/checkpoint/_src/tree/utils.py b/checkpoint/orbax/checkpoint/_src/tree/utils.py index f770cd30c..73ec5a168 100644 --- a/checkpoint/orbax/checkpoint/_src/tree/utils.py +++ b/checkpoint/orbax/checkpoint/_src/tree/utils.py @@ -231,14 +231,23 @@ def deserialize_tree( ) -> PyTree: """Deserializes a PyTree to the same structure as `target`.""" - def _reconstruct_from_keypath(keypath, _): + def _reconstruct_from_keypath(keypath, x): + del x result = serialized for key in keypath: - key_name = get_key_name(key) - # Special case to support Pax. - if not isinstance(result, list) and key_name not in result: - key_name = str(key_name) - result = result[key_name] + if isinstance(key, jax.tree_util.SequenceKey): + result = result[key.idx] + elif isinstance(key, jax.tree_util.DictKey): + result = result[key.key] + elif isinstance(key, jax.tree_util.GetAttrKey): + if isinstance_of_namedtuple(result): + result = getattr(result, key.name) + else: + result = result[key.name] + elif isinstance(key, jax.tree_util.FlattenedIndexKey): + result = result[key.key] + else: + raise ValueError(f'Unsupported KeyEntry: {type(key)}: "{key}"') return result return jax.tree_util.tree_map_with_path(