From 79a82b3d35fa5fc1fbf484e0b15657bbb48857df Mon Sep 17 00:00:00 2001 From: Justin Pan Date: Thu, 22 Jan 2026 15:30:22 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 859801395 --- checkpoint/CHANGELOG.md | 5 + .../orbax/checkpoint/_src/asyncio_utils.py | 8 -- .../checkpoint/_src/asyncio_utils_test.py | 97 +------------------ .../checkpointers/checkpointer_test_utils.py | 6 +- .../v1/_src/synchronization/asyncio_utils.py | 7 +- checkpoint/pyproject.toml | 1 - 6 files changed, 13 insertions(+), 111 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 1002b3e3a..939f6c1a4 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Removed + +- Removed `nest_asyncio` dependency and its usage. Orbax no longer supports + nested `asyncio.run` calls by default. + ## [0.11.32] - 2026-01-20 ### Changed diff --git a/checkpoint/orbax/checkpoint/_src/asyncio_utils.py b/checkpoint/orbax/checkpoint/_src/asyncio_utils.py index 7086d37f6..a66d55a3b 100644 --- a/checkpoint/orbax/checkpoint/_src/asyncio_utils.py +++ b/checkpoint/orbax/checkpoint/_src/asyncio_utils.py @@ -16,7 +16,6 @@ import asyncio from typing import Any, Coroutine, TypeVar -import nest_asyncio _T = TypeVar('_T') @@ -24,13 +23,6 @@ def run_sync( coro: Coroutine[Any, Any, _T], - enable_nest_asyncio: bool = True, # For testing. ) -> _T: """Runs a coroutine and returns the result.""" - try: - asyncio.get_running_loop() # no event loop: ~0.001s, otherwise: ~0.182s - if enable_nest_asyncio: - nest_asyncio.apply() # patch asyncio globally in a runtime (idempotent). - except RuntimeError: - pass return asyncio.run(coro) diff --git a/checkpoint/orbax/checkpoint/_src/asyncio_utils_test.py b/checkpoint/orbax/checkpoint/_src/asyncio_utils_test.py index 03fbe092d..110db2312 100644 --- a/checkpoint/orbax/checkpoint/_src/asyncio_utils_test.py +++ b/checkpoint/orbax/checkpoint/_src/asyncio_utils_test.py @@ -45,13 +45,6 @@ async def raise_error(): raise ValueError("test error") -async def with_run_sync(a_coro_fn): - x = asyncio_utils.run_sync(a_coro_fn()) - y = asyncio_utils.run_sync(a_coro_fn()) - z = asyncio_utils.run_sync(a_coro_fn()) - return f"{x}{y}{z}" - - class AsyncioUtilsTest(parameterized.TestCase): @parameterized.named_parameters( @@ -156,82 +149,6 @@ def test_run_sync_raising_error(self, coro_fn): with self.assertRaisesRegex(ValueError, "test error"): asyncio_utils.run_sync(coro_fn()) - @parameterized.named_parameters( - ["basic", partial(with_run_sync, one)], - ["nested_1_level", partial(with_run_sync, partial(with_run_sync, one))], - [ - "nested_2_level", - partial( - with_run_sync, partial(with_run_sync, partial(with_run_sync, one)) - ), - ], - [ - "nested_3_level", - partial( - with_run_sync, - partial( - with_run_sync, - partial(with_run_sync, partial(with_run_sync, one)), - ), - ), - ], - [ - "nested_4_level", - partial( - with_run_sync, - partial( - with_run_sync, - partial( - with_run_sync, - partial(with_run_sync, partial(with_run_sync, one)), - ), - ), - ), - ], - [ - "nested_5_level", - partial( - with_run_sync, - partial( - with_run_sync, - partial( - with_run_sync, - partial( - with_run_sync, - partial(with_run_sync, partial(with_run_sync, one)), - ), - ), - ), - ), - ], - [ - "nested_6_level", - partial( - with_run_sync, - partial( - with_run_sync, - partial( - with_run_sync, - partial( - with_run_sync, - partial( - with_run_sync, - partial( - with_run_sync, partial(with_run_sync, one) - ), - ), - ), - ), - ), - ), - ], - ) - def test_run_sync_nested(self, coro_fn): - self.assertEqual( - asyncio.run(coro_fn()), - asyncio_utils.run_sync(coro_fn()), - ) - def test_run_nested(self): async def with_run(a_coro_fn): return asyncio.run(a_coro_fn()) @@ -285,21 +202,13 @@ async def _test(): )() number = 10000 - # First run with enable_nest_asyncio=False, because nest_asyncio.apply - # patches asyncio globally in a runtime. There is no way to unpatch it. - run_time = timeit.timeit( - lambda: asyncio_utils.run_sync(_test(), enable_nest_asyncio=False), - number=number, - ) # ~1.604s run_sync_time = timeit.timeit( - lambda: asyncio_utils.run_sync(_test(), enable_nest_asyncio=True), + lambda: asyncio_utils.run_sync(_test()), number=number, - ) # ~1.5503s + ) logging.info( - "time: run_sync_time=%s, run_time=%s, ratio=%s", + "time: run_sync_time=%s", run_sync_time, - run_time, - run_sync_time / run_time, ) diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer_test_utils.py b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer_test_utils.py index 296a92b75..8c6683653 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer_test_utils.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer_test_utils.py @@ -599,7 +599,7 @@ def test_composite_handler_metadata(self): ) self.assertIsNone(metadata.item_metadata['other']) - async def test_save_with_failing_array_metadata_store_in_finalize(self): + def test_save_with_failing_array_metadata_store_in_finalize(self): """ArrayMetadata validation in CheckpointHandler.finalize().""" class IncompleteArrayMetadataSerializer( @@ -662,7 +662,9 @@ def maybe_raise_error(self): 'ArrayMetadata Store contains different number of params', ): array_metadata_validator.maybe_raise_error() - self.assertIsNotNone(await array_metadata_store.read(self.directory)) + self.assertIsNotNone( + asyncio_utils.run_sync(array_metadata_store.read(self.directory)) + ) def test_partial_restore_with_placeholder(self): """Basic save and restore test.""" diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/asyncio_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/asyncio_utils.py index f6bff8ddf..db930a996 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/asyncio_utils.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/asyncio_utils.py @@ -14,11 +14,6 @@ """Helpers for asyncio usage.""" -import nest_asyncio - def maybe_apply_nest_asyncio(): - try: - nest_asyncio.apply() - except RuntimeError: - pass + pass diff --git a/checkpoint/pyproject.toml b/checkpoint/pyproject.toml index be98db552..29ba20d0c 100644 --- a/checkpoint/pyproject.toml +++ b/checkpoint/pyproject.toml @@ -30,7 +30,6 @@ dependencies = [ 'numpy', 'pyyaml', 'tensorstore >= 0.1.74', - 'nest_asyncio', 'aiofiles', 'protobuf', 'humanize',