diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/pytree_checkpoint_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/pytree_checkpoint_benchmark.py index 727320f24..361da4b8a 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/pytree_checkpoint_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/pytree_checkpoint_benchmark.py @@ -97,7 +97,7 @@ class PyTreeCheckpointBenchmark(benchmarks_core.BenchmarksGenerator): PyTreeCheckpointHandler with various configurations. """ - def _clear_pytree(self, pytree: Any) -> Any: + def clear_pytree(self, pytree: Any) -> Any: """Clears the pytree to free up memory.""" return jax.tree.map( lambda x: x.delete() if isinstance(x, jax.Array) else None, pytree @@ -174,7 +174,7 @@ def test_fn( assert hasattr(checkpointer, "wait_until_finished") checkpointer.wait_until_finished() - context.pytree = self._clear_pytree(context.pytree) + context.pytree = self.clear_pytree(context.pytree) with metrics.measure("restore", metrics_to_measure): restored_pytree = checkpointer.restore( @@ -185,7 +185,7 @@ def test_fn( ), ) - self._clear_pytree(restored_pytree) + self.clear_pytree(restored_pytree) checkpointer.close() return benchmarks_core.TestResult(metrics=metrics) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1_benchmark.py new file mode 100644 index 000000000..96f51624b --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1_benchmark.py @@ -0,0 +1,165 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for V1 free functions.""" + +from __future__ import annotations + +from collections.abc import Sequence +import dataclasses +import pprint +from typing import Any + +from absl import logging +import jax +from orbax.checkpoint import v1 as ocp +from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core +from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib + + +def _metrics_to_measure(options: V1BenchmarkOptions) -> list[str]: + """Returns the list of metrics to measure.""" + metrics = ["time", "rss", "io"] + if options.metric_tracemalloc_enabled: + metrics.append("tracemalloc") + if options.metric_tensorstore_enabled: + metrics.append("tensorstore") + return metrics + + +# ============================================================================== +# 1. Define the Options Dataclass for this specific benchmark +# ============================================================================== +@dataclasses.dataclass(frozen=True) +class V1BenchmarkOptions(benchmarks_core.BenchmarkOptions): + """Configuration options for benchmarks targeting V1BenchmarkHandler. + + Each attribute can be a single value or a list of values to create + a parameter sweep. + + Attributes: + """ + + async_enabled: bool | Sequence[bool] = True + use_ocdbt: bool | Sequence[bool] = True + use_zarr3: bool | Sequence[bool] = False + use_compression: bool | Sequence[bool] = True + save_concurrent_gb: int | None | Sequence[int | None] = None + restore_concurrent_gb: int | None | Sequence[int | None] = None + metric_tracemalloc_enabled: bool = False + metric_tensorstore_enabled: bool = False + use_replica_parallel: bool | Sequence[bool] = False + enable_replica_parallel_separate_folder: bool | Sequence[bool] = False + enable_trace: bool = False + + def is_valid(self): + assert isinstance(self.use_replica_parallel, bool) + assert isinstance(self.enable_replica_parallel_separate_folder, bool) + if self.enable_replica_parallel_separate_folder and ( + not self.use_replica_parallel or not self.use_ocdbt + ): + return False + return True + + @property + def context(self) -> ocp.Context: + return ocp.Context( + array_options=ocp.options.ArrayOptions( + saving=ocp.options.ArrayOptions.Saving( + use_ocdbt=self.use_ocdbt, + use_zarr3=self.use_zarr3, + use_replica_parallel=self.use_replica_parallel, + use_compression=self.use_compression, + enable_replica_parallel_separate_folder=self.enable_replica_parallel_separate_folder, + concurrent_bytes=self.save_concurrent_gb * 1024**3 + if self.save_concurrent_gb is not None + else None, + ), + loading=ocp.options.ArrayOptions.Loading( + concurrent_bytes=self.restore_concurrent_gb * 1024**3 + if self.restore_concurrent_gb is not None + else None, + ), + ), + ) + + +# ============================================================================== +# 2. Implement the Benchmark Generator +# ============================================================================== +@benchmarks_core.benchmark_options(V1BenchmarkOptions) +class V1Benchmark(benchmarks_core.BenchmarksGenerator): + """A concrete generator for `orbax.checkpoint.V1BenchmarkHandler`. + + This class provides the specific test logic for benchmarking the + V1BenchmarkHandler with various configurations. + """ + + def _clear_pytree(self, pytree: Any) -> Any: + """Clears the pytree to free up memory.""" + return jax.tree.map( + lambda x: x.delete() if isinstance(x, jax.Array) else None, pytree + ) + + def test_fn( + self, context: benchmarks_core.TestContext + ) -> benchmarks_core.TestResult: + """The core test logic for a single save/restore cycle. + + This function is called for each combination of options generated by the + framework. It uses the `context.options` to configure the handler + dynamically for each run. + + Args: + context: The test context containing the pytree, path, and options. + + Returns: + The test result containing the metrics. + """ + metrics = metric_lib.Metrics() + pytree = context.pytree + abstract_pytree = jax.tree.map(ocp.arrays.to_shape_dtype_struct, pytree) + save_path = context.path / "ckpt" + options = context.options + assert isinstance(options, V1BenchmarkOptions) + + logging.info("Benchmark options: %s", pprint.pformat(options)) + metrics_to_measure = _metrics_to_measure(options) + + with ocp.Context(context=options.context): + if options.enable_trace: + jax.profiler.start_trace(context.path / "trace_save") + if options.async_enabled: + with metrics.measure("save_blocking", metrics_to_measure): + f = ocp.save_pytree_async(save_path, pytree) + with metrics.measure("save_background", metrics_to_measure): + f.result() + else: + with metrics.measure("save_blocking", metrics_to_measure): + ocp.save_pytree(save_path, pytree) + with metrics.measure("save_background", metrics_to_measure): + pass + context.pytree = self._clear_pytree(context.pytree) + if options.enable_trace: + jax.profiler.stop_trace() + + if options.enable_trace: + jax.profiler.start_trace(context.path / "trace_load") + with metrics.measure("load", metrics_to_measure): + restored_pytree = ocp.load_pytree(save_path, abstract_pytree) + self._clear_pytree(restored_pytree) + if options.enable_trace: + jax.profiler.stop_trace() + + return benchmarks_core.TestResult(metrics=metrics) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1_benchmark_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1_benchmark_test.py new file mode 100644 index 000000000..8fe7191d6 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1_benchmark_test.py @@ -0,0 +1,135 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +import jax.numpy as jnp +from orbax.checkpoint._src.testing.benchmarks import v1_benchmark +from orbax.checkpoint._src.testing.benchmarks.core import configs as benchmarks_configs +from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core + + +V1BenchmarkOptions = v1_benchmark.V1BenchmarkOptions +V1Benchmark = v1_benchmark.V1Benchmark + + +class V1BenchmarkTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.directory = epath.Path(self.create_tempdir().full_path) + + @parameterized.parameters( + dict( + options=V1BenchmarkOptions(use_ocdbt=False, use_zarr3=True), + expected_len=1, + ), + dict( + options=V1BenchmarkOptions(use_ocdbt=[False, True], use_zarr3=True), + expected_len=2, + ), + dict( + options=V1BenchmarkOptions( + use_ocdbt=[False, True], use_zarr3=[False, True] + ), + expected_len=4, + ), + ) + def test_generate_benchmarks(self, options, expected_len): + generator = V1Benchmark( + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], + options=options, + ) + benchmarks = generator.generate() + self.assertLen(benchmarks, expected_len) + for benchmark in benchmarks: + self.assertIsInstance(benchmark.options, V1BenchmarkOptions) + + @parameterized.product( + use_ocdbt=(False, True), + use_zarr3=(False, True), + use_compression=(False, True), + save_concurrent_gb=(None, 1), + restore_concurrent_gb=(None, 2), + use_replica_parallel=(False,), # Simplified for single process test + enable_replica_parallel_separate_folder=(False,), + ) + def test_benchmark_test_fn( + self, + use_ocdbt, + use_zarr3, + use_compression, + save_concurrent_gb, + restore_concurrent_gb, + use_replica_parallel, + enable_replica_parallel_separate_folder, + ): + # Skip invalid combinations + if enable_replica_parallel_separate_folder and ( + not use_replica_parallel or not use_ocdbt + ): + return + + generator = V1Benchmark( + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], + options=V1BenchmarkOptions(), + ) + + pytree = { + 'a': jnp.arange(10), + 'b': {'c': jnp.ones((5, 5))}, + } + + test_options = V1BenchmarkOptions( + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + use_compression=use_compression, + save_concurrent_gb=save_concurrent_gb, + restore_concurrent_gb=restore_concurrent_gb, + use_replica_parallel=use_replica_parallel, + enable_replica_parallel_separate_folder=enable_replica_parallel_separate_folder, + ) + + # Create unique path for each parameter set + test_subdir = ( + self.directory / f'test_{use_ocdbt}_{use_zarr3}_{use_compression}' + ) + test_subdir.mkdir(exist_ok=True, parents=True) + + context = benchmarks_core.TestContext( + pytree=pytree, path=test_subdir, options=test_options + ) + + result = generator.test_fn(context) + + self.assertIsInstance(result, benchmarks_core.TestResult) + # Check for expected metrics keys based on _metrics_to_measure + # in v1_benchmark.py and the metric.measure calls. + # The benchmark records "save_blocking", "save_background", "load". + # Metric "time" is always added. + + # We expect roughly: + # save_blocking_time_duration + # save_background_time_duration + # load_time_duration + + metrics = result.metrics.results + self.assertIn('save_blocking_time_duration', metrics) + self.assertIn('save_background_time_duration', metrics) + self.assertIn('load_time_duration', metrics) + + +if __name__ == '__main__': + absltest.main()