diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/args.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/args.py new file mode 100644 index 000000000..d52dc848c --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/args.py @@ -0,0 +1,39 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""P2P composite checkpoint argument.""" + +from typing import Any, final +from orbax.checkpoint import args as args_lib +from orbax.checkpoint.experimental.emergency.p2p import constants + + +@final +class Composite(args_lib.Composite): + """Composite argument that only supports 'state' key.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if constants.STATE_SUBDIR not in self or len(self) > 1: + raise ValueError( + f'Composite must contain "{constants.STATE_SUBDIR}" key and no other' + f' keys: {list(self.keys())}' + ) + + def __setitem__(self, key: str, value: Any): + if key != constants.STATE_SUBDIR: + raise KeyError( + f'Invalid key: {key}. Only "{constants.STATE_SUBDIR}" is supported.' + ) + self[key] = value diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/local.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/local.py new file mode 100644 index 000000000..d0cc16cb6 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/local.py @@ -0,0 +1,142 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal checkpoint manager for local P2P storage logic.""" + +from typing import Any, Sequence, final + +from etils import epath +import jax +import orbax.checkpoint as ocp +from orbax.checkpoint import checkpoint_manager +from orbax.checkpoint import type_handlers +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.serialization import type_handler_registry +from orbax.checkpoint.experimental.emergency.checkpoint_manager import CheckpointManagerOptions +from orbax.checkpoint.experimental.emergency.p2p import utils +from orbax.checkpoint.experimental.emergency.p2p.args import Composite + + +@final +class LocalCheckpointManager: + """Wrapper around Orbax CheckpointManager for local P2P shards.""" + + def __init__( + self, + directory: epath.PathLike, + global_mesh: jax.sharding.Mesh, + *, + options: CheckpointManagerOptions, + ): + self._directory = epath.Path(directory) + self._global_mesh = global_mesh + self._process_index = multihost.process_index() + + barrier_sync_key_prefix = f'p2p_shard_{self._process_index}' + mp_options = ocp.options.MultiprocessingOptions( + primary_host=None, # Symmetric read/write + active_processes={self._process_index}, # Only I write to my shard + barrier_sync_key_prefix=barrier_sync_key_prefix, + ) + + p2p_specific_options = checkpoint_manager.CheckpointManagerOptions( + step_name_format=options.step_name_format, + save_interval_steps=options.local.save_interval_steps, + max_to_keep=options.local.max_to_keep, + should_save_fn=options.local.should_save_fn, + multiprocessing_options=mp_options, + create=False, + cleanup_tmp_directories=False, + enable_background_delete=True, + enable_per_process_directory_creation=True, + ) + + local_registry = type_handler_registry.create_type_handler_registry(( + jax.Array, + type_handlers.ArrayHandler( + primary_host=None, replica_id=None, use_replica_parallel=False + ), + )) + + handler = ocp.PyTreeCheckpointHandler( + use_ocdbt=True, + use_zarr3=True, + multiprocessing_options=mp_options, + type_handler_registry=local_registry, + ) + + self._manager = checkpoint_manager.CheckpointManager( + self._directory, + options=p2p_specific_options, + item_handlers=dict(state=handler), + ) + + @property + def directory(self) -> epath.Path: + return self._directory + + def scan_stored_steps(self) -> tuple[int | None, Sequence[int]]: + """Identifies available steps and the stored process index (from latest).""" + if not self._directory.exists(): + return None, [] + + steps = self._manager.all_steps() + if not steps: + return None, [] + + latest = steps[-1] + detected_index = utils.detect_process_index(self._directory, latest) + + if detected_index is None: + raise ValueError( + f'Failed to detect process index for step {latest} in' + f' {self._directory}. Checkpoint may be malformed.' + ) + + return detected_index, steps + + def save(self, step: int, args: Composite, *, force: bool = False) -> bool: + """Saves the checkpoint.""" + return self._manager.save(step, args=args, force=force) + + def restore( + self, + step: int, + *, + directory: epath.PathLike | None = None, + ) -> Composite: + """Restores the checkpoint, enforcing process identity check.""" + # No need to check for P2P restore directory + if directory is None: + # 1. Fast Fail: Verify Process Identity + stored_index = utils.detect_process_index(self._directory, step) + + if stored_index != self._process_index: + error_msg = ( + f'Process Mismatch: Local checkpoint at step {step} belongs to' + f' Process {stored_index}, but current process is' + f' {self._process_index}. Aborting local restore to trigger' + ' P2P/Persistent fallback.' + ) + raise ValueError(error_msg) + + # 2. Delegate to Orbax + restored = self._manager.restore(step, directory=directory) + return Composite(**restored) + + def __getattr__(self, name: str) -> Any: + return getattr(self._manager, name) + + def close(self): + self._manager.close() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/local_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/local_test.py new file mode 100644 index 000000000..299a6c8a8 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/local_test.py @@ -0,0 +1,112 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from etils import epath +import jax +import numpy as np +from orbax.checkpoint import args as args_lib +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint.experimental.emergency import checkpoint_manager as emergency_checkpoint_manager +from orbax.checkpoint.experimental.emergency.p2p import args as p2p_args_lib +from orbax.checkpoint.experimental.emergency.p2p import local + +Mesh = jax.sharding.Mesh +P = jax.sharding.PartitionSpec + + +class LocalCheckpointManagerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.enter_context( + mock.patch.object(multihost, 'get_jax_distributed_client') + ) + if not multihost.is_runtime_to_distributed_ids_initialized(): + multihost.initialize_runtime_to_distributed_ids() + self.directory = epath.Path(self.create_tempdir().full_path) + self.mesh = Mesh(np.array(jax.devices()), axis_names=('x',)) + self.options = emergency_checkpoint_manager.CheckpointManagerOptions() + + @mock.patch( + 'orbax.checkpoint._src.multihost.multihost.process_index', return_value=0 + ) + def test_init(self, unused_process_index): + manager = local.LocalCheckpointManager( + self.directory, self.mesh, options=self.options + ) + self.assertEqual(manager.directory, self.directory) + self.assertEqual(manager._process_index, 0) + self.assertIsNotNone(manager._manager) + manager.close() + + @mock.patch( + 'orbax.checkpoint._src.multihost.multihost.process_index', return_value=0 + ) + def test_scan_stored_steps_empty(self, unused_process_index): + manager = local.LocalCheckpointManager( + self.directory, self.mesh, options=self.options + ) + detected_index, steps = manager.scan_stored_steps() + self.assertIsNone(detected_index) + self.assertEmpty(steps) + manager.close() + + @mock.patch( + 'orbax.checkpoint._src.multihost.multihost.process_index', return_value=0 + ) + def test_restore_process_mismatch_raises_error(self, unused_process_index): + manager = local.LocalCheckpointManager( + self.directory, self.mesh, options=self.options + ) + step_dir = self.directory / '1' + step_dir.mkdir() + (step_dir / 'state' / 'ocdbt.process_1').mkdir( + parents=True + ) # Stored by process 1 + + with self.assertRaisesRegex(ValueError, 'Process Mismatch'): + manager.restore(1) + manager.close() + + @mock.patch( + 'orbax.checkpoint._src.multihost.multihost.process_index', return_value=0 + ) + def test_save_restore(self, unused_process_index): + manager = local.LocalCheckpointManager( + self.directory, self.mesh, options=self.options + ) + sharding = jax.sharding.NamedSharding(self.mesh, P('x')) + arr = jax.device_put(np.arange(self.mesh.size, dtype=np.int32), sharding) + state = { + 'a': arr, + 'b': jax.device_put( + np.arange(self.mesh.size, dtype=np.int32), sharding + ), + } + manager.save( + 1, args=p2p_args_lib.Composite(state=args_lib.PyTreeSave(state)) + ) + manager.wait_until_finished() + + restored = manager.restore(1) + + jax.tree_util.tree_map(np.testing.assert_array_equal, state, restored.state) + manager.close() + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py new file mode 100644 index 000000000..dd25a16c6 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py @@ -0,0 +1,174 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Handles persistent storage logic (GCS/S3) for P2P checkpointing.""" + +from typing import Any, final + +from absl import logging +from etils import epath +import jax +import orbax.checkpoint as ocp +from orbax.checkpoint import args as args_lib +from orbax.checkpoint import checkpoint_manager +from orbax.checkpoint import checkpoint_utils +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.multihost import multislice +from orbax.checkpoint._src.serialization import type_handler_registry +from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint.experimental.emergency.checkpoint_manager import CheckpointManagerOptions +from orbax.checkpoint.experimental.emergency.p2p.args import Composite + +_PRIMARY_REPLICA_ID = 0 +PyTree = Any + + +def _create_persistent_handler( + mp_options: checkpoint_manager.MultiprocessingOptions, +) -> ocp.PyTreeCheckpointHandler: + """Creates a PyTreeCheckpointHandler for persistent storage. + + Args: + mp_options: Multiprocessing options for the checkpoint handler. + + Returns: + A PyTreeCheckpointHandler configured for persistent storage. + """ + registry = type_handler_registry.create_type_handler_registry(( + jax.Array, + type_handlers.ArrayHandler( + primary_host=mp_options.primary_host, + replica_id=_PRIMARY_REPLICA_ID, + use_replica_parallel=False, + ), + )) + return ocp.PyTreeCheckpointHandler( + use_ocdbt=True, + use_zarr3=True, + multiprocessing_options=mp_options, + type_handler_registry=registry, + ) + + +@final +class PersistentCheckpointManager: + """Manages saving/restoring from slow persistent storage.""" + + def __init__( + self, + directory: epath.PathLike, + global_mesh: jax.sharding.Mesh, + *, + replica_axis_index: int, + options: CheckpointManagerOptions, + ): + self._directory = epath.Path(directory) + self._global_mesh = global_mesh + self._replica_axis_index = replica_axis_index + self._process_index = multihost.process_index() + self._replica_id = multislice.process_replica_id( + self._process_index, + self._global_mesh, + replica_axis_index=self._replica_axis_index, + ) + self._in_primary_slice = multislice.in_replica( + self._process_index, + global_mesh, + replica_axis_index=self._replica_axis_index, + replica_id=_PRIMARY_REPLICA_ID, + ) + + replica_devices = multislice.replica_devices( + self._global_mesh, + replica_axis_index=self._replica_axis_index, + replica_id=self._replica_id, + ) + primary_host = multislice.primary_process_in_replica( + self._global_mesh, + replica_axis_index=self._replica_axis_index, + replica_id=self._replica_id, + ) + active_processes = multihost.unique_processes_from_devices( + replica_devices + ) + mp_options = checkpoint_manager.MultiprocessingOptions( + primary_host=primary_host, + active_processes=active_processes, + barrier_sync_key_prefix=f'persistent_fallback_{self._replica_id}', + ) + + internal_options = checkpoint_manager.CheckpointManagerOptions( + create=False, + multiprocessing_options=mp_options, + step_name_format=options.step_name_format, + save_interval_steps=options.persistent.save_interval_steps, + max_to_keep=options.persistent.max_to_keep, + enable_async_checkpointing=True, + ) + + self._manager = checkpoint_manager.CheckpointManager( + self._directory, + options=internal_options, + item_handlers=dict(state=_create_persistent_handler(mp_options)), + ) + + @property + def directory(self) -> epath.Path: + return self._directory + + def save(self, step: int, args: Composite, force: bool = False) -> bool: + if self._in_primary_slice: + return self._manager.save(step, args=args, force=force) + return True + + def restore(self, step: int, args: Composite) -> Composite: + """Restores a checkpoint from persistent storage. + + Args: + step: The step number to restore. + args: A Composite object containing the abstract state to restore. + + Returns: + The restored state as a Composite object. + """ + assert self._manager is not None + logging.info( + 'Restoring step %s from persistent storage on slice %d...', + step, + self._replica_id, + ) + abstract_state = args.state + + sharding_tree = jax.tree.map(lambda x: x.sharding, abstract_state) + # TODO(exlin): Enable SingleReplicaRestore. + restore_args_obj = args_lib.PyTreeRestore( + item=abstract_state, + restore_args=checkpoint_utils.construct_restore_args( + abstract_state, sharding_tree + ), + ) + return self._manager.restore(step, args=Composite(state=restore_args_obj)) + + def delete(self, step: int): + if self._in_primary_slice: + self._manager.delete(step) + + def wait_until_finished(self): + self._manager.wait_until_finished() + + def check_for_errors(self): + self._manager.check_for_errors() + + def close(self): + self._manager.close() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent_test.py new file mode 100644 index 000000000..f2697c346 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent_test.py @@ -0,0 +1,216 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from etils import epath +import jax +import numpy as np +from orbax.checkpoint import args as args_lib +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint.experimental.emergency import checkpoint_manager as emergency_checkpoint_manager +from orbax.checkpoint.experimental.emergency.p2p import args as p2p_args_lib +from orbax.checkpoint.experimental.emergency.p2p import persistent + +Mesh = jax.sharding.Mesh + + +class MockDevice: + + def __init__(self, process_index, slice_index): + self.process_index = process_index + self.slice_index = slice_index + self.client = mock.Mock() + self.client.process_index.return_value = process_index + + def __repr__(self): + return f'MockDevice(pi={self.process_index}, si={self.slice_index})' + + +class PersistentCheckpointManagerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.enter_context( + mock.patch.object(multihost, 'get_jax_distributed_client') + ) + self.enter_context(mock.patch.object(jax, 'process_count', return_value=4)) + if not multihost.is_runtime_to_distributed_ids_initialized(): + multihost.initialize_runtime_to_distributed_ids() + if not multihost.is_distributed_to_device_ids_initialized(): + multihost.initialize_distributed_to_device_ids() + self.directory = epath.Path(self.create_tempdir().full_path) + devices = np.array([ + [MockDevice(0, 0), MockDevice(1, 0)], + [MockDevice(2, 1), MockDevice(3, 1)], + ]) + self.mesh = mock.Mock( + spec=jax.sharding.Mesh, + devices=devices, + axis_names=('replica', 'data'), + shape={'replica': 2, 'data': 2}, + shape_tuple=devices.shape, + size=devices.size, + ) + self.options = emergency_checkpoint_manager.CheckpointManagerOptions() + + def _patch_process_index(self, process_index): + self.enter_context( + mock.patch( + 'orbax.checkpoint._src.multihost.multihost.process_index', + return_value=process_index, + ) + ) + + def test_init_in_primary_slice(self): + self._patch_process_index(process_index=0) + manager = persistent.PersistentCheckpointManager( + self.directory, self.mesh, replica_axis_index=0, options=self.options + ) + self.assertTrue(manager._in_primary_slice) + self.assertIsNotNone(manager._manager) + manager.close() + + def test_init_not_in_primary_slice(self): + self._patch_process_index(process_index=2) + manager = persistent.PersistentCheckpointManager( + self.directory, self.mesh, replica_axis_index=0, options=self.options + ) + self.assertFalse(manager._in_primary_slice) + self.assertIsNotNone(manager._manager) + manager.close() + + def test_save_in_primary_slice_saves(self): + self._patch_process_index(process_index=0) + manager = persistent.PersistentCheckpointManager( + self.directory, self.mesh, replica_axis_index=0, options=self.options + ) + manager._manager = mock.MagicMock() + args = p2p_args_lib.Composite( + state=args_lib.PyTreeSave({'a': jax.device_put(1)}) + ) + manager.save(1, args) + manager._manager.save.assert_called_once() + manager.close() + + def test_save_not_in_primary_slice_does_not_save(self): + self._patch_process_index(process_index=2) + manager = persistent.PersistentCheckpointManager( + self.directory, self.mesh, replica_axis_index=0, options=self.options + ) + manager._manager = mock.MagicMock() + args = p2p_args_lib.Composite( + state=args_lib.PyTreeSave({'a': jax.device_put(1)}) + ) + manager.save(1, args) + manager._manager.save.assert_not_called() + manager.close() + + def test_save_and_restore(self): + self._patch_process_index(process_index=0) + # persistent checkpoint manager with multiprocessing only works with a + # unified storage. + self.enter_context(mock.patch.object(jax, 'process_count', return_value=1)) + devices = np.array([ + [MockDevice(0, 0)], + ]) + mesh = mock.Mock( + spec=jax.sharding.Mesh, + devices=devices, + axis_names=('replica', 'data'), + shape={'replica': 1, 'data': 1}, + shape_tuple=devices.shape, + size=devices.size, + ) + manager = persistent.PersistentCheckpointManager( + self.directory, mesh, replica_axis_index=0, options=self.options + ) + + arr = jax.device_put(np.arange(self.mesh.size, dtype=np.int32)) + state = {'a': arr, 'b': jax.device_put(1)} + args = p2p_args_lib.Composite(state=args_lib.PyTreeSave(state)) + + self.assertTrue(manager.save(1, args)) + manager.wait_until_finished() + + self.assertFalse((self.directory / '1' / 'default').exists()) + self.assertTrue((self.directory / '1' / 'state').exists()) + + def _to_abstract(x): + if isinstance(x, jax.Array): + return jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding) + return x + + abstract_state = jax.tree.map(_to_abstract, state) + restored = manager.restore( + 1, args=p2p_args_lib.Composite(state=abstract_state) + ) + restored_state = restored.state + test_utils.assert_tree_equal(self, state, restored_state) + manager.close() + + def test_delete_in_primary_slice_deletes(self): + self._patch_process_index(process_index=0) + manager = persistent.PersistentCheckpointManager( + self.directory, self.mesh, replica_axis_index=0, options=self.options + ) + manager._manager = mock.MagicMock() + manager.delete(1) + manager._manager.delete.assert_called_once_with(1) + manager.close() + + def test_delete_not_in_primary_slice_does_not_delete(self): + self._patch_process_index(process_index=2) + manager = persistent.PersistentCheckpointManager( + self.directory, self.mesh, replica_axis_index=0, options=self.options + ) + manager._manager = mock.MagicMock() + manager.delete(1) + manager._manager.delete.assert_not_called() + manager.close() + + def test_wait_until_finished_calls_manager(self): + self._patch_process_index(process_index=0) + manager = persistent.PersistentCheckpointManager( + self.directory, self.mesh, replica_axis_index=0, options=self.options + ) + manager._manager = mock.MagicMock() + manager.wait_until_finished() + manager._manager.wait_until_finished.assert_called_once() + manager.close() + + def test_check_for_errors_calls_manager(self): + self._patch_process_index(process_index=0) + manager = persistent.PersistentCheckpointManager( + self.directory, self.mesh, replica_axis_index=0, options=self.options + ) + manager._manager = mock.MagicMock() + manager.check_for_errors() + manager._manager.check_for_errors.assert_called_once() + manager.close() + + def test_close_calls_manager(self): + self._patch_process_index(process_index=0) + manager = persistent.PersistentCheckpointManager( + self.directory, self.mesh, replica_axis_index=0, options=self.options + ) + manager._manager = mock.MagicMock() + manager.close() + manager._manager.close.assert_called_once() + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py index cc5a15cd6..a980e3a05 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py @@ -27,6 +27,7 @@ from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint.experimental.emergency.p2p import constants from orbax.checkpoint.experimental.emergency.p2p import protocol +from orbax.checkpoint.experimental.emergency.p2p import utils class _ThreadingTCPServer(socketserver.ThreadingTCPServer): @@ -119,16 +120,6 @@ def stop(self): self._thread.join(timeout=2.0) self._thread = None - def _get_stored_process_index(self, step_path: epath.Path) -> int | None: - """Returns the process index of the shard stored in the given step path.""" - item_path = step_path / constants.STATE_SUBDIR - if item_path.exists(): - for path in item_path.glob(f'{constants.PROCESS_SUBDIR_PREFIX}*'): - if path.is_dir(): - # Format: ocdbt.process_0, ocdbt.process_12, etc. - return int(path.name.split('_')[-1]) - return None - def handle_get_manifest( self, payload: dict[str, Any] ) -> list[dict[str, Any]]: @@ -149,7 +140,7 @@ def handle_get_manifest( if not step_dir.exists(): return [] - stored_process_index = self._get_stored_process_index(step_dir) + stored_process_index = utils.detect_process_index(self.directory, step) # If process_index is specified, only return manifest if it matches. if req_process_index != stored_process_index: diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/utils.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/utils.py new file mode 100644 index 000000000..de91f929c --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/utils.py @@ -0,0 +1,40 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for P2P checkpointing.""" + +from absl import logging +from etils import epath +from orbax.checkpoint.experimental.emergency.p2p import constants + + +def detect_process_index(directory: epath.Path, step: int) -> int | None: + """Inspects the disk to find which process index created this step.""" + step_path = directory / str(step) + if not step_path.exists(): + return None + + # Check for standard Orbax/OCDBT structure + # P2P checkpoint requires 'state' item in CompositeArgs + try: + item_path = step_path / constants.STATE_SUBDIR + if item_path.exists(): + for path in item_path.glob(f'{constants.PROCESS_SUBDIR_PREFIX}*'): + if path.is_dir(): + # Format: ocdbt.process_0, ocdbt.process_12, etc. + return int(path.name.split('_')[-1]) + except (ValueError, IndexError, OSError) as e: + logging.warning('Could not detect process index for step %d: %s', step, e) + + return None diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/utils_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/utils_test.py new file mode 100644 index 000000000..7fe46f793 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/utils_test.py @@ -0,0 +1,36 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from etils import epath +from orbax.checkpoint.experimental.emergency.p2p import utils + + +class UtilsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.directory = epath.Path(self.create_tempdir().full_path) + + def test_detect_process_index(self): + step_dir = self.directory / '1' + step_dir.mkdir() + (step_dir / 'state' / 'ocdbt.process_42').mkdir(parents=True) + + self.assertEqual(utils.detect_process_index(self.directory, 1), 42) + self.assertIsNone(utils.detect_process_index(self.directory, 2)) + + +if __name__ == '__main__': + absltest.main()