Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Removed

- Removed `nest_asyncio` dependency and its usage. Orbax no longer supports
nested `asyncio.run` calls by default.

## [0.11.32] - 2026-01-20

### Changed
Expand Down
8 changes: 0 additions & 8 deletions checkpoint/orbax/checkpoint/_src/asyncio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,13 @@

import asyncio
from typing import Any, Coroutine, TypeVar
import nest_asyncio


_T = TypeVar('_T')


def run_sync(
coro: Coroutine[Any, Any, _T],
enable_nest_asyncio: bool = True, # For testing.
) -> _T:
"""Runs a coroutine and returns the result."""
try:
asyncio.get_running_loop() # no event loop: ~0.001s, otherwise: ~0.182s
if enable_nest_asyncio:
nest_asyncio.apply() # patch asyncio globally in a runtime (idempotent).
except RuntimeError:
pass
return asyncio.run(coro)
97 changes: 3 additions & 94 deletions checkpoint/orbax/checkpoint/_src/asyncio_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ async def raise_error():
raise ValueError("test error")


async def with_run_sync(a_coro_fn):
x = asyncio_utils.run_sync(a_coro_fn())
y = asyncio_utils.run_sync(a_coro_fn())
z = asyncio_utils.run_sync(a_coro_fn())
return f"{x}{y}{z}"


class AsyncioUtilsTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -156,82 +149,6 @@ def test_run_sync_raising_error(self, coro_fn):
with self.assertRaisesRegex(ValueError, "test error"):
asyncio_utils.run_sync(coro_fn())

@parameterized.named_parameters(
["basic", partial(with_run_sync, one)],
["nested_1_level", partial(with_run_sync, partial(with_run_sync, one))],
[
"nested_2_level",
partial(
with_run_sync, partial(with_run_sync, partial(with_run_sync, one))
),
],
[
"nested_3_level",
partial(
with_run_sync,
partial(
with_run_sync,
partial(with_run_sync, partial(with_run_sync, one)),
),
),
],
[
"nested_4_level",
partial(
with_run_sync,
partial(
with_run_sync,
partial(
with_run_sync,
partial(with_run_sync, partial(with_run_sync, one)),
),
),
),
],
[
"nested_5_level",
partial(
with_run_sync,
partial(
with_run_sync,
partial(
with_run_sync,
partial(
with_run_sync,
partial(with_run_sync, partial(with_run_sync, one)),
),
),
),
),
],
[
"nested_6_level",
partial(
with_run_sync,
partial(
with_run_sync,
partial(
with_run_sync,
partial(
with_run_sync,
partial(
with_run_sync,
partial(
with_run_sync, partial(with_run_sync, one)
),
),
),
),
),
),
],
)
def test_run_sync_nested(self, coro_fn):
self.assertEqual(
asyncio.run(coro_fn()),
asyncio_utils.run_sync(coro_fn()),
)

def test_run_nested(self):
async def with_run(a_coro_fn):
return asyncio.run(a_coro_fn())
Expand Down Expand Up @@ -285,21 +202,13 @@ async def _test():
)()

number = 10000
# First run with enable_nest_asyncio=False, because nest_asyncio.apply
# patches asyncio globally in a runtime. There is no way to unpatch it.
run_time = timeit.timeit(
lambda: asyncio_utils.run_sync(_test(), enable_nest_asyncio=False),
number=number,
) # ~1.604s
run_sync_time = timeit.timeit(
lambda: asyncio_utils.run_sync(_test(), enable_nest_asyncio=True),
lambda: asyncio_utils.run_sync(_test()),
number=number,
) # ~1.5503s
)
logging.info(
"time: run_sync_time=%s, run_time=%s, ratio=%s",
"time: run_sync_time=%s",
run_sync_time,
run_time,
run_sync_time / run_time,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def test_composite_handler_metadata(self):
)
self.assertIsNone(metadata.item_metadata['other'])

async def test_save_with_failing_array_metadata_store_in_finalize(self):
def test_save_with_failing_array_metadata_store_in_finalize(self):
"""ArrayMetadata validation in CheckpointHandler.finalize()."""

class IncompleteArrayMetadataSerializer(
Expand Down Expand Up @@ -662,7 +662,9 @@ def maybe_raise_error(self):
'ArrayMetadata Store contains different number of params',
):
array_metadata_validator.maybe_raise_error()
self.assertIsNotNone(await array_metadata_store.read(self.directory))
self.assertIsNotNone(
asyncio_utils.run_sync(array_metadata_store.read(self.directory))
)

def test_partial_restore_with_placeholder(self):
"""Basic save and restore test."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@

"""Helpers for asyncio usage."""

import nest_asyncio


def maybe_apply_nest_asyncio():
try:
nest_asyncio.apply()
except RuntimeError:
pass
pass
1 change: 0 additions & 1 deletion checkpoint/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ dependencies = [
'numpy',
'pyyaml',
'tensorstore >= 0.1.74',
'nest_asyncio',
'aiofiles',
'protobuf',
'humanize',
Expand Down
Loading