Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class PyTreeCheckpointBenchmark(benchmarks_core.BenchmarksGenerator):
PyTreeCheckpointHandler with various configurations.
"""

def _clear_pytree(self, pytree: Any) -> Any:
def clear_pytree(self, pytree: Any) -> Any:
"""Clears the pytree to free up memory."""
return jax.tree.map(
lambda x: x.delete() if isinstance(x, jax.Array) else None, pytree
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_fn(
assert hasattr(checkpointer, "wait_until_finished")
checkpointer.wait_until_finished()

context.pytree = self._clear_pytree(context.pytree)
context.pytree = self.clear_pytree(context.pytree)

with metrics.measure("restore", metrics_to_measure):
restored_pytree = checkpointer.restore(
Expand All @@ -185,7 +185,7 @@ def test_fn(
),
)

self._clear_pytree(restored_pytree)
self.clear_pytree(restored_pytree)

checkpointer.close()
return benchmarks_core.TestResult(metrics=metrics)
165 changes: 165 additions & 0 deletions checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Benchmarks for V1 free functions."""

from __future__ import annotations

from collections.abc import Sequence
import dataclasses
import pprint
from typing import Any

from absl import logging
import jax
from orbax.checkpoint import v1 as ocp
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib


def _metrics_to_measure(options: V1BenchmarkOptions) -> list[str]:
"""Returns the list of metrics to measure."""
metrics = ["time", "rss", "io"]
if options.metric_tracemalloc_enabled:
metrics.append("tracemalloc")
if options.metric_tensorstore_enabled:
metrics.append("tensorstore")
return metrics


# ==============================================================================
# 1. Define the Options Dataclass for this specific benchmark
# ==============================================================================
@dataclasses.dataclass(frozen=True)
class V1BenchmarkOptions(benchmarks_core.BenchmarkOptions):
"""Configuration options for benchmarks targeting V1BenchmarkHandler.

Each attribute can be a single value or a list of values to create
a parameter sweep.

Attributes:
"""

async_enabled: bool | Sequence[bool] = True
use_ocdbt: bool | Sequence[bool] = True
use_zarr3: bool | Sequence[bool] = False
use_compression: bool | Sequence[bool] = True
save_concurrent_gb: int | None | Sequence[int | None] = None
restore_concurrent_gb: int | None | Sequence[int | None] = None
metric_tracemalloc_enabled: bool = False
metric_tensorstore_enabled: bool = False
use_replica_parallel: bool | Sequence[bool] = False
enable_replica_parallel_separate_folder: bool | Sequence[bool] = False
enable_trace: bool = False

def is_valid(self):
assert isinstance(self.use_replica_parallel, bool)
assert isinstance(self.enable_replica_parallel_separate_folder, bool)
if self.enable_replica_parallel_separate_folder and (
not self.use_replica_parallel or not self.use_ocdbt
):
return False
return True

@property
def context(self) -> ocp.Context:
return ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(
use_ocdbt=self.use_ocdbt,
use_zarr3=self.use_zarr3,
use_replica_parallel=self.use_replica_parallel,
use_compression=self.use_compression,
enable_replica_parallel_separate_folder=self.enable_replica_parallel_separate_folder,
concurrent_bytes=self.save_concurrent_gb * 1024**3
if self.save_concurrent_gb is not None
else None,
),
loading=ocp.options.ArrayOptions.Loading(
concurrent_bytes=self.restore_concurrent_gb * 1024**3
if self.restore_concurrent_gb is not None
else None,
),
),
)


# ==============================================================================
# 2. Implement the Benchmark Generator
# ==============================================================================
@benchmarks_core.benchmark_options(V1BenchmarkOptions)
class V1Benchmark(benchmarks_core.BenchmarksGenerator):
"""A concrete generator for `orbax.checkpoint.V1BenchmarkHandler`.

This class provides the specific test logic for benchmarking the
V1BenchmarkHandler with various configurations.
"""

def _clear_pytree(self, pytree: Any) -> Any:
"""Clears the pytree to free up memory."""
return jax.tree.map(
lambda x: x.delete() if isinstance(x, jax.Array) else None, pytree
)

def test_fn(
self, context: benchmarks_core.TestContext
) -> benchmarks_core.TestResult:
"""The core test logic for a single save/restore cycle.

This function is called for each combination of options generated by the
framework. It uses the `context.options` to configure the handler
dynamically for each run.

Args:
context: The test context containing the pytree, path, and options.

Returns:
The test result containing the metrics.
"""
metrics = metric_lib.Metrics()
pytree = context.pytree
abstract_pytree = jax.tree.map(ocp.arrays.to_shape_dtype_struct, pytree)
save_path = context.path / "ckpt"
options = context.options
assert isinstance(options, V1BenchmarkOptions)

logging.info("Benchmark options: %s", pprint.pformat(options))
metrics_to_measure = _metrics_to_measure(options)

with ocp.Context(context=options.context):
if options.enable_trace:
jax.profiler.start_trace(context.path / "trace_save")
if options.async_enabled:
with metrics.measure("save_blocking", metrics_to_measure):
f = ocp.save_pytree_async(save_path, pytree)
with metrics.measure("save_background", metrics_to_measure):
f.result()
else:
with metrics.measure("save_blocking", metrics_to_measure):
ocp.save_pytree(save_path, pytree)
with metrics.measure("save_background", metrics_to_measure):
pass
context.pytree = self._clear_pytree(context.pytree)
if options.enable_trace:
jax.profiler.stop_trace()

if options.enable_trace:
jax.profiler.start_trace(context.path / "trace_load")
with metrics.measure("load", metrics_to_measure):
restored_pytree = ocp.load_pytree(save_path, abstract_pytree)
self._clear_pytree(restored_pytree)
if options.enable_trace:
jax.profiler.stop_trace()

return benchmarks_core.TestResult(metrics=metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
import jax.numpy as jnp
from orbax.checkpoint._src.testing.benchmarks import v1_benchmark
from orbax.checkpoint._src.testing.benchmarks.core import configs as benchmarks_configs
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core


V1BenchmarkOptions = v1_benchmark.V1BenchmarkOptions
V1Benchmark = v1_benchmark.V1Benchmark


class V1BenchmarkTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.directory = epath.Path(self.create_tempdir().full_path)

@parameterized.parameters(
dict(
options=V1BenchmarkOptions(use_ocdbt=False, use_zarr3=True),
expected_len=1,
),
dict(
options=V1BenchmarkOptions(use_ocdbt=[False, True], use_zarr3=True),
expected_len=2,
),
dict(
options=V1BenchmarkOptions(
use_ocdbt=[False, True], use_zarr3=[False, True]
),
expected_len=4,
),
)
def test_generate_benchmarks(self, options, expected_len):
generator = V1Benchmark(
checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})],
options=options,
)
benchmarks = generator.generate()
self.assertLen(benchmarks, expected_len)
for benchmark in benchmarks:
self.assertIsInstance(benchmark.options, V1BenchmarkOptions)

@parameterized.product(
use_ocdbt=(False, True),
use_zarr3=(False, True),
use_compression=(False, True),
save_concurrent_gb=(None, 1),
restore_concurrent_gb=(None, 2),
use_replica_parallel=(False,), # Simplified for single process test
enable_replica_parallel_separate_folder=(False,),
)
def test_benchmark_test_fn(
self,
use_ocdbt,
use_zarr3,
use_compression,
save_concurrent_gb,
restore_concurrent_gb,
use_replica_parallel,
enable_replica_parallel_separate_folder,
):
# Skip invalid combinations
if enable_replica_parallel_separate_folder and (
not use_replica_parallel or not use_ocdbt
):
return

generator = V1Benchmark(
checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})],
options=V1BenchmarkOptions(),
)

pytree = {
'a': jnp.arange(10),
'b': {'c': jnp.ones((5, 5))},
}

test_options = V1BenchmarkOptions(
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
use_compression=use_compression,
save_concurrent_gb=save_concurrent_gb,
restore_concurrent_gb=restore_concurrent_gb,
use_replica_parallel=use_replica_parallel,
enable_replica_parallel_separate_folder=enable_replica_parallel_separate_folder,
)

# Create unique path for each parameter set
test_subdir = (
self.directory / f'test_{use_ocdbt}_{use_zarr3}_{use_compression}'
)
test_subdir.mkdir(exist_ok=True, parents=True)

context = benchmarks_core.TestContext(
pytree=pytree, path=test_subdir, options=test_options
)

result = generator.test_fn(context)

self.assertIsInstance(result, benchmarks_core.TestResult)
# Check for expected metrics keys based on _metrics_to_measure
# in v1_benchmark.py and the metric.measure calls.
# The benchmark records "save_blocking", "save_background", "load".
# Metric "time" is always added.

# We expect roughly:
# save_blocking_time_duration
# save_background_time_duration
# load_time_duration

metrics = result.metrics.results
self.assertIn('save_blocking_time_duration', metrics)
self.assertIn('save_background_time_duration', metrics)
self.assertIn('load_time_duration', metrics)


if __name__ == '__main__':
absltest.main()
Loading