diff --git a/gcsfs/concurrency.py b/gcsfs/concurrency.py new file mode 100644 index 000000000..10b32c9c1 --- /dev/null +++ b/gcsfs/concurrency.py @@ -0,0 +1,23 @@ +import asyncio +from contextlib import asynccontextmanager + + +@asynccontextmanager +async def parallel_tasks_first_completed(coros): + """ + Starts coroutines in parallel and enters the context as soon as + at least one task has completed. Automatically cancels pending tasks + when exiting the context. + """ + tasks = [asyncio.create_task(c) for c in coros] + try: + # Suspend until the first task finishes for maximum responsiveness + done, pending = await asyncio.wait( + set(tasks), return_when=asyncio.FIRST_COMPLETED + ) + yield tasks, done, pending + finally: + # Ensure 'losing' tasks are cancelled immediately + for t in tasks: + if not t.done(): + t.cancel() diff --git a/gcsfs/core.py b/gcsfs/core.py index e6ee5f5eb..ebe264184 100644 --- a/gcsfs/core.py +++ b/gcsfs/core.py @@ -27,6 +27,7 @@ from . import __version__ as version from .checkers import get_consistency_checker +from .concurrency import parallel_tasks_first_completed from .credentials import GoogleCredentials from .inventory_report import InventoryReport from .retry import errs, retry_request, validate_response @@ -1048,17 +1049,23 @@ async def _info(self, path, generation=None, **kwargs): """File information about this path.""" path = self._strip_protocol(path).rstrip("/") if "/" not in path: - try: - out = await self._call("GET", f"b/{path}", json_out=True) - out.update(size=0, type="directory") - except OSError: - # GET bucket failed, try ls; will have no metadata - exists = await self._ls(path) - if exists: - out = {"name": path, "size": 0, "type": "directory"} - else: + async with parallel_tasks_first_completed( + [ + self._call("GET", f"b/{path}", json_out=True), + self._ls(path, max_results=1), + ] + ) as (tasks, done, pending): + get_task, ls_task = tasks + + try: + out = await get_task + out.update(size=0, type="directory") + return out + except OSError: + if await ls_task: + return {"name": path, "size": 0, "type": "directory"} raise FileNotFoundError(path) - return out + # Check directory cache for parent dir parent_path = self._parent(path) parent_cache = self._ls_from_cache(parent_path) diff --git a/gcsfs/tests/integration/test_async_gcsfs.py b/gcsfs/tests/integration/test_async_gcsfs.py index 7b1345c9c..500b6603f 100644 --- a/gcsfs/tests/integration/test_async_gcsfs.py +++ b/gcsfs/tests/integration/test_async_gcsfs.py @@ -230,6 +230,37 @@ async def test_async_info(async_gcs, hns_file_path): assert info["type"] == "file" +@pytest.mark.asyncio +async def test_async_info_fallback(async_gcs, hns_file_path): + """Test that _info falls back to _ls when _call (GET) fails.""" + # Create a dummy file to ensure listing works if we use a path under it + # We want to test bucket level _info fallback, so we test on the bucket itself. + bucket, _, _ = async_gcs.split_path(hns_file_path) + + # We pipe a file to ensure the bucket is not empty. If the bucket is empty, + # _ls returns [], which the fallback logic evaluates as falsy, causing it + # to raise FileNotFoundError even if the bucket exists. + file_path = f"{hns_file_path}/fallback_file" + await async_gcs._pipe_file(file_path, b"data") + + original_call = async_gcs._call + + async def mock_call(*args, **kwargs): + if len(args) >= 2 and args[0] == "GET" and args[1] == f"b/{bucket}": + raise OSError("Simulated 403 Forbidden") + return await original_call(*args, **kwargs) + + async_gcs._call = mock_call + + try: + # Calling _info on the bucket root should fall back to _ls + info = await async_gcs._info(bucket) + assert info["name"] == bucket + assert info["type"] == "directory" + finally: + async_gcs._call = original_call + + @pytest.mark.asyncio async def test_async_rm_recursive(async_gcs, hns_file_path): """Test async _rm recursive.""" diff --git a/gcsfs/tests/perf/microbenchmarks/info/configs.py b/gcsfs/tests/perf/microbenchmarks/info/configs.py index 20ae9ff73..f30977828 100644 --- a/gcsfs/tests/perf/microbenchmarks/info/configs.py +++ b/gcsfs/tests/perf/microbenchmarks/info/configs.py @@ -5,6 +5,10 @@ class InfoConfigurator(ListingConfigurator): param_class = InfoBenchmarkParameters + def build_cases(self, scenario, common_config): + cases = super().build_cases(scenario, common_config) + return cases + def _get_folders_list(self, scenario, common_config): return common_config.get("folders", [1]) diff --git a/gcsfs/tests/test_core.py b/gcsfs/tests/test_core.py index f984d3887..fd6d15bb3 100644 --- a/gcsfs/tests/test_core.py +++ b/gcsfs/tests/test_core.py @@ -2177,6 +2177,119 @@ def test_mv_file_raises_error_for_specific_generation(gcs): gcs.version_aware = original_version_aware +@pytest.mark.asyncio +async def test_info_bucket_optimization(gcs): + bucket = "test-bucket" + + # Mock _call to fail with OSError on GET b/test-bucket + # and mock _ls to return a list + with mock.patch.object(gcs, "_call", new_callable=mock.AsyncMock) as mock_call: + mock_call.side_effect = OSError("Failed to GET bucket") + with mock.patch.object(gcs, "_ls", new_callable=mock.AsyncMock) as mock_ls: + mock_ls.return_value = ["test-bucket/"] + + # Use await gcs._info as it is an async method + info = await gcs._info(bucket) + + # Verify _call was called for the bucket GET + mock_call.assert_called_with("GET", f"b/{bucket}", json_out=True) + + # Verify _ls was called with max_results=1 + mock_ls.assert_awaited_once_with(bucket, max_results=1) + + assert info == {"name": bucket, "size": 0, "type": "directory"} + + +@pytest.mark.asyncio +async def test_info_bucket_not_found_optimization(gcs): + bucket = "non-existent-bucket" + + with mock.patch.object(gcs, "_call", new_callable=mock.AsyncMock) as mock_call: + mock_call.side_effect = OSError("Failed to GET bucket") + with mock.patch.object(gcs, "_ls", new_callable=mock.AsyncMock) as mock_ls: + mock_ls.return_value = [] + + with pytest.raises(FileNotFoundError): + await gcs._info(bucket) + + mock_ls.assert_awaited_once_with(bucket, max_results=1) + + +@pytest.mark.asyncio +async def test_info_bucket_success_parallel(gcs): + bucket = "test-bucket" + + with mock.patch.object(gcs, "_call", new_callable=mock.AsyncMock) as mock_call: + mock_call.return_value = {"name": bucket, "kind": "storage#bucket"} + with mock.patch.object(gcs, "_ls", new_callable=mock.AsyncMock) as mock_ls: + mock_ls.return_value = ["test-bucket/"] + + info = await gcs._info(bucket) + + mock_call.assert_called_with("GET", f"b/{bucket}", json_out=True) + mock_ls.assert_awaited_once_with(bucket, max_results=1) + + assert info == { + "name": bucket, + "kind": "storage#bucket", + "size": 0, + "type": "directory", + } + + +@pytest.mark.asyncio +async def test_info_bucket_ls_exception(gcs): + bucket = "test-bucket" + + with mock.patch.object(gcs, "_call", new_callable=mock.AsyncMock) as mock_call: + mock_call.side_effect = OSError("Failed to GET bucket") + with mock.patch.object(gcs, "_ls", new_callable=mock.AsyncMock) as mock_ls: + mock_ls.side_effect = ValueError("LS error") + + with pytest.raises(ValueError, match="LS error"): + await gcs._info(bucket) + + mock_call.assert_called_with("GET", f"b/{bucket}", json_out=True) + mock_ls.assert_awaited_once_with(bucket, max_results=1) + + +@pytest.mark.asyncio +async def test_info_bucket_other_exception(gcs): + bucket = "test-bucket" + + with mock.patch.object(gcs, "_call", new_callable=mock.AsyncMock) as mock_call: + mock_call.side_effect = ValueError("Some other error") + with mock.patch.object(gcs, "_ls", new_callable=mock.AsyncMock) as mock_ls: + mock_ls.return_value = ["test-bucket/"] + + with pytest.raises(ValueError, match="Some other error"): + await gcs._info(bucket) + + mock_call.assert_called_with("GET", f"b/{bucket}", json_out=True) + mock_ls.assert_awaited_once_with(bucket, max_results=1) + + +@pytest.mark.asyncio +async def test_info_bucket_fallback_success(gcs): + bucket = "test-bucket" + + with mock.patch.object(gcs, "_call", new_callable=mock.AsyncMock) as mock_call: + mock_call.side_effect = OSError("Access denied") + with mock.patch.object(gcs, "_ls", new_callable=mock.AsyncMock) as mock_ls: + mock_ls.return_value = ["test-bucket/some-file"] + + info = await gcs._info(bucket) + + mock_call.assert_called_with("GET", f"b/{bucket}", json_out=True) + mock_ls.assert_awaited_once_with(bucket, max_results=1) + + assert info == { + "name": bucket, + "size": 0, + "type": "directory", + } + + def test_tree(gcs): unique_id = uuid.uuid4().hex base_dir = f"{TEST_BUCKET}/test_tree_regional_{unique_id}"