diff --git a/gcsfs/concurrency.py b/gcsfs/concurrency.py new file mode 100644 index 00000000..10b32c9c --- /dev/null +++ b/gcsfs/concurrency.py @@ -0,0 +1,23 @@ +import asyncio +from contextlib import asynccontextmanager + + +@asynccontextmanager +async def parallel_tasks_first_completed(coros): + """ + Starts coroutines in parallel and enters the context as soon as + at least one task has completed. Automatically cancels pending tasks + when exiting the context. + """ + tasks = [asyncio.create_task(c) for c in coros] + try: + # Suspend until the first task finishes for maximum responsiveness + done, pending = await asyncio.wait( + set(tasks), return_when=asyncio.FIRST_COMPLETED + ) + yield tasks, done, pending + finally: + # Ensure 'losing' tasks are cancelled immediately + for t in tasks: + if not t.done(): + t.cancel() diff --git a/gcsfs/core.py b/gcsfs/core.py index e6ee5f5e..56d21048 100644 --- a/gcsfs/core.py +++ b/gcsfs/core.py @@ -27,6 +27,7 @@ from . import __version__ as version from .checkers import get_consistency_checker +from .concurrency import parallel_tasks_first_completed from .credentials import GoogleCredentials from .inventory_report import InventoryReport from .retry import errs, retry_request, validate_response @@ -1080,15 +1081,22 @@ async def _info(self, path, generation=None, **kwargs): "storageClass": "DIRECTORY", "type": "directory", } - # Check exact file path - try: - exact = await self._get_object(path) - # this condition finds a "placeholder" - still need to check if it's a directory - if not _is_directory_marker(exact): - return exact - except FileNotFoundError: - pass - return await self._get_directory_info(path, bucket, key, generation) + + async with parallel_tasks_first_completed( + [ + self._get_object(path), + self._get_directory_info(path, bucket, key, generation), + ] + ) as (tasks, done, pending): + get_object_task, get_directory_info_task = tasks + + try: + get_object_res = await get_object_task + if not _is_directory_marker(get_object_res): + return get_object_res + except FileNotFoundError: + pass + return await get_directory_info_task async def _get_directory_info(self, path, bucket, key, generation): """ diff --git a/gcsfs/tests/test_concurrency.py b/gcsfs/tests/test_concurrency.py new file mode 100644 index 00000000..be47fd8f --- /dev/null +++ b/gcsfs/tests/test_concurrency.py @@ -0,0 +1,78 @@ +import asyncio + +import pytest + +from gcsfs.concurrency import parallel_tasks_first_completed + + +@pytest.mark.asyncio +async def test_parallel_tasks_first_completed_basic(): + async def slow_task(): + await asyncio.sleep(1) + return "slow" + + async def fast_task(): + await asyncio.sleep(0.1) + return "fast" + + async with parallel_tasks_first_completed([slow_task(), fast_task()]) as ( + tasks, + done, + pending, + ): + assert len(done) == 1 + assert len(pending) == 1 + completed_task = done.pop() + assert completed_task.result() == "fast" + assert len(tasks) == 2 + + +@pytest.mark.asyncio +async def test_parallel_tasks_first_completed_cancellation(): + task_cancelled = False + + async def slow_task(): + nonlocal task_cancelled + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + task_cancelled = True + raise + + async def fast_task(): + await asyncio.sleep(0.1) + return "fast" + + async with parallel_tasks_first_completed([slow_task(), fast_task()]) as ( + tasks, + done, + pending, + ): + assert len(done) == 1 + completed_task = done.pop() + assert completed_task.result() == "fast" + + # After exiting context, slow_task should be cancelled + await asyncio.sleep(0.1) # Give it a moment to run cancellation cleanup + assert task_cancelled + + +@pytest.mark.asyncio +async def test_parallel_tasks_first_completed_exception(): + async def error_task(): + await asyncio.sleep(0.1) + raise ValueError("error") + + async def slow_task(): + await asyncio.sleep(1) + return "slow" + + async with parallel_tasks_first_completed([error_task(), slow_task()]) as ( + tasks, + done, + pending, + ): + assert len(done) == 1 + completed_task = done.pop() + with pytest.raises(ValueError, match="error"): + completed_task.result() diff --git a/gcsfs/tests/test_core.py b/gcsfs/tests/test_core.py index f984d388..23ffa128 100644 --- a/gcsfs/tests/test_core.py +++ b/gcsfs/tests/test_core.py @@ -2367,3 +2367,117 @@ def test_walk(gcs): exp_dirs, exp_files = expected_structure[root] assert set(d_list) == exp_dirs assert set(f_list) == exp_files + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "object_behavior, dir_behavior, expected", + [ + ( + {"return": {"name": TEST_BUCKET + "/file", "type": "file", "size": 100}}, + {"exception": FileNotFoundError}, + {"return": {"type": "file"}}, + ), + ( + {"exception": FileNotFoundError}, + {"return": {"name": TEST_BUCKET + "/file", "type": "directory", "size": 0}}, + {"return": {"type": "directory"}}, + ), + ( + { + "return": { + "name": TEST_BUCKET + "/file/", + "type": "directory", + "size": 0, + } + }, + { + "return": { + "name": TEST_BUCKET + "/file", + "type": "directory", + "size": 0, + "extra": "info", + } + }, + {"return": {"type": "directory", "extra": "info"}}, + ), + ( + {"exception": Exception("Generic error")}, + {"exception": FileNotFoundError}, + {"exception": Exception, "match": "Generic error"}, + ), + ( + {"exception": FileNotFoundError}, + {"exception": Exception("Directory error")}, + {"exception": Exception, "match": "Directory error"}, + ), + ( + {"exception": FileNotFoundError}, + {"exception": FileNotFoundError}, + {"exception": FileNotFoundError}, + ), + ], +) +async def test_info_parallel(gcs, object_behavior, dir_behavior, expected): + path = TEST_BUCKET + "/file" + + with ( + mock.patch.object( + gcs, "_get_object", new_callable=mock.AsyncMock + ) as mock_get_object, + mock.patch.object( + gcs, "_get_directory_info", new_callable=mock.AsyncMock + ) as mock_get_dir, + ): + + if "return" in object_behavior: + mock_get_object.return_value = object_behavior["return"] + elif "exception" in object_behavior: + mock_get_object.side_effect = object_behavior["exception"] + + if "return" in dir_behavior: + mock_get_dir.return_value = dir_behavior["return"] + elif "exception" in dir_behavior: + mock_get_dir.side_effect = dir_behavior["exception"] + + if "exception" in expected: + with pytest.raises(expected["exception"], match=expected.get("match")): + await gcs._info(path) + else: + res = await gcs._info(path) + for k, v in expected["return"].items(): + assert res[k] == v + + assert mock_get_object.call_count == 1 + assert mock_get_dir.call_count == 1 + + +@pytest.mark.asyncio +async def test_info_parallel_dir_first(gcs): + import asyncio + + path = TEST_BUCKET + "/dir" + + with ( + mock.patch.object( + gcs, "_get_object", new_callable=mock.AsyncMock + ) as mock_get_object, + mock.patch.object( + gcs, "_get_directory_info", new_callable=mock.AsyncMock + ) as mock_get_dir, + ): + + # Make _get_object slower than _get_directory_info + async def slow_get_object(*args, **kwargs): + await asyncio.sleep(0.1) + return {"name": path, "type": "file", "size": 100} + + mock_get_object.side_effect = slow_get_object + # Directory check finishes immediately and succeeds + mock_get_dir.return_value = {"name": path, "type": "directory", "size": 0} + + res = await gcs._info(path) + assert res["type"] == "file" + + assert mock_get_object.call_count == 1 + assert mock_get_dir.call_count == 1