diff --git a/cloudbuild/run_tests.sh b/cloudbuild/run_tests.sh old mode 100644 new mode 100755 index 8db017c9..7ba63657 --- a/cloudbuild/run_tests.sh +++ b/cloudbuild/run_tests.sh @@ -110,6 +110,13 @@ case "$TEST_SUITE" in "--deselect=gcsfs/tests/test_core.py::test_mv_file_cache" ) + # Zonal tests with prefetcher cache does not call _cat_file method. + # The following tests depends upon mocking the _cat_file method for regional. + ZONAL_DESELECTS+=( + "--deselect=gcsfs/tests/test_core.py::test_prefetcher_logical_chunk_override" + "--deselect=gcsfs/tests/test_core.py::test_fetch_logical_chunk_exception" + ) + pytest "${ARGS[@]}" "${ZONAL_DESELECTS[@]}" gcsfs/tests/test_core.py ;; esac diff --git a/gcsfs/caching.py b/gcsfs/caching.py index 56a4856f..98167fb0 100644 --- a/gcsfs/caching.py +++ b/gcsfs/caching.py @@ -1,5 +1,7 @@ +import asyncio from collections import deque +import fsspec.asyn from fsspec.caching import BaseCache, register_cache @@ -118,4 +120,315 @@ def _fetch(self, start: int | None, end: int | None) -> bytes: return b"".join(parts) -register_cache(ReadAheadChunked, clobber=True) +class Prefetcher(BaseCache): + """ + Asynchronous prefetching cache that reads ahead. + + This cache spawns a background producer task that fetches sequential + blocks of data before they are explicitly requested. It is highly optimized + for sequential reads but can recover from arbitrary seeks by restarting + the prefetch loop. + + Parameters + ---------- + blocksize : int + Base size of the chunks to read ahead, in bytes. + fetcher : Callable + A coroutine of the form `f(start, end)` which gets bytes from the remote. + size : int + Total size of the file being read. + max_prefetch_size : int, optional + Maximum bytes to prefetch ahead of the current user offset. + Defaults to `max(2 * blocksize, 128MB)`. + concurrency : int, optional + Number of concurrent network requests to use for large chunks. Defaults to 4. + """ + + name = "prefetcher" + + MIN_CHUNK_SIZE = 5 * 1024 * 1024 + DEFAULT_PREFETCH_SIZE = 128 * 1024 * 1024 + + def __init__( + self, + blocksize: int, + fetcher, + size: int, + max_prefetch_size=None, + concurrency=4, + **kwargs, + ): + super().__init__(blocksize, fetcher, size) + self.fetcher = kwargs.pop("fetcher_override", self.fetcher) + self.concurrency = concurrency + self._user_max_prefetch_size = max_prefetch_size + self.sequential_streak = 0 + self.user_offset = 0 + self.current_offset = 0 + self.queue = asyncio.Queue() + self.is_stopped = False + self._active_tasks = set() + self._wakeup_producer = asyncio.Event() + self._current_block = b"" + self._current_block_idx = 0 + self.loop = fsspec.asyn.get_loop() + self.read_history = deque(maxlen=10) + self.history_sum = 0 + + async def _start_producer(): + self._producer_task = asyncio.create_task(self._producer_loop()) + + fsspec.asyn.sync(self.loop, _start_producer) + + def _get_adaptive_blocksize(self) -> int: + """Returns the adaptive blocksize configuration.""" + count = len(self.read_history) + if not count: + avg_size = self.blocksize + else: + avg_size = self.history_sum // count + + # Cap the adaptive blocksize only if the user explicitly set a max prefetch size + if self._user_max_prefetch_size is not None: + return min(avg_size, self._user_max_prefetch_size) + + return avg_size + + @property + def max_prefetch_size(self) -> int: + """Dynamically calculates max prefetch based on user intent or current blocksize.""" + if self._user_max_prefetch_size is not None: + return self._user_max_prefetch_size + + return max(2 * self._get_adaptive_blocksize(), self.DEFAULT_PREFETCH_SIZE) + + async def _cancel_all_tasks(self, wait=False): + self.is_stopped = True + self._wakeup_producer.set() + + tasks_to_wait = [] + + if hasattr(self, "_producer_task") and isinstance( + self._producer_task, asyncio.Task + ): + if not self._producer_task.done(): + self._producer_task.cancel() + tasks_to_wait.append(self._producer_task) + + for task in list(self._active_tasks): + if not task.done(): + tasks_to_wait.append(task) + + self._active_tasks.clear() + if hasattr(self, "queue"): + while not self.queue.empty(): + try: + self.queue.get_nowait() + except asyncio.QueueEmpty: + break + + if wait and tasks_to_wait: + await asyncio.gather(*tasks_to_wait, return_exceptions=True) + + async def _restart_producer(self): + # Cancel old tasks without waiting + await self._cancel_all_tasks(wait=False) + self.is_stopped = False + self.sequential_streak = 0 + self.read_history.clear() + self.history_sum = 0 + self._producer_task = asyncio.create_task(self._producer_loop()) + + async def _producer_loop(self): + try: + while not self.is_stopped: + await self._wakeup_producer.wait() + self._wakeup_producer.clear() + + block_size = self._get_adaptive_blocksize() + prefetch_size = min( + (self.sequential_streak + 1) * block_size, + self.max_prefetch_size, + ) + + while ( + not self.is_stopped + and (self.current_offset - self.user_offset) < prefetch_size + and self.current_offset < self.size + ): + space_remaining = self.size - self.current_offset + prefetch_space_available = prefetch_size - ( + self.current_offset - self.user_offset + ) + if ( + space_remaining >= block_size + and prefetch_space_available < block_size + ): + break + + if prefetch_size >= self.MIN_CHUNK_SIZE: + if prefetch_space_available >= self.MIN_CHUNK_SIZE: + actual_size = min( + max(self.MIN_CHUNK_SIZE, block_size), + space_remaining, + ) + else: + break + else: + actual_size = min(block_size, space_remaining) + + if self.sequential_streak < 2: + sfactor = ( + self.concurrency + if actual_size >= self.MIN_CHUNK_SIZE + else min(self.concurrency, 1) + ) # random usecase + else: + sfactor = ( + min( + self.concurrency, + max(1, actual_size * self.concurrency // prefetch_size), + ) + if actual_size >= self.MIN_CHUNK_SIZE + else 1 + ) # sequential usecase + + download_task = asyncio.create_task( + self.fetcher( + self.current_offset, actual_size, split_factor=sfactor + ) + ) + self._active_tasks.add(download_task) + download_task.add_done_callback(self._active_tasks.discard) + + await self.queue.put(download_task) + self.current_offset += actual_size + + except asyncio.CancelledError: + pass + except Exception as e: + await self.queue.put(e) + self.is_stopped = True + + async def read(self): + """Reads the next chunk from the object.""" + if self.user_offset >= self.size: + return b"" + if self.is_stopped and self.queue.empty(): + # This may happen if user read despite previous read produced an exception. + raise RuntimeError("Could not fetch data, the producer is stopped") + + if self.queue.empty(): + self._wakeup_producer.set() + + task = await self.queue.get() + + # Check if the producer pushed an exception + if isinstance(task, Exception): + self.is_stopped = True + raise task + + if task.done(): + self.hit_count += 1 + else: + self.miss_count += 1 + + try: + block = await task + self.sequential_streak += 1 + if self.sequential_streak >= 2: + self._wakeup_producer.set() # starts prefetching. + return block + except asyncio.CancelledError: + raise + except Exception as e: + self.is_stopped = True + raise e + + async def seek(self, new_offset): + if new_offset == self.user_offset: + return + + self.user_offset = new_offset + self.current_offset = new_offset + await self._restart_producer() + + async def _async_fetch(self, start, end): + if start != self.user_offset: + # We seeked elsewhere, reset the current block + self._current_block = b"" + self._current_block_idx = 0 + await self.seek(start) + + requested_size = end - start + chunks = [] + collected = 0 + + # Update read history for the adaptive blocksize logic + if requested_size > 0: + if len(self.read_history) == self.read_history.maxlen: + self.history_sum -= self.read_history[0] + self.read_history.append(requested_size) + self.history_sum += requested_size + + available_in_block = len(self._current_block) - self._current_block_idx + if available_in_block > 0: + take = min(requested_size, available_in_block) + + if take == len(self._current_block) and self._current_block_idx == 0: + chunks.append(self._current_block) + else: + chunks.append( + self._current_block[ + self._current_block_idx : self._current_block_idx + take + ] + ) + + self._current_block_idx += take + collected += take + self.user_offset += take + + while collected < requested_size and self.user_offset < self.size: + block = await self.read() + if not block: + break + + needed = requested_size - collected + if len(block) > needed: + chunks.append(block[:needed]) + self._current_block = block + self._current_block_idx = needed + collected += needed + self.user_offset += needed + break + else: + chunks.append(block) + collected += len(block) + self.user_offset += len(block) + self._current_block = b"" + self._current_block_idx = 0 + + if len(chunks) == 1: + out = chunks[0] + else: + out = b"".join(chunks) + + self.total_requested_bytes += len(out) + return out + + def _fetch(self, start: int | None, stop: int | None) -> bytes: + if start is None: + start = 0 + if stop is None: + stop = self.size + if start >= self.size or start >= stop: + return b"" + return fsspec.asyn.sync(self.loop, self._async_fetch, start, stop) + + def close(self): + """Clean shutdown. Cancels tasks and waits for them to abort.""" + fsspec.asyn.sync(self.loop, self._cancel_all_tasks, True) + + +for gcs_cache in [ReadAheadChunked, Prefetcher]: + register_cache(gcs_cache, clobber=True) diff --git a/gcsfs/core.py b/gcsfs/core.py index def29372..71774a32 100644 --- a/gcsfs/core.py +++ b/gcsfs/core.py @@ -26,6 +26,7 @@ from fsspec.utils import setup_logging, stringify_path from . import __version__ as version +from .caching import Prefetcher from .checkers import get_consistency_checker from .credentials import GoogleCredentials from .inventory_report import InventoryReport @@ -1998,6 +1999,15 @@ def __init__( if not key: raise OSError("Attempt to open a bucket") self.generation = _coalesce_generation(generation, path_generation) + + if cache_options is None: + cache_options = {} + + if "r" in mode: + _FETCHER_OVERRIDE = {Prefetcher.name: self._fetch_logical_chunk} + if cache_type in _FETCHER_OVERRIDE: + cache_options["fetcher_override"] = _FETCHER_OVERRIDE[cache_type] + super().__init__( gcsfs, path, @@ -2197,6 +2207,51 @@ def _fetch_range(self, start=None, end=None): return b"" raise + async def _fetch_logical_chunk(self, start_offset, total_size, split_factor=1): + """ + Async fetcher mapped to the Prefetcher cache for regional buckets. + Uses concurrent HTTP range requests for split downloads. + """ + if split_factor == 1: + return await self.gcsfs._cat_file( + self.path, start=start_offset, end=start_offset + total_size + ) + + part_size = total_size // split_factor + tasks = [] + + for i in range(split_factor): + offset = start_offset + (i * part_size) + actual_size = ( + part_size if i < split_factor - 1 else total_size - (i * part_size) + ) + + tasks.append( + asyncio.create_task( + self.gcsfs._cat_file( + self.path, start=offset, end=offset + actual_size + ) + ) + ) + + try: + results = await asyncio.gather(*tasks, return_exceptions=True) + for res in results: + if isinstance(res, Exception): + raise res + return b"".join(results) + except asyncio.CancelledError as e: + for t in tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise e + + def close(self): + if hasattr(self, "cache") and self.cache and hasattr(self.cache, "close"): + self.cache.close() + super().close() + def _convert_fixed_key_metadata(metadata, *, from_google=False): """ diff --git a/gcsfs/extended_gcsfs.py b/gcsfs/extended_gcsfs.py index 550f5d46..94c07fae 100644 --- a/gcsfs/extended_gcsfs.py +++ b/gcsfs/extended_gcsfs.py @@ -2,6 +2,8 @@ import logging import os import uuid +import weakref +from concurrent.futures import ThreadPoolExecutor from enum import Enum from glob import has_magic from io import BytesIO @@ -67,6 +69,10 @@ def __init__(self, *args, finalize_on_close=False, **kwargs): if self.credentials.token == "anon": self.credential = AnonymousCredentials() self._storage_layout_cache = {} + self.memmove_executor = ThreadPoolExecutor( + max_workers=kwargs.get("memmove_max_workers", 8) + ) + weakref.finalize(self, self.memmove_executor.shutdown) @property def grpc_client(self): diff --git a/gcsfs/tests/test_caching.py b/gcsfs/tests/test_caching.py index e4bc0fd5..1f5ab3b7 100644 --- a/gcsfs/tests/test_caching.py +++ b/gcsfs/tests/test_caching.py @@ -1,6 +1,10 @@ +import asyncio +from unittest import mock + +import fsspec.asyn import pytest -from gcsfs.caching import ReadAheadChunked +from gcsfs.caching import Prefetcher, ReadAheadChunked class MockVectorFetcher: @@ -198,3 +202,396 @@ def test_out_of_bounds(cache_setup): """Test start >= size returns empty.""" cache, _ = cache_setup assert cache._fetch(150, 200) == b"" + + +class TrackedAsyncMockFetcher: + """Simulates an async backend and tracks calls for assertions.""" + + def __init__(self, data: bytes): + self.data = data + self.should_fail = False + self.calls = [] + + async def __call__(self, start, size, split_factor=1): + self.calls.append({"start": start, "size": size, "split_factor": split_factor}) + if self.should_fail: + raise RuntimeError("Mocked network error") + + await asyncio.sleep(0.001) + end = min(start + size, len(self.data)) + return self.data[start:end] + + +@pytest.fixture +def prefetcher_setup(source_data): + """Provides a fresh Prefetcher and its mocked fetcher for each test.""" + fetcher = TrackedAsyncMockFetcher(source_data) + + cache = Prefetcher( + blocksize=10, + fetcher=fetcher, + size=len(source_data), + max_prefetch_size=30, + concurrency=4, + ) + yield cache, fetcher + cache.close() + + +def test_prefetcher_initial_state(prefetcher_setup): + cache, _ = prefetcher_setup + assert cache.user_offset == 0 + assert cache.sequential_streak == 0 + assert not cache.is_stopped + + +def test_prefetcher_sequential_reads(prefetcher_setup, source_data): + cache, _ = prefetcher_setup + + res1 = cache._fetch(0, 15) + assert res1 == source_data[0:15] + assert cache.sequential_streak > 0 + res2 = cache._fetch(15, 25) + assert res2 == source_data[15:25] + + +def test_prefetcher_out_of_bounds(prefetcher_setup): + cache, _ = prefetcher_setup + res = cache._fetch(250, 260) + assert res == b"" + + +def test_prefetcher_with_no_offsets(prefetcher_setup, source_data): + cache, _ = prefetcher_setup + res = cache._fetch(None, None) + assert res == source_data + + +def test_prefetcher_seek_resets_streak(prefetcher_setup, source_data): + cache, _ = prefetcher_setup + + cache._fetch(0, 10) + assert cache.sequential_streak > 0 + + res = cache._fetch(50, 60) + assert res == source_data[50:60] + assert cache.user_offset == 60 + + +def test_prefetcher_exact_block_reads(prefetcher_setup, source_data): + """Test reading exactly the blocksize increments streak and fetches correctly.""" + cache, fetcher = prefetcher_setup + + res1 = cache._fetch(0, 10) + assert res1 == source_data[0:10] + assert cache.sequential_streak == 1 + + res2 = cache._fetch(10, 20) + assert res2 == source_data[10:20] + assert cache.sequential_streak == 2 + + assert len(fetcher.calls) >= 2 + + +def test_prefetcher_adaptive_small_reads(prefetcher_setup, source_data): + """Test that reading a small amount scales the fetcher down to match.""" + cache, fetcher = prefetcher_setup + + # Fetch 4 bytes. Blocksize is 10 originally, but the read history makes adaptive size=4. + res1 = cache._fetch(0, 4) + assert res1 == source_data[0:4] + + # Because adaptive blocksize=4, the producer specifically fetched 4 bytes. + # Therefore, 0 bytes remain in the zero-copy block. + assert len(cache._current_block) - cache._current_block_idx == 0 + assert cache.user_offset == 4 + + # Verify the fetcher only requested 4 bytes from the backend + assert fetcher.calls[0]["size"] == 4 + + +def test_prefetcher_partial_read_from_queued_block(prefetcher_setup, source_data): + """Test zero-copy pointer logic when the queued block is larger than the read request.""" + cache, fetcher = prefetcher_setup + + # Manually queue an explicit 10-byte block simulating background prefetching + task = asyncio.Future() + task.set_result(source_data[0:10]) + cache.queue.put_nowait(task) + + # User only asks for 4 bytes out of the 10-byte queued block + res1 = cache._fetch(0, 4) + assert res1 == source_data[0:4] + + # 10 bytes were queued, 4 consumed, leaving exactly 6 in the zero-copy buffer + assert len(cache._current_block) - cache._current_block_idx == 6 + assert cache.user_offset == 4 + + # The next read should drain the remaining zero-copy buffer without fetching + res2 = cache._fetch(4, 8) + assert res2 == source_data[4:8] + assert cache.user_offset == 8 + assert len(cache._current_block) - cache._current_block_idx == 2 + + +def test_prefetcher_cross_block_read(prefetcher_setup, source_data): + """Test requesting a large chunk that spans multiple underlying prefetch blocks.""" + cache, _ = prefetcher_setup + res = cache._fetch(0, 25) + + assert res == source_data[0:25] + assert cache.user_offset == 25 + + # Read history becomes 25, adaptive size becomes 25. + # The producer fetches 25, we consume 25. Exactly 0 remain. + assert len(cache._current_block) - cache._current_block_idx == 0 + + +def test_prefetcher_seek_same_offset(prefetcher_setup): + """Test that seeking to the current user_offset is a no-op and does not clear buffers.""" + cache, _ = prefetcher_setup + cache._fetch(0, 5) + streak_before = cache.sequential_streak + block_before = cache._current_block + idx_before = cache._current_block_idx + + fsspec.asyn.sync(cache.loop, cache.seek, cache.user_offset) + assert cache.sequential_streak == streak_before + assert cache._current_block == block_before + assert cache._current_block_idx == idx_before + + +def test_prefetcher_eof_handling(prefetcher_setup, source_data): + """Test behavior when fetching up to and past the file size limit.""" + cache, _ = prefetcher_setup + res = cache._fetch(95, 110) + assert res == source_data[95:100] + assert cache._fetch(105, 115) == b"" + + +def test_prefetcher_producer_error_propagation(prefetcher_setup): + """Test that exceptions in the background fetcher task surface to the caller.""" + cache, fetcher = prefetcher_setup + fetcher.should_fail = True + with pytest.raises(RuntimeError, match="Mocked network error"): + cache._fetch(0, 10) + assert cache.is_stopped is True + + +def test_prefetcher_dynamic_split_factor(prefetcher_setup, source_data): + """Test that split_factor increases for large chunks on sequential reads.""" + cache, fetcher = prefetcher_setup + + with mock.patch.object(Prefetcher, "MIN_CHUNK_SIZE", 5): + cache._fetch(0, 10) + cache._fetch(10, 20) + + fsspec.asyn.sync(cache.loop, asyncio.sleep, 0.05) + + recent_calls = [c for c in fetcher.calls if c["start"] >= 20] + assert len(recent_calls) > 0 + assert recent_calls[0]["split_factor"] > 1 + + +def test_prefetcher_max_prefetch_limit(prefetcher_setup): + """Test that the producer pauses when the queue hits the max_prefetch_size.""" + cache, _ = prefetcher_setup + cache._fetch(0, 1) + fsspec.asyn.sync(cache.loop, asyncio.sleep, 0.05) + max_expected_offset = cache.user_offset + cache.max_prefetch_size + cache.blocksize + assert cache.current_offset <= max_expected_offset + + +def test_prefetcher_close_while_active(prefetcher_setup): + """Test that closing the prefetcher safely cancels pending background tasks.""" + cache, _ = prefetcher_setup + cache._fetch(0, 10) + cache._fetch(10, 20) + + assert len(cache._active_tasks) > 0 or not cache.queue.empty() + assert cache.is_stopped is False + + cache.close() + + assert cache.is_stopped is True + assert len(cache._active_tasks) == 0 + assert cache.queue.empty() is True + + +def test_prefetcher_adaptive_averaging(prefetcher_setup): + """Verify that the blocksize adapts upwards and downwards based on read history.""" + cache, _ = prefetcher_setup + assert cache._get_adaptive_blocksize() == 10 + + # Test upward adaptation (5 reads of 12 = 60 bytes used) + for _ in range(5): + cache._fetch(cache.user_offset, cache.user_offset + 12) + + # Average of five 12s is 12 + assert cache._get_adaptive_blocksize() == 12 + + # Test downward adaptation (5 reads of 4 = 20 bytes used, 80 total) + for _ in range(5): + cache._fetch(cache.user_offset, cache.user_offset + 4) + + # Average of five 12s and five 4s is (60 + 20) / 10 = 80 / 10 = 8 + assert cache._get_adaptive_blocksize() == 8 + + # Push out all the 12s, leaving only 4s (5 reads of 4 = 20 bytes used, exactly 100 total) + for _ in range(5): + cache._fetch(cache.user_offset, cache.user_offset + 4) + + # Average of ten 4s is 4 + assert cache._get_adaptive_blocksize() == 4 + + +def test_prefetcher_history_eviction(prefetcher_setup): + """Verify that only the last 10 reads impact the adaptive blocksize.""" + cache, _ = prefetcher_setup + for _ in range(10): + cache._fetch(cache.user_offset, cache.user_offset + 1) + + assert cache.history_sum == 10 + assert len(cache.read_history) == 10 + + # Adding a large read should evict the oldest 1 + cache._fetch(cache.user_offset, cache.user_offset + 10) + assert cache.history_sum == 19 + assert cache.read_history[-1] == 10 + + +def test_prefetcher_seek_resets_history(prefetcher_setup): + """Verify that a seek clears adaptive history to prevent stale logic.""" + cache, _ = prefetcher_setup + cache._fetch(0, 100) + cache._fetch(100, 200) + cache._fetch(200, 300) + assert cache.history_sum > 0 + + fsspec.asyn.sync(cache.loop, cache.seek, 500) + assert cache.history_sum == 0 + assert len(cache.read_history) == 0 + assert cache._get_adaptive_blocksize() == cache.blocksize + + +def test_prefetcher_queue_empty_race_condition(prefetcher_setup): + """ + Verify that the defensive asyncio.QueueEmpty catch works if the queue + reports not empty but actually contains no items. + """ + cache, _ = prefetcher_setup + + while not cache.queue.empty(): + cache.queue.get_nowait() + + with mock.patch.object(cache.queue, "empty", side_effect=[False, True]): + fsspec.asyn.sync(cache.loop, cache._cancel_all_tasks, False) + + +def test_producer_loop_uses_adaptive_size(prefetcher_setup, source_data): + """Verify the producer actually fetches using the adaptive blocksize.""" + cache, fetcher = prefetcher_setup + + with mock.patch.object(cache, "_get_adaptive_blocksize", return_value=15): + cache._wakeup_producer.set() + fsspec.asyn.sync(cache.loop, asyncio.sleep, 0.1) + + prefetch_calls = [c for c in fetcher.calls] + assert len(prefetch_calls) > 0 + assert prefetch_calls[-1]["size"] == 15 + + +def test_prefetcher_producer_exception_handling(prefetcher_setup): + """ + Verify that an unexpected exception inside the producer loop is caught, + placed into the queue, and stops the cache. + """ + cache, _ = prefetcher_setup + + with mock.patch.object( + cache, "_get_adaptive_blocksize", side_effect=Exception("Mocked Error!") + ): + cache._wakeup_producer.set() + + fsspec.asyn.sync(cache.loop, asyncio.sleep, 0.05) + with pytest.raises(Exception, match="Mocked Error!"): + cache._fetch(0, 10) + + +def test_prefetcher_producer_early_stop(prefetcher_setup): + """ + Verify that if the cache is stopped while the producer is waiting, + waking it up causes it to immediately break the loop. + """ + cache, _ = prefetcher_setup + cache.is_stopped = True + cache._wakeup_producer.set() + fsspec.asyn.sync(cache.loop, asyncio.sleep, 0.05) + assert cache._producer_task.done() is True + + +def test_prefetcher_read_after_producer_stops(prefetcher_setup): + """ + Test that reading from a Prefetcher after the producer has fatally + stopped (and the queue is drained of the original error) raises our + new RuntimeError safeguard. + """ + cache, fetcher = prefetcher_setup + fetcher.should_fail = True + with pytest.raises(RuntimeError, match="Mocked network error"): + cache._fetch(0, 10) + + assert cache.is_stopped is True + assert cache.queue.empty() is True + assert cache.user_offset < cache.size + with pytest.raises( + RuntimeError, match="Could not fetch data, the producer is stopped" + ): + cache._fetch(cache.user_offset, cache.user_offset + 10) + + +def test_prefetcher_zero_copy_full_current_block(prefetcher_setup, source_data): + """ + Test the zero-copy optimization path where the requested size matches + the entire available _current_block, and the index is 0. + """ + cache, _ = prefetcher_setup + cache._current_block = source_data[0:10] + cache._current_block_idx = 0 + res = cache._fetch(0, 10) + + assert res == source_data[0:10] + assert res is cache._current_block + assert cache.user_offset == 10 + + +def test_prefetcher_break_on_empty_block(prefetcher_setup): + """ + Test that _async_fetch safely breaks its collection loop if read() + unexpectedly returns an empty bytes object. + """ + cache, _ = prefetcher_setup + task = asyncio.Future() + task.set_result(b"") + cache.queue.put_nowait(task) + res = cache._fetch(0, 10) + assert res == b"" + + +def test_prefetcher_cancelled_error_propagation(prefetcher_setup): + """ + Verify that an asyncio.CancelledError is re-raised without altering + the is_stopped flag, differentiating a deliberate cancellation from a crash. + """ + cache, _ = prefetcher_setup + + async def simulate_cancellation(): + task = cache.loop.create_future() + task.cancel() + cache.queue.put_nowait(task) + assert cache.is_stopped is False + with pytest.raises(asyncio.CancelledError): + await cache.read() + + fsspec.asyn.sync(cache.loop, simulate_cancellation) + assert cache.is_stopped is False diff --git a/gcsfs/tests/test_core.py b/gcsfs/tests/test_core.py index 285ee03f..60fd3e3c 100644 --- a/gcsfs/tests/test_core.py +++ b/gcsfs/tests/test_core.py @@ -1,3 +1,4 @@ +import asyncio import io import os import uuid @@ -8,6 +9,7 @@ from urllib.parse import parse_qs, unquote, urlparse from uuid import uuid4 +import fsspec.asyn import fsspec.core import pytest import requests @@ -18,6 +20,7 @@ import gcsfs.tests.settings from gcsfs import GCSFileSystem from gcsfs import __version__ as version +from gcsfs.caching import Prefetcher from gcsfs.credentials import GoogleCredentials from gcsfs.tests.conftest import a, allfiles, b, csv_files, files, text_files from gcsfs.tests.utils import tempdir, tmpfile @@ -1921,3 +1924,193 @@ def test_mv_file_raises_error_for_specific_generation(gcs): gcs.mv_file(src, dest) finally: gcs.version_aware = original_version_aware + + +def test_gcsfile_prefetcher_sequential_read(gcs): + """ + Test that the Prefetcher cache correctly handles sequential reads + and returns the expected data chunks. + """ + fn = f"{TEST_BUCKET}/prefetcher_seq.txt" + + # Create a 2MB file + file_size = 2 * 1024 * 1024 + data = os.urandom(file_size) + gcs.pipe(fn, data) + + # Open with Prefetcher using a small block size (512KB) to trigger multiple fetches + block_size = 512 * 1024 + with gcs.open(fn, "rb", cache_type="prefetcher", block_size=block_size) as f: + assert isinstance(f.cache, Prefetcher) + + # Read the first block + chunk1 = f.read(block_size) + assert chunk1 == data[:block_size] + assert f.cache.user_offset == block_size + + # Read a smaller chunk to test remainder logic + chunk2 = f.read(1024) + assert chunk2 == data[block_size : block_size + 1024] + + # Read the rest of the file + chunk3 = f.read() + assert chunk3 == data[block_size + 1024 :] + + # Ensure we reached the end + assert f.read(10) == b"" + + +def test_gcsfile_prefetcher_seek(gcs): + """ + Test that the Prefetcher gracefully handles forward and backward seeks, + which requires clearing remainders and restarting the producer loop. + """ + fn = f"{TEST_BUCKET}/prefetcher_seek.txt" + data = b"A" * 1000 + b"B" * 1000 + b"C" * 1000 + gcs.pipe(fn, data) + + with gcs.open(fn, "rb", cache_type="prefetcher", block_size=1000) as f: + # Initial read + assert f.read(500) == b"A" * 500 + + # Seek forward (misaligned with block size) + f.seek(1500) + assert f.read(500) == b"B" * 500 + + # Seek backward (forces task cancellation and restart) + f.seek(0) + assert f.read(1000) == b"A" * 1000 + + # Seek to exact end + f.seek(3000) + assert f.read(10) == b"" + + +def test_gcsfile_prefetcher_cleanup(gcs): + """ + Test that calling close() on a GCSFile explicitly stops the Prefetcher's + background asyncio tasks to prevent 'Task was destroyed' warnings. + """ + fn = f"{TEST_BUCKET}/prefetcher_cleanup.txt" + data = b"X" * (1024 * 1024) + gcs.pipe(fn, data) + + f = gcs.open(fn, "rb", cache_type="prefetcher", block_size=1024 * 100) + + # Read a tiny bit to ensure the producer task is spawned and running + f.read(10) + + cache = f.cache + assert isinstance(cache, Prefetcher) + assert not cache.is_stopped + + # Close the file wrapper + f.close() + + # Verify that the internal tasks were flagged to stop and cleaned up + assert cache.is_stopped is True + assert len(cache._active_tasks) == 0 + + +@pytest.mark.asyncio +async def test_prefetcher_logical_chunk_override(gcs): + """ + Test that the custom `_fetch_logical_chunk` override in GCSFile + is correctly passed to and used by the Prefetcher cache. + """ + fn = f"{TEST_BUCKET}/prefetcher_override.txt" + data = b"Y" * 1024 + gcs.pipe(fn, data) + original_fetch = gcsfs.core.GCSFile._fetch_logical_chunk + + async def mock_fetch(self_obj, start_offset, total_size, split_factor=1): + return await original_fetch( + self_obj, start_offset, total_size, split_factor=split_factor + ) + + with mock.patch.object( + gcsfs.core.GCSFile, + "_fetch_logical_chunk", + autospec=True, + side_effect=mock_fetch, + ) as mock_fetch_obj: + + with gcs.open(fn, "rb", cache_type="prefetcher", block_size=256) as f: + assert f.cache.fetcher.__name__ == "_fetch_logical_chunk" + f.read(512) + + assert mock_fetch_obj.call_count >= 1 + + +def test_fetch_logical_chunk_single(gcs): + """Test standard single-chunk fetch when split_factor is 1.""" + fn = f"{TEST_BUCKET}/fetch_single.txt" + data = b"0123456789" * 10 + gcs.pipe(fn, data) + + with gcs.open(fn, "rb") as f: + # Run the async method in the gcs fsspec loop + res = fsspec.asyn.sync(gcs.loop, f._fetch_logical_chunk, 0, 100, split_factor=1) + assert res == data + + +def test_fetch_logical_chunk_multi(gcs): + """Test concurrent multi-chunk fetch when split_factor > 1.""" + fn = f"{TEST_BUCKET}/fetch_multi.txt" + data = b"A" * 500 + b"B" * 500 + gcs.pipe(fn, data) + + with gcs.open(fn, "rb") as f: + # 1000 bytes split across 3 workers (333, 333, 334) + res = fsspec.asyn.sync( + gcs.loop, f._fetch_logical_chunk, 0, 1000, split_factor=3 + ) + assert res == data + assert len(res) == 1000 + + +def test_fetch_logical_chunk_exception(gcs): + """Test that exceptions inside the concurrent gather are raised.""" + fn = f"{TEST_BUCKET}/fetch_exception.txt" + data = b"0123456789" * 10 + gcs.pipe(fn, data) + + with gcs.open(fn, "rb") as f: + original_cat = f.gcsfs._cat_file + + async def mock_cat_file(path, start, end, **kwargs): + if start > 0: # Simulate a failure on the second chunk + raise RuntimeError("Simulated download failure") + return await original_cat(path, start=start, end=end, **kwargs) + + with mock.patch.object(f.gcsfs, "_cat_file", side_effect=mock_cat_file): + with pytest.raises(RuntimeError, match="Simulated download failure"): + fsspec.asyn.sync( + gcs.loop, f._fetch_logical_chunk, 0, 100, split_factor=2 + ) + + +def test_fetch_logical_chunk_cancellation(gcs): + """Test that cancelling _fetch_logical_chunk properly cancels its inner tasks.""" + fn = f"{TEST_BUCKET}/fetch_cancel.txt" + data = b"0123456789" * 10 + gcs.pipe(fn, data) + + with gcs.open(fn, "rb", cache_type="prefetcher") as f: + + async def hanging_cat_file(path, start, end, **kwargs): + await asyncio.sleep(10) + return b"" + + with mock.patch.object(f.gcsfs, "_cat_file", side_effect=hanging_cat_file): + + async def run_and_cancel(): + task = asyncio.create_task( + f._fetch_logical_chunk(0, 100, split_factor=3) + ) + await asyncio.sleep(0.05) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + fsspec.asyn.sync(gcs.loop, run_and_cancel) diff --git a/gcsfs/tests/test_extended_gcsfs.py b/gcsfs/tests/test_extended_gcsfs.py index 9a449d05..f7eab452 100644 --- a/gcsfs/tests/test_extended_gcsfs.py +++ b/gcsfs/tests/test_extended_gcsfs.py @@ -279,27 +279,6 @@ def test_readline_blocksize_zb(extended_gcsfs, gcs_bucket_mocks): assert result == expected -def test_mrd_stream_cleanup(extended_gcsfs, gcs_bucket_mocks): - """ - Tests that mrd stream is properly closed with file closure. - """ - with gcs_bucket_mocks( - json_data, bucket_type_val=BucketType.ZONAL_HIERARCHICAL - ) as mocks: - if not extended_gcsfs.on_google: - - def close_side_effect(): - mocks["downloader"].is_stream_open = False - - mocks["downloader"].close.side_effect = close_side_effect - - with extended_gcsfs.open(file_path, "rb") as f: - assert f.mrd is not None - - assert True is f.closed - assert False is f.mrd.is_stream_open - - def test_read_unfinalized_file_using_mrd(extended_gcsfs, file_path): "Tests that mrd can read from an unfinalized file successfully" if not extended_gcsfs.on_google: diff --git a/gcsfs/tests/test_zb_hns_utils.py b/gcsfs/tests/test_zb_hns_utils.py index c83f3a16..91034a55 100644 --- a/gcsfs/tests/test_zb_hns_utils.py +++ b/gcsfs/tests/test_zb_hns_utils.py @@ -1,3 +1,5 @@ +import concurrent.futures +import ctypes import logging from unittest import mock @@ -5,6 +7,7 @@ from google.api_core.exceptions import NotFound from gcsfs import zb_hns_utils +from gcsfs.zb_hns_utils import DirectMemmoveBuffer, MRDPool mock_grpc_client = mock.Mock() bucket_name = "test-bucket" @@ -23,7 +26,6 @@ async def test_download_range(): mock_mrd = mock.AsyncMock() expected_data = b"test data from download" - # Simulate the download_ranges method writing data to the buffer async def mock_download_ranges(ranges): _offset, _length, buffer = ranges[0] buffer.write(expected_data) @@ -183,3 +185,220 @@ async def test_close_mrd(caplog): "Error closing AsyncMultiRangeDownloader for test-bucket/test-object: Close failed" in caplog.text ) + + +@pytest.mark.asyncio +async def test_mrd_pool_close(): + gcsfs_mock = mock.Mock() + gcsfs_mock._get_grpc_client = mock.AsyncMock() + + mrd_instance_mock = mock.AsyncMock() + + with mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.create_mrd", + return_value=mrd_instance_mock, + ): + pool = MRDPool(gcsfs_mock, "bucket", "obj", "123", pool_size=1) + await pool.initialize() + + await pool.close() + mrd_instance_mock.close.assert_awaited_once() + assert len(pool._all_mrds) == 0 + + +@pytest.fixture +def mock_gcsfs(): + gcsfs_mock = mock.Mock() + gcsfs_mock._get_grpc_client = mock.AsyncMock() + return gcsfs_mock + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.create_mrd", + new_callable=mock.AsyncMock, +) +async def test_mrd_pool_scaling(create_mrd_mock, mock_gcsfs): + mrd_instance_mock = mock.AsyncMock() + mrd_instance_mock.persisted_size = 1024 + create_mrd_mock.return_value = mrd_instance_mock + + pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=2) + + await pool.initialize() + assert pool.persisted_size == 1024 + assert pool._active_count == 1 + create_mrd_mock.assert_awaited_once() + + async with pool.get_mrd() as mrd1: + assert mrd1 == mrd_instance_mock + + # Since mrd1 is in use, getting another one should spawn a new MRD + async with pool.get_mrd() as _: + assert pool._active_count == 2 + assert create_mrd_mock.call_count == 2 + + # Both should have been returned to the free queue + assert pool._free_mrds.qsize() == 2 + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.create_mrd", + new_callable=mock.AsyncMock, +) +async def test_mrd_pool_double_initialize(create_mrd_mock, mock_gcsfs): + pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=2) + + await pool.initialize() + await pool.initialize() # Second call should be a no-op + + assert pool._active_count == 1 + create_mrd_mock.assert_awaited_once() + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.create_mrd", + new_callable=mock.AsyncMock, +) +async def test_mrd_pool_get_mrd_creation_error(create_mrd_mock, mock_gcsfs): + # First creation succeeds during initialization + valid_mrd = mock.AsyncMock() + + # Second creation fails when pool tries to scale + create_mrd_mock.side_effect = [valid_mrd, Exception("Network Error")] + + pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=2) + await pool.initialize() + + # Consume the initialized MRD + async def consume_and_error(): + async with pool.get_mrd() as _: + # Try to get a second one, which forces a spawn that will fail + with pytest.raises(Exception, match="Network Error"): + async with pool.get_mrd() as _: + pass + + await consume_and_error() + + # Active count should remain 1 because the second creation failed and rolled back + assert pool._active_count == 1 + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.create_mrd", + new_callable=mock.AsyncMock, +) +async def test_mrd_pool_close_with_exceptions(create_mrd_mock, mock_gcsfs): + bad_mrd_instance = mock.AsyncMock() + bad_mrd_instance.close.side_effect = Exception("Close failed") + create_mrd_mock.return_value = bad_mrd_instance + + pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=1) + await pool.initialize() + + # Should not raise an exception, even though the internal close() fails + await pool.close() + + bad_mrd_instance.close.assert_awaited_once() + assert len(pool._all_mrds) == 0 + + +@mock.patch("gcsfs.zb_hns_utils.ctypes.memmove") +def test_direct_memmove_buffer_error_handling(mock_memmove): + size = 20 + buffer_array = (ctypes.c_char * size)() + start_address = ctypes.addressof(buffer_array) + end_address = start_address + size + + # Simulate an access violation or similar error during memory copy + mock_memmove.side_effect = MemoryError("Segfault simulated") + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) + + # First write triggers the background error + future = buf.write(b"bad data") + + # Wait for the background thread to actually fail + with pytest.raises(MemoryError): + future.result() + + # Subsequent writes should raise the stored error immediately + with pytest.raises(MemoryError, match="Segfault simulated"): + buf.write(b"more data") + + # Close should also raise the stored error. + with pytest.raises(MemoryError, match="Segfault simulated"): + buf.close() + + executor.shutdown() + + +def test_direct_memmove_buffer(): + data1 = b"hello" + data2 = b"world" + + # Calculate exact size to prevent the new underflow check from failing + size = len(data1) + len(data2) + buffer_array = (ctypes.c_char * size)() + start_address = ctypes.addressof(buffer_array) + end_address = start_address + size + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) + + future1 = buf.write(data1) + future2 = buf.write(data2) + + future1.result() + future2.result() + buf.close() + + result_bytes = ctypes.string_at(start_address, len(data1) + len(data2)) + assert result_bytes == b"helloworld" + + executor.shutdown() + + +def test_direct_memmove_buffer_overflow(): + """Tests that writing past the allocated end_address raises a BufferError.""" + size = 10 + buffer_array = (ctypes.c_char * size)() + start_address = ctypes.addressof(buffer_array) + end_address = start_address + size + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) + + # Fill the buffer exactly to capacity + buf.write(b"1234567890") + + # Attempting to write even 1 more byte should trigger the overflow protection + with pytest.raises(BufferError, match="Attempted to write"): + buf.write(b"1") + + buf.close() + executor.shutdown() + + +def test_direct_memmove_buffer_underflow(): + """Tests that closing an incompletely filled buffer raises a BufferError.""" + size = 10 + buffer_array = (ctypes.c_char * size)() + start_address = ctypes.addressof(buffer_array) + end_address = start_address + size + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) + + # Write fewer bytes than the expected capacity + buf.write(b"12345") + + # Closing should detect that current_offset (5) < expected size (10) + with pytest.raises(BufferError, match="Buffer contains uninitialized data"): + buf.close() + + executor.shutdown() diff --git a/gcsfs/tests/test_zonal_file.py b/gcsfs/tests/test_zonal_file.py index e8cebae0..dde711f4 100644 --- a/gcsfs/tests/test_zonal_file.py +++ b/gcsfs/tests/test_zonal_file.py @@ -1,5 +1,7 @@ """Tests for ZonalFile write operations.""" +import asyncio +import contextlib import os from unittest import mock @@ -8,8 +10,10 @@ _DEFAULT_FLUSH_INTERVAL_BYTES, ) +from gcsfs.caching import Prefetcher from gcsfs.tests.settings import TEST_ZONAL_BUCKET from gcsfs.tests.utils import tempdir, tmpfile +from gcsfs.zonal_file import ZonalFile test_data = b"hello world" @@ -474,3 +478,279 @@ def test_pipe_overwrite_in_zonal_bucket(self, extended_gcsfs, file_path): extended_gcsfs.pipe(remote_path, overwrite_data, finalize_on_close=True) assert extended_gcsfs.cat(remote_path) == overwrite_data + + +@pytest.fixture +def mock_gcsfs(): + fs = mock.Mock() + fs._split_path.return_value = ("test-bucket", "test-key", "123") + fs.split_path.return_value = ("test-bucket", "test-key", "123") + fs.info.return_value = {"size": 1000, "generation": "123", "name": "test-key"} + fs.loop = mock.Mock() + return fs + + +def test_zonal_file_prefetcher_initialization(mock_gcsfs): + """Test that setting cache_type to 'prefetcher' injects the logical chunk fetcher.""" + + with ( + mock.patch("gcsfs.zonal_file.MRDPool") as mrd_pool_mock, + mock.patch("gcsfs.zonal_file.asyn.sync"), + ): + + mrd_pool_instance = mock.Mock() + mrd_pool_instance.persisted_size = 1000 + mrd_pool_mock.return_value = mrd_pool_instance + + cache_options = {"concurrency": 2} + + zf = ZonalFile( + gcsfs=mock_gcsfs, + path="gs://test-bucket/test-key", + mode="rb", + cache_type="prefetcher", + pool_size=1, + cache_options=cache_options, + ) + + assert zf.pool_size == 1 + assert zf.cache.name == Prefetcher.name + assert zf.cache.size == 1000 + + zf.close() + + +@pytest.mark.asyncio +async def test_fetch_logical_chunk_split_logic(mock_gcsfs): + """Test that chunks larger than 16MB are split correctly.""" + + with ( + mock.patch("gcsfs.zonal_file.MRDPool"), + mock.patch("gcsfs.zonal_file.asyn.sync"), + mock.patch("gcsfs.zonal_file.DirectMemmoveBuffer"), + ): + + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + zf.pool_size = 4 + + zf.mrd_pool = mock.Mock() + zf.gcsfs.memmove_executor = mock.Mock() + + mrd_mock = mock.AsyncMock() + + @contextlib.asynccontextmanager + async def fake_get_mrd(): + yield mrd_mock + + zf.mrd_pool.get_mrd = fake_get_mrd + + total_size = 32 * 1024 * 1024 + await zf._fetch_logical_chunk( + start_offset=0, total_size=total_size, split_factor=2 + ) + + # Assert the split logic directly on the downloader mock + assert mrd_mock.download_ranges.call_count == 2 + + # Sort calls by offset to ensure consistent assertions (tasks can run in any order) + calls = mrd_mock.download_ranges.call_args_list + args = [c[0][0][0] for c in calls] # extracts the (offset, size, buffer) tuple + args.sort(key=lambda x: x[0]) + + assert args[0][0] == 0 # Offset 1 + assert args[0][1] == 16 * 1024 * 1024 # Size 1 + + assert args[1][0] == 16 * 1024 * 1024 # Offset 2 + assert args[1][1] == 16 * 1024 * 1024 # Size 2 + + # Explicitly close while asyn.sync is still mocked + zf.close() + + +@pytest.mark.asyncio +async def test_zonal_fetch_logical_chunk_cancellation(mock_gcsfs): + """Test the BaseException block (cancellation) cleans up and cancels inner tasks.""" + with ( + mock.patch("gcsfs.zonal_file.MRDPool"), + mock.patch("gcsfs.zonal_file.asyn.sync"), + mock.patch("gcsfs.zonal_file.DirectMemmoveBuffer"), + ): + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + zf.mrd_pool = mock.Mock() + zf.gcsfs.memmove_executor = mock.Mock() + + mrd_mock = mock.AsyncMock() + + # Create a side effect that hangs to simulate pending downloads + async def slow_download(*args, **kwargs): + await asyncio.sleep(10) + + mrd_mock.download_ranges = mock.AsyncMock(side_effect=slow_download) + + @contextlib.asynccontextmanager + async def fake_get_mrd(): + yield mrd_mock + + zf.mrd_pool.get_mrd = fake_get_mrd + + # Spawn the fetcher as a task + task = asyncio.create_task( + zf._fetch_logical_chunk(start_offset=0, total_size=100, split_factor=2) + ) + + # Yield control to let the inner tasks get created and block on sleep + await asyncio.sleep(0.1) + + # Cancel the outer task + task.cancel() + + # Ensure the CancelledError bubbles out exactly as expected + with pytest.raises(asyncio.CancelledError): + await task + + zf.close() + + +@pytest.mark.asyncio +async def test_zonal_fetch_logical_chunk_single(mock_gcsfs): + """Test successful single chunk download (split_factor=1).""" + with ( + mock.patch("gcsfs.zonal_file.MRDPool"), + mock.patch("gcsfs.zonal_file.asyn.sync"), + mock.patch("gcsfs.zonal_file.DirectMemmoveBuffer") as mem_buf_mock, + mock.patch( + "gcsfs.zonal_file.PyBytes_FromStringAndSize", return_value=b"0" * 100 + ), + mock.patch("gcsfs.zonal_file.PyBytes_AsString", return_value=12345), + ): + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + zf.mrd_pool = mock.Mock() + + mrd_mock = mock.AsyncMock() + + @contextlib.asynccontextmanager + async def fake_get_mrd(): + yield mrd_mock + + zf.mrd_pool.get_mrd = fake_get_mrd + + res = await zf._fetch_logical_chunk( + start_offset=0, total_size=100, split_factor=1 + ) + + assert res == b"0" * 100 + mrd_mock.download_ranges.assert_awaited_once() + # Verify the finally block ran and closed the buffer + mem_buf_mock.return_value.close.assert_called_once() + zf.close() + + +@pytest.mark.asyncio +async def test_zonal_fetch_logical_chunk_single_exception(mock_gcsfs): + """Test exception handling and buffer cleanup in single chunk download.""" + with ( + mock.patch("gcsfs.zonal_file.MRDPool"), + mock.patch("gcsfs.zonal_file.asyn.sync"), + mock.patch("gcsfs.zonal_file.DirectMemmoveBuffer") as mem_buf_mock, + mock.patch( + "gcsfs.zonal_file.PyBytes_FromStringAndSize", return_value=b"0" * 100 + ), + mock.patch("gcsfs.zonal_file.PyBytes_AsString", return_value=12345), + ): + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + zf.mrd_pool = mock.Mock() + + mrd_mock = mock.AsyncMock() + mrd_mock.download_ranges.side_effect = RuntimeError("Single chunk failure") + + @contextlib.asynccontextmanager + async def fake_get_mrd(): + yield mrd_mock + + zf.mrd_pool.get_mrd = fake_get_mrd + + with pytest.raises(RuntimeError, match="Single chunk failure"): + await zf._fetch_logical_chunk( + start_offset=0, total_size=100, split_factor=1 + ) + + # Verify the finally block ran despite the error + mem_buf_mock.return_value.close.assert_called_once() + zf.close() + + +@pytest.mark.asyncio +async def test_zonal_fetch_logical_chunk_multi_exception(mock_gcsfs): + """Test that standard Exceptions in concurrent downloads are caught and propagated.""" + with ( + mock.patch("gcsfs.zonal_file.MRDPool"), + mock.patch("gcsfs.zonal_file.asyn.sync"), + mock.patch("gcsfs.zonal_file.DirectMemmoveBuffer") as mem_buf_mock, + mock.patch( + "gcsfs.zonal_file.PyBytes_FromStringAndSize", return_value=b"0" * 100 + ), + mock.patch("gcsfs.zonal_file.PyBytes_AsString", return_value=12345), + ): + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + zf.mrd_pool = mock.Mock() + + mrd_mock = mock.AsyncMock() + + call_count = 0 + + async def fake_download_ranges(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise RuntimeError("Simulated chunk download failure") + return None + + mrd_mock.download_ranges = mock.AsyncMock(side_effect=fake_download_ranges) + + @contextlib.asynccontextmanager + async def fake_get_mrd(): + yield mrd_mock + + zf.mrd_pool.get_mrd = fake_get_mrd + + with pytest.raises(RuntimeError, match="Simulated chunk download failure"): + await zf._fetch_logical_chunk( + start_offset=0, total_size=100, split_factor=2 + ) + + assert mem_buf_mock.return_value.close.call_count == 2 + zf.close() + + +@pytest.mark.asyncio +async def test_zonal_fetch_logical_chunk_multi_cancellation(mock_gcsfs): + """Test the BaseException block (cancellation) cleans up and cancels inner tasks.""" + with ( + mock.patch("gcsfs.zonal_file.MRDPool"), + mock.patch("gcsfs.zonal_file.asyn.sync"), + mock.patch("gcsfs.zonal_file.DirectMemmoveBuffer") as mem_buf_mock, + mock.patch( + "gcsfs.zonal_file.PyBytes_FromStringAndSize", return_value=b"0" * 100 + ), + mock.patch("gcsfs.zonal_file.PyBytes_AsString", return_value=12345), + mock.patch("asyncio.gather", new_callable=mock.AsyncMock) as gather_mock, + ): + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + zf.mrd_pool = mock.Mock() + + @contextlib.asynccontextmanager + async def fake_get_mrd(): + yield mock.AsyncMock() + + zf.mrd_pool.get_mrd = fake_get_mrd + + gather_mock.side_effect = [asyncio.CancelledError("Cancelled by test"), []] + + with pytest.raises(asyncio.CancelledError): + await zf._fetch_logical_chunk( + start_offset=0, total_size=100, split_factor=2 + ) + assert gather_mock.call_count == 2 + assert mem_buf_mock.return_value.close.call_count == 2 + + zf.close() diff --git a/gcsfs/zb_hns_utils.py b/gcsfs/zb_hns_utils.py index 24145f50..1cf912a6 100644 --- a/gcsfs/zb_hns_utils.py +++ b/gcsfs/zb_hns_utils.py @@ -1,4 +1,8 @@ +import asyncio +import contextlib +import ctypes import logging +import threading from io import BytesIO from google.api_core.exceptions import NotFound @@ -10,9 +14,20 @@ AsyncMultiRangeDownloader, ) +PyBytes_FromStringAndSize = ctypes.pythonapi.PyBytes_FromStringAndSize +PyBytes_FromStringAndSize.argtypes = (ctypes.c_void_p, ctypes.c_ssize_t) +PyBytes_FromStringAndSize.restype = ctypes.py_object + +PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString +PyBytes_AsString.argtypes = (ctypes.py_object,) +PyBytes_AsString.restype = ctypes.c_void_p + logger = logging.getLogger("gcsfs") +_AUTO = "auto" + + async def init_mrd(grpc_client, bucket_name, object_name, generation=None): """ Creates the AsyncMultiRangeDownloader using an existing client. @@ -92,3 +107,212 @@ async def close_aaow(aaow, finalize_on_close=False): logger.warning( f"Error closing AsyncAppendableObjectWriter for {aaow.bucket_name}/{aaow.object_name}: {e}" ) + + +class DirectMemmoveBuffer: + """ + A buffer-like object that writes data directly to memory asynchronously. + + This class provides a `write` interface that queues `ctypes.memmove` operations + to a thread pool executor, limiting the maximum number of concurrent pending + writes using a semaphore. It is useful for high-performance data transfers + where memory copies need to be offloaded from the main thread. + """ + + def __init__(self, start_address, end_address, executor, max_pending=5): + """ + Initializes the DirectMemmoveBuffer. + + Args: + start_address (int): The starting memory address where data will be written. + end_address (int): The absolute ending memory address. Writes exceeding + this boundary will be rejected to prevent overflows. + executor (concurrent.futures.Executor): The thread pool executor to run the + memmove operations. + max_pending (int, optional): The maximum number of pending write operations + allowed in the queue. Defaults to 5. + """ + self.start_address = start_address + self.end_address = end_address + self.current_offset = 0 + self.semaphore = threading.Semaphore(max_pending) + self._error = None + self._pending_count = 0 + self._lock = threading.Lock() + self._done_event = threading.Event() + self._done_event.set() + self.executor = executor + + def write(self, data): + """ + Schedules a write operation to memory. + + Calculates the destination address based on the current offset, increments the offset, + and submits the memory move operation to the executor. Blocks if the number of + pending operations reaches `max_pending`. + + Args: + data (bytes): The data to be written to memory. + + Returns: + concurrent.futures.Future: A future object representing the execution of the + memory move operation. + + Raises: + Exception: If any previous asynchronous write operation encountered an error. + """ + if self._error: + raise self._error + size = len(data) + with self._lock: + dest = self.start_address + self.current_offset + if dest + size > self.end_address: + error_msg = ( + f"Attempted to write {size} bytes " + f"at offset {self.current_offset}. " + f"Max capacity is {self.end_address - self.start_address} bytes." + ) + raise BufferError(error_msg) + + self.current_offset += size + data_bytes = bytes(data) if not isinstance(data, bytes) else data + + self.semaphore.acquire() + with self._lock: + if self._pending_count == 0: + self._done_event.clear() + self._pending_count += 1 + return self.executor.submit(self._do_memmove, dest, data_bytes, size) + + def _do_memmove(self, dest, data_bytes, size): + try: + ctypes.memmove(dest, data_bytes, size) + except Exception as e: + self._error = e + raise e + finally: + self.semaphore.release() + with self._lock: + self._pending_count -= 1 + if self._pending_count == 0: + self._done_event.set() + + def close(self): + """ + Waits for all pending write operations to complete and checks for errors. + Blocks the calling thread until the queue of memory operations is entirely + processed. + + Raises: + Exception: If any background write operation failed during execution. + """ + self._done_event.wait() + if self._error: + raise self._error + + expected_size = self.end_address - self.start_address + if self.current_offset < expected_size: + error_msg = ( + f"Expected {expected_size} bytes, " + f"but only received {self.current_offset} bytes. " + f"Buffer contains uninitialized data." + ) + raise BufferError(error_msg) + + +class MRDPool: + """Manages a pool of AsyncMultiRangeDownloader objects with on-demand scaling.""" + + def __init__(self, gcsfs, bucket_name, object_name, generation, pool_size): + """ + Initializes the MRDPool. + + Args: + gcsfs (gcsfs.GCSFileSystem): The GCS filesystem client used for the downloads. + bucket_name (str): The name of the GCS bucket. + object_name (str): The target object/blob name in the bucket. + generation (int or str): The specific generation of the GCS object to download. + pool_size (int): The maximum number of concurrent downloaders allowed in the pool. + """ + self.gcsfs = gcsfs + self.bucket_name = bucket_name + self.object_name = object_name + self.generation = generation + self.pool_size = pool_size + + self._free_mrds = asyncio.Queue(maxsize=pool_size) + self._active_count = 0 + self._creation_lock = asyncio.Lock() + self.persisted_size = None + self._initialized = False + self._all_mrds = [] + + async def _create_mrd(self): + await self.gcsfs._get_grpc_client() + mrd = await init_mrd( + self.gcsfs.grpc_client, self.bucket_name, self.object_name, self.generation + ) + self._all_mrds.append(mrd) + return mrd + + async def initialize(self): + """Initializes the MRDPool by creating the first downloader instance.""" + async with self._creation_lock: + if not self._initialized: + mrd = await self._create_mrd() + self.persisted_size = mrd.persisted_size + self._free_mrds.put_nowait(mrd) + self._active_count += 1 + self._initialized = True + + @contextlib.asynccontextmanager + async def get_mrd(self): + """ + Dynamically provisions MRDs using an async context manager. + + If a downloader is available in the pool, it is yielded immediately. If the + pool is empty but hasn't reached `pool_size`, a new downloader is spawned + on demand. Automatically returns the downloader to the free queue upon exit. + + Yields: + AsyncMultiRangeDownloader: An active downloader ready for requests. + + Raises: + Exception: Bubbles up any exceptions encountered during MRD creation. + """ + spawn_new = False + + if self._free_mrds.empty(): + async with self._creation_lock: + if self._active_count < self.pool_size: + self._active_count += 1 + spawn_new = True + + if spawn_new: + try: + mrd = await self._create_mrd() + except Exception as e: + self._active_count -= 1 + raise e + else: + mrd = await self._free_mrds.get() + + try: + yield mrd + finally: + self._free_mrds.put_nowait(mrd) + + async def close(self): + """ + Cleanly shut down all MRDs. + + Iterates through all instantiated downloaders and calls their close methods + with a 2-second timeout to prevent indefinite hanging during teardown. + """ + tasks = [] + for mrd in self._all_mrds: + tasks.append(mrd.close()) + try: + await asyncio.gather(*tasks, return_exceptions=True) + finally: + self._all_mrds.clear() diff --git a/gcsfs/zonal_file.py b/gcsfs/zonal_file.py index 0b6da70b..f51a7ee8 100644 --- a/gcsfs/zonal_file.py +++ b/gcsfs/zonal_file.py @@ -1,3 +1,4 @@ +import asyncio import logging from fsspec import asyn @@ -7,8 +8,15 @@ from gcsfs import zb_hns_utils from gcsfs.core import DEFAULT_BLOCK_SIZE, GCSFile +from gcsfs.zb_hns_utils import ( + DirectMemmoveBuffer, + MRDPool, + PyBytes_AsString, + PyBytes_FromStringAndSize, +) from .caching import ( # noqa: F401 Unused import to register GCS-Specific caches, Please do not remove it. + Prefetcher, ReadAheadChunked, ) @@ -40,6 +48,7 @@ def __init__( kms_key_name=None, finalize_on_close=False, flush_interval_bytes=_DEFAULT_FLUSH_INTERVAL_BYTES, + pool_size=zb_hns_utils._AUTO, **kwargs, ): """ @@ -59,24 +68,40 @@ def __init__( bucket, key, generation = gcsfs._split_path(path) if not key: raise OSError("Attempt to open a bucket") - self.mrd = None + + self.mrd_pool = None self.aaow = None self.finalize_on_close = finalize_on_close self.finalized = False self.mode = mode self.flush_interval_bytes = flush_interval_bytes self.gcsfs = gcsfs + + if pool_size != zb_hns_utils._AUTO: + self.pool_size = pool_size + else: + self.pool_size = 4 if cache_type == Prefetcher.name else 1 object_size = None + + if cache_options is None: + cache_options = {} + if "r" in self.mode: - self.mrd = asyn.sync( - self.gcsfs.loop, self._init_mrd, bucket, key, generation - ) - object_size = self.mrd.persisted_size + self.mrd_pool = MRDPool(self.gcsfs, bucket, key, generation, self.pool_size) + asyn.sync(self.gcsfs.loop, self.mrd_pool.initialize) + object_size = self.mrd_pool.persisted_size + if object_size is None: logger.warning( "AsyncMultiRangeDownloader (MRD) exists but has no 'persisted_size'. " "This may result in incorrect behavior for unfinalized objects." ) + + # These caches support overriding the default fetcher method. + _FETCHER_OVERRIDE = {Prefetcher.name: self._fetch_logical_chunk} + if cache_type in _FETCHER_OVERRIDE: + cache_options["fetcher_override"] = _FETCHER_OVERRIDE[cache_type] + elif "w" in self.mode or "a" in self.mode: pass else: @@ -162,23 +187,100 @@ def _fetch_range(self, start=None, end=None, chunk_lengths=None): "The end and chunk_lengths arguments are mutually exclusive and cannot be used together." ) - try: - if chunk_lengths is not None: - return asyn.sync( - self.fs.loop, - self.gcsfs._fetch_range_split, - self.path, - start=start, - chunk_lengths=chunk_lengths, - size=self.size, - mrd=self.mrd, + async def _do_fetch(): + async with self.mrd_pool.get_mrd() as mrd: + if chunk_lengths is not None: + return await self.gcsfs._fetch_range_split( + self.path, + start=start, + chunk_lengths=chunk_lengths, + size=self.size, + mrd=mrd, + ) + return await self.gcsfs._cat_file( + self.path, start=start, end=end, mrd=mrd ) - return self.gcsfs.cat_file(self.path, start=start, end=end, mrd=self.mrd) + + try: + return asyn.sync(self.fs.loop, _do_fetch) except RuntimeError as e: if "not satisfiable" in str(e): return b"" if chunk_lengths is None else [b""] raise + async def _fetch_logical_chunk(self, start_offset, total_size, split_factor=1): + """ + A custom asynchronous fetcher designed to be mapped to the Prefetcher cache. + + Pre-allocates an uninitialized Python bytes object at the C-level and pulls + data directly into that memory space using `DirectMemmoveBuffer`. + + Args: + start_offset (int): The starting byte offset in the remote file. + total_size (int): The total number of bytes to fetch. + split_factor (int): The number of parts in slicing process. + + Returns: + bytes: The fully downloaded logical chunk. + + Raises: + Exception: Any exception encountered during concurrent chunk downloads. + """ + result_bytes = PyBytes_FromStringAndSize(None, total_size) + buffer_ptr = PyBytes_AsString(result_bytes) + loop = asyncio.get_running_loop() + + if split_factor == 1: + buf = DirectMemmoveBuffer( + buffer_ptr, buffer_ptr + total_size, self.gcsfs.memmove_executor + ) + try: + async with self.mrd_pool.get_mrd() as mrd: + await mrd.download_ranges([(start_offset, total_size, buf)]) + finally: + await loop.run_in_executor(None, buf.close) + else: + part_size = total_size // split_factor + tasks = [] + buffers = [] + + async def _download(o, s, b): + async with self.mrd_pool.get_mrd() as mrd: + await mrd.download_ranges([(o, s, b)]) + + for i in range(split_factor): + offset = start_offset + (i * part_size) + actual_size = ( + part_size if i < split_factor - 1 else total_size - (i * part_size) + ) + + part_address = buffer_ptr + (offset - start_offset) + buf = DirectMemmoveBuffer( + part_address, + part_address + actual_size, + self.gcsfs.memmove_executor, + ) + buffers.append(buf) + task = asyncio.create_task(_download(offset, actual_size, buf)) + tasks.append(task) + + try: + results = await asyncio.gather(*tasks, return_exceptions=True) + for res in results: + if isinstance(res, Exception): + raise res + except BaseException: + for t in tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + finally: + for buf in buffers: + await loop.run_in_executor(None, buf.close) + + return result_bytes + def write(self, data): """ Writes data using AsyncAppendableObjectWriter. @@ -308,16 +410,21 @@ def close(self): """ if self.closed: return - # super is closed before aaow since flush may need aaow - super().close() - # Helper method safely handles mrd=None. - asyn.sync(self.gcsfs.loop, zb_hns_utils.close_mrd, self.mrd) - - # Only close aaow if the stream is open - if self.aaow and self.aaow._is_stream_open: - asyn.sync( - self.gcsfs.loop, - zb_hns_utils.close_aaow, - self.aaow, - finalize_on_close=self.finalize_on_close, - ) + if hasattr(self, "cache") and self.cache and hasattr(self.cache, "close"): + self.cache.close() + + try: + # super is closed before aaow since flush may need aaow + super().close() + finally: + if hasattr(self, "mrd_pool") and self.mrd_pool: + asyn.sync(self.gcsfs.loop, self.mrd_pool.close) + + # Only close aaow if the stream is open + if self.aaow and self.aaow._is_stream_open: + asyn.sync( + self.gcsfs.loop, + zb_hns_utils.close_aaow, + self.aaow, + finalize_on_close=self.finalize_on_close, + )