From 65cd4ef9a6d21a4aed48b03aaa3c56330f56139d Mon Sep 17 00:00:00 2001 From: 0xSwego <0xSwego@gmail.com> Date: Mon, 29 Jun 2026 14:54:23 +0100 Subject: [PATCH 1/2] Upgrade ShardedZarrStore for pyramid zarrs --- py_hamt/hamt_to_sharded_converter.py | 53 +- py_hamt/sharded_zarr_store.py | 1466 ++++++++++++++++++++------ tests/test_sharded_store_grafting.py | 2 +- tests/test_sharded_zarr_store.py | 7 +- tests/test_sharded_zarr_store_v2.py | 910 ++++++++++++++++ 5 files changed, 2089 insertions(+), 349 deletions(-) create mode 100644 tests/test_sharded_zarr_store_v2.py diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py index b0e8921..2f8b050 100644 --- a/py_hamt/hamt_to_sharded_converter.py +++ b/py_hamt/hamt_to_sharded_converter.py @@ -2,13 +2,17 @@ import asyncio import time -import xarray as xr from multiformats import CID from .hamt import HAMT -from .sharded_zarr_store import ShardedZarrStore +from .sharded_zarr_store import SHARDED_ZARR_V2, ShardedZarrStore from .store_httpx import KuboCAS -from .zarr_hamt_store import ZarrHAMTStore + + +def _is_zarr_chunk_key(key: str) -> bool: + if key.endswith(("zarr.json", ".zarray", ".zattrs", ".zgroup")): + return False + return key.startswith("c/") or "/c/" in key async def convert_hamt_to_sharded( @@ -32,43 +36,44 @@ async def convert_hamt_to_sharded( hamt_ro = await HAMT.build( cas=cas, root_node_id=hamt_root_cid, values_are_bytes=True, read_only=True ) - source_store = ZarrHAMTStore(hamt_ro, read_only=True) - source_dataset = xr.open_zarr(store=source_store, consolidated=True) - # 2. Introspect the source array to get its configuration - print("Reading metadata from source store...") - - # Read the stores metadata to get array shape and chunk shape - data_var_name = next(iter(source_dataset.data_vars)) - ordered_dims = list(source_dataset[data_var_name].dims) - array_shape_tuple = tuple(source_dataset.sizes[dim] for dim in ordered_dims) - chunk_shape_tuple = tuple(source_dataset.chunks[dim][0] for dim in ordered_dims) - array_shape = array_shape_tuple - chunk_shape = chunk_shape_tuple - - # 3. Create the destination ShardedZarrStore for writing + + # 2. Create the destination ShardedZarrStore for writing. print( - f"Initializing new ShardedZarrStore with {chunks_per_shard} chunks per shard..." + f"Initializing new ShardedZarrStore v2 with {chunks_per_shard} chunks per shard..." ) dest_store = await ShardedZarrStore.open( cas=cas, read_only=False, - array_shape=array_shape, - chunk_shape=chunk_shape, chunks_per_shard=chunks_per_shard, + manifest_version=SHARDED_ZARR_V2, ) print("Destination store initialized.") - # 4. Iterate and copy all data from source to destination + # 3. Copy metadata first so each chunked array path registers its own shard + # index before chunk pointers are inserted. print("Starting data migration...") count = 0 async for key in hamt_ro.keys(): + if _is_zarr_chunk_key(key): + continue count += 1 - # Read the raw data (metadata or chunk) from the source - cid: CID = await hamt_ro.get_pointer(key) + cid = await hamt_ro.get_pointer(key) + if not isinstance(cid, CID): # pragma: no cover + raise TypeError(f"Expected CID pointer for key {key!r}.") cid_base32_str = str(cid.encode("base32")) + await dest_store.set_pointer(key, cid_base32_str) + if count % 200 == 0: # pragma: no cover + print(f"Migrated {count} keys...") # pragma: no cover - # Write the exact same key-value pair to the destination. + async for key in hamt_ro.keys(): + if not _is_zarr_chunk_key(key): + continue + count += 1 + cid = await hamt_ro.get_pointer(key) + if not isinstance(cid, CID): # pragma: no cover + raise TypeError(f"Expected CID pointer for key {key!r}.") + cid_base32_str = str(cid.encode("base32")) await dest_store.set_pointer(key, cid_base32_str) if count % 200 == 0: # pragma: no cover print(f"Migrated {count} keys...") # pragma: no cover diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 4a20736..94d228d 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -6,11 +6,13 @@ import time from collections import OrderedDict, defaultdict from collections.abc import AsyncIterator, Iterable -from typing import DefaultDict, Dict, List, Optional, Set, Tuple +from dataclasses import dataclass +from typing import DefaultDict, Dict, List, Optional, Set, Tuple, cast import dag_cbor import zarr.abc.store import zarr.core.buffer +from dag_cbor.ipld import IPLDKind from multiformats.cid import CID from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest from zarr.core.common import BytesLike @@ -18,6 +20,153 @@ from . import instrumentation from .store_httpx import ContentAddressedStore +SHARDED_ZARR_V1 = "sharded_zarr_v1" +SHARDED_ZARR_V2 = "sharded_zarr_v2" +ZARR_METADATA_SUFFIXES = ("zarr.json", ".zarray", ".zattrs", ".zgroup") + +ShardCacheKey = int | tuple[str, int] + + +@dataclass(frozen=True) +class ChunkKey: + """A parsed Zarr v3 chunk key.""" + + array_path: str + coords: tuple[int, ...] + + +@dataclass +class ArrayIndex: + """Path-local shard index and chunk geometry for one Zarr array.""" + + array_path: str + array_shape: tuple[int, ...] + chunk_shape: tuple[int, ...] + chunks_per_shard: int + shard_cids: list[Optional[CID]] + order: str = "C" + + def __post_init__(self) -> None: + self.array_path = ShardedZarrStore._normalize_array_path(self.array_path) + self.array_shape = tuple(self.array_shape) + self.chunk_shape = tuple(self.chunk_shape) + self._validate_geometry() + self.chunks_per_dim = self._calculate_chunks_per_dim() + self.total_chunks = math.prod(self.chunks_per_dim) + self.num_shards = ( + (self.total_chunks + self.chunks_per_shard - 1) // self.chunks_per_shard + if self.total_chunks > 0 + else 0 + ) + if len(self.shard_cids) != self.num_shards: + raise ValueError( + f"Inconsistent number of shards. Expected {self.num_shards}, found {len(self.shard_cids)}." + ) + + def _validate_geometry(self) -> None: + if not isinstance(self.chunks_per_shard, int) or self.chunks_per_shard <= 0: + raise ValueError("chunks_per_shard must be a positive integer.") + if len(self.array_shape) != len(self.chunk_shape): + raise ValueError("array_shape and chunk_shape must have the same rank.") + if not all(cs > 0 for cs in self.chunk_shape): + raise ValueError("All chunk_shape dimensions must be positive.") + if not all(s >= 0 for s in self.array_shape): + raise ValueError("All array_shape dimensions must be non-negative.") + if self.order != "C": + raise ValueError("Only row-major ('C') shard ordering is supported.") + + def _calculate_chunks_per_dim(self) -> tuple[int, ...]: + return tuple( + math.ceil(a / c) if c > 0 else 0 + for a, c in zip(self.array_shape, self.chunk_shape, strict=True) + ) + + @classmethod + def new( + cls, + array_path: str, + array_shape: tuple[int, ...], + chunk_shape: tuple[int, ...], + chunks_per_shard: int, + *, + order: str = "C", + ) -> "ArrayIndex": + chunks_per_dim = tuple( + math.ceil(a / c) if c > 0 else 0 + for a, c in zip(array_shape, chunk_shape, strict=True) + ) + total_chunks = math.prod(chunks_per_dim) + num_shards = ( + (total_chunks + chunks_per_shard - 1) // chunks_per_shard + if total_chunks > 0 + else 0 + ) + return cls( + array_path=array_path, + array_shape=array_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + shard_cids=[None] * num_shards, + order=order, + ) + + @classmethod + def from_manifest(cls, array_path: str, manifest: dict) -> "ArrayIndex": + shard_cids = manifest.get("shard_cids") + if not isinstance(shard_cids, list): + raise ValueError("shard_cids is not a list.") + + sharding_config = manifest.get("sharding_config", {}) + if not isinstance(sharding_config, dict): + raise ValueError("sharding_config is not a dictionary.") + + return cls( + array_path=array_path, + array_shape=tuple(manifest["array_shape"]), + chunk_shape=tuple(manifest["chunk_shape"]), + chunks_per_shard=sharding_config["chunks_per_shard"], + order=sharding_config.get("order", "C"), + shard_cids=list(shard_cids), + ) + + def resize(self, new_shape: tuple[int, ...]) -> None: + if len(new_shape) != len(self.array_shape): + raise ValueError( + "New shape must have the same number of dimensions as the old shape." + ) + + old_shard_cids = self.shard_cids + self.array_shape = tuple(new_shape) + self._validate_geometry() + self.chunks_per_dim = self._calculate_chunks_per_dim() + self.total_chunks = math.prod(self.chunks_per_dim) + old_num_shards = self.num_shards + self.num_shards = ( + (self.total_chunks + self.chunks_per_shard - 1) // self.chunks_per_shard + if self.total_chunks > 0 + else 0 + ) + + if self.num_shards > old_num_shards: + self.shard_cids = old_shard_cids + [None] * ( + self.num_shards - old_num_shards + ) + elif self.num_shards < old_num_shards: + self.shard_cids = old_shard_cids[: self.num_shards] + else: + self.shard_cids = old_shard_cids + + def to_manifest(self) -> dict: + return { + "array_shape": list(self.array_shape), + "chunk_shape": list(self.chunk_shape), + "sharding_config": { + "chunks_per_shard": self.chunks_per_shard, + "order": self.order, + }, + "shard_cids": self.shard_cids, + } + class MemoryBoundedLRUCache: """ @@ -30,9 +179,9 @@ class MemoryBoundedLRUCache: def __init__(self, max_memory_bytes: int = 100 * 1024 * 1024): # 100MB default self.max_memory_bytes = max_memory_bytes - self._cache: OrderedDict[int, List[Optional[CID]]] = OrderedDict() - self._dirty_shards: Set[int] = set() - self._shard_sizes: Dict[int, int] = {} # Cached sizes for each shard + self._cache: OrderedDict[ShardCacheKey, List[Optional[CID]]] = OrderedDict() + self._dirty_shards: Set[ShardCacheKey] = set() + self._shard_sizes: Dict[ShardCacheKey, int] = {} self._actual_memory_usage = 0 self._cache_lock = asyncio.Lock() @@ -45,7 +194,7 @@ def _get_shard_size(self, shard_data: List[Optional[CID]]) -> int: total += sys.getsizeof(item) return total - async def get(self, shard_idx: int) -> Optional[List[Optional[CID]]]: + async def get(self, shard_idx: ShardCacheKey) -> Optional[List[Optional[CID]]]: """Get a shard from cache, moving it to end (most recently used).""" async with self._cache_lock: if shard_idx not in self._cache: @@ -55,27 +204,26 @@ async def get(self, shard_idx: int) -> Optional[List[Optional[CID]]]: return shard_data async def put( - self, shard_idx: int, shard_data: List[Optional[CID]], is_dirty: bool = False + self, + shard_idx: ShardCacheKey, + shard_data: List[Optional[CID]], + is_dirty: bool = False, ) -> None: """Add or update a shard in cache, evicting old items if needed.""" async with self._cache_lock: shard_size = self._get_shard_size(shard_data) - # If shard exists, remove its old size if shard_idx in self._cache: self._cache.pop(shard_idx) self._actual_memory_usage -= self._shard_sizes.pop(shard_idx, 0) - # Track dirty status if is_dirty: self._dirty_shards.add(shard_idx) - # Add new shard self._cache[shard_idx] = shard_data self._shard_sizes[shard_idx] = shard_size self._actual_memory_usage += shard_size - # Evict old items if over memory limit, never evict dirty shards while ( self._actual_memory_usage > self.max_memory_bytes and len(self._cache) > 1 @@ -85,30 +233,26 @@ async def put( while self._cache: candidate_idx, candidate_data = self._cache.popitem(last=False) if candidate_idx not in self._dirty_shards: - # Evict this clean LRU self._actual_memory_usage -= self._shard_sizes.pop( candidate_idx, 0 ) evicted = True break - else: - # Dirty: move to MRU - self._cache[candidate_idx] = candidate_data - checked_dirty_shards.add(candidate_idx) - # If we've checked all dirty shards, no clean shards available - if len(checked_dirty_shards) == len(self._dirty_shards): - break + + self._cache[candidate_idx] = candidate_data + checked_dirty_shards.add(candidate_idx) + if len(checked_dirty_shards) == len(self._dirty_shards): + break if not evicted: - # No clean shards to evict break - async def mark_dirty(self, shard_idx: int) -> None: + async def mark_dirty(self, shard_idx: ShardCacheKey) -> None: """Mark a shard as dirty (should not be evicted).""" async with self._cache_lock: if shard_idx in self._cache: self._dirty_shards.add(shard_idx) - async def mark_clean(self, shard_idx: int) -> None: + async def mark_clean(self, shard_idx: ShardCacheKey) -> None: """Mark a shard as clean (can be evicted).""" async with self._cache_lock: self._dirty_shards.discard(shard_idx) @@ -121,7 +265,7 @@ async def clear(self) -> None: self._shard_sizes.clear() self._actual_memory_usage = 0 - async def __contains__(self, shard_idx: int) -> bool: + async def __contains__(self, shard_idx: ShardCacheKey) -> bool: async with self._cache_lock: return shard_idx in self._cache @@ -145,17 +289,21 @@ class ShardedZarrStore(zarr.abc.store.Store): """ Implements the Zarr Store API using a sharded layout for chunk CIDs. - This store divides the flat index of chunk CIDs into multiple "shards". - Each shard is a DAG-CBOR array where each element is either a CID link - to a chunk or a null value if the chunk is empty. This structure allows - for efficient traversal by IPLD-aware systems. - - The store's root object contains: - 1. A dictionary mapping metadata keys (like 'zarr.json') to their CIDs. - 2. A list of CIDs, where each CID points to a shard object. - 3. Sharding configuration details (e.g., chunks_per_shard). + ``sharded_zarr_v1`` roots keep the original single global shard index for + compatibility. ``sharded_zarr_v2`` roots keep one shard index per Zarr array + path, allowing grouped arrays to reuse chunk coordinates without collisions. """ + _V1_COORDINATE_ARRAY_PREFIXES = { + "time", + "lat", + "lon", + "latitude", + "longitude", + "forecast_reference_time", + "step", + } + def __init__( self, cas: ContentAddressedStore, @@ -168,46 +316,76 @@ def __init__( super().__init__(read_only=read_only) self.cas = cas self._root_cid = root_cid - self._root_obj: dict + self._root_obj: dict = {} + self._manifest_version = SHARDED_ZARR_V1 self._resize_lock = asyncio.Lock() - # An event to signal when a resize is in-progress. - # It starts in the "set" state, allowing all operations to proceed. self._resize_complete = asyncio.Event() self._resize_complete.set() - self._shard_locks: DefaultDict[int, asyncio.Lock] = defaultdict(asyncio.Lock) + self._shard_locks: DefaultDict[ShardCacheKey, asyncio.Lock] = defaultdict( + asyncio.Lock + ) self._shard_data_cache = MemoryBoundedLRUCache(max_cache_memory_bytes) - self._pending_shard_loads: Dict[int, asyncio.Event] = {} + self._pending_shard_loads: Dict[ShardCacheKey, asyncio.Event] = {} self._metadata_read_cache: Dict[str, bytes] = {} - self._array_shape: Tuple[int, ...] - self._chunk_shape: Tuple[int, ...] - self._chunks_per_dim: Tuple[int, ...] - self._chunks_per_shard: int + self.array_indices: Dict[str, ArrayIndex] = {} + self._primary_array_path: Optional[str] = None + self._default_chunks_per_shard: Optional[int] = None + + self._array_shape: Tuple[int, ...] = () + self._chunk_shape: Tuple[int, ...] = () + self._chunks_per_dim: Tuple[int, ...] = () + self._chunks_per_shard: int = 0 self._num_shards: int = 0 self._total_chunks: int = 0 self._dirty_root = False - def __update_geometry(self): - """Calculates derived geometric properties from the base shapes.""" - - if not all(cs > 0 for cs in self._chunk_shape): - raise ValueError("All chunk_shape dimensions must be positive.") - if not all(s >= 0 for s in self._array_shape): - raise ValueError("All array_shape dimensions must be non-negative.") + @staticmethod + def _normalize_array_path(array_path: str) -> str: + return array_path.strip("/") + + @staticmethod + def _array_path_from_metadata_key(key: str) -> Optional[str]: + if key in {"zarr.json", ".zarray"}: + return "" + if key.endswith("/zarr.json"): + return key[: -len("/zarr.json")] + if key.endswith("/.zarray"): + return key[: -len("/.zarray")] + return None - self._chunks_per_dim = tuple( - math.ceil(a / c) if c > 0 else 0 - for a, c in zip(self._array_shape, self._chunk_shape) + @staticmethod + def _format_chunk_key(array_path: str, coords: tuple[int, ...]) -> str: + coord_path = "/".join(str(coord) for coord in coords) + if array_path: + return f"{array_path}/c/{coord_path}" + return f"c/{coord_path}" + + @staticmethod + def _coords_from_linear_index( + linear_index: int, chunks_per_dim: tuple[int, ...] + ) -> tuple[int, ...]: + coords: list[int] = [] + remaining = linear_index + for stride in reversed(chunks_per_dim): + coords.append(remaining % stride) + remaining //= stride + return tuple(reversed(coords)) + + def __update_geometry(self) -> None: + """Calculates legacy v1 geometric properties from the base shapes.""" + index = ArrayIndex.new( + array_path="", + array_shape=self._array_shape, + chunk_shape=self._chunk_shape, + chunks_per_shard=self._chunks_per_shard, ) - self._total_chunks = math.prod(self._chunks_per_dim) - - if not self._total_chunks == 0: - self._num_shards = ( - self._total_chunks + self._chunks_per_shard - 1 - ) // self._chunks_per_shard + self._chunks_per_dim = index.chunks_per_dim + self._total_chunks = index.total_chunks + self._num_shards = index.num_shards @classmethod async def open( @@ -220,9 +398,15 @@ async def open( chunk_shape: Optional[Tuple[int, ...]] = None, chunks_per_shard: Optional[int] = None, max_cache_memory_bytes: int = 100 * 1024 * 1024, # 100MB default + manifest_version: Optional[str] = None, + primary_array_path: str = "", ) -> "ShardedZarrStore": """ Asynchronously opens an existing ShardedZarrStore or initializes a new one. + + Shape-based creation remains the v1 compatibility path. To create a new + path-aware v2 store, pass ``manifest_version="sharded_zarr_v2"`` or omit + ``array_shape``/``chunk_shape`` and provide ``chunks_per_shard``. """ store = cls( cas, read_only, root_cid, max_cache_memory_bytes=max_cache_memory_bytes @@ -230,7 +414,27 @@ async def open( if root_cid: await store._load_root_from_cid() elif not read_only: - if array_shape is None or chunk_shape is None: + if manifest_version not in {None, SHARDED_ZARR_V1, SHARDED_ZARR_V2}: + raise ValueError(f"Incompatible manifest version: {manifest_version}.") + + if ( + manifest_version in {None, SHARDED_ZARR_V1} + and array_shape is None + and chunk_shape is None + and chunks_per_shard is None + ): + raise ValueError( + "array_shape and chunk_shape must be provided for a new store." + ) + if manifest_version in {None, SHARDED_ZARR_V1} and ( + (array_shape is None) != (chunk_shape is None) + ): + raise ValueError( + "array_shape and chunk_shape must be provided for a new store." + ) + if manifest_version == SHARDED_ZARR_V1 and ( + array_shape is None or chunk_shape is None + ): raise ValueError( "array_shape and chunk_shape must be provided for a new store." ) @@ -238,7 +442,26 @@ async def open( if not isinstance(chunks_per_shard, int) or chunks_per_shard <= 0: raise ValueError("chunks_per_shard must be a positive integer.") - store._initialize_new_root(array_shape, chunk_shape, chunks_per_shard) + use_v2 = manifest_version == SHARDED_ZARR_V2 or ( + array_shape is None and chunk_shape is None + ) + if use_v2: + if (array_shape is None) != (chunk_shape is None): + raise ValueError( + "array_shape and chunk_shape must both be provided when seeding a v2 array index." + ) + store._initialize_new_root_v2( + chunks_per_shard=chunks_per_shard, + array_shape=array_shape, + chunk_shape=chunk_shape, + primary_array_path=primary_array_path, + ) + else: + if array_shape is None or chunk_shape is None: # pragma: no cover + raise ValueError( + "array_shape and chunk_shape must be provided for a new store." + ) + store._initialize_new_root(array_shape, chunk_shape, chunks_per_shard) else: raise ValueError("root_cid must be provided for a read-only store.") return store @@ -248,15 +471,17 @@ def _initialize_new_root( array_shape: Tuple[int, ...], chunk_shape: Tuple[int, ...], chunks_per_shard: int, - ): - self._array_shape = array_shape - self._chunk_shape = chunk_shape + ) -> None: + self._manifest_version = SHARDED_ZARR_V1 + self._array_shape = tuple(array_shape) + self._chunk_shape = tuple(chunk_shape) self._chunks_per_shard = chunks_per_shard + self._default_chunks_per_shard = chunks_per_shard self.__update_geometry() self._root_obj = { - "manifest_version": "sharded_zarr_v1", + "manifest_version": SHARDED_ZARR_V1, "metadata": {}, "chunks": { "array_shape": list(self._array_shape), @@ -267,30 +492,89 @@ def _initialize_new_root( "shard_cids": [None] * self._num_shards, }, } + self.array_indices = { + "": ArrayIndex( + array_path="", + array_shape=self._array_shape, + chunk_shape=self._chunk_shape, + chunks_per_shard=self._chunks_per_shard, + shard_cids=self._root_obj["chunks"]["shard_cids"], + ) + } + self._primary_array_path = "" self._dirty_root = True - async def _load_root_from_cid(self): + def _initialize_new_root_v2( + self, + *, + chunks_per_shard: int, + array_shape: Optional[Tuple[int, ...]] = None, + chunk_shape: Optional[Tuple[int, ...]] = None, + primary_array_path: str = "", + ) -> None: + self._manifest_version = SHARDED_ZARR_V2 + self._default_chunks_per_shard = chunks_per_shard + self._root_obj = { + "manifest_version": SHARDED_ZARR_V2, + "store_type": "py_hamt.sharded_zarr", + "zarr_format": 3, + "sharding_config": { + "chunks_per_shard": chunks_per_shard, + "order": "C", + }, + "metadata": {}, + "arrays": {}, + } + self.array_indices = {} + self._primary_array_path = None + self._array_shape = () + self._chunk_shape = () + self._chunks_per_dim = () + self._chunks_per_shard = chunks_per_shard + self._num_shards = 0 + self._total_chunks = 0 + + if array_shape is not None and chunk_shape is not None: + self._register_or_update_array_index( + array_path=primary_array_path, + array_shape=tuple(array_shape), + chunk_shape=tuple(chunk_shape), + chunks_per_shard=chunks_per_shard, + ) + self._dirty_root = True + + async def _load_root_from_cid(self) -> None: root_bytes = await self.cas.load(self._root_cid) try: - self._root_obj = dag_cbor.decode(root_bytes) - if not isinstance(self._root_obj, dict) or "chunks" not in self._root_obj: - raise ValueError( - "Root object is not a valid dictionary with 'chunks' key." - ) - if not isinstance(self._root_obj["chunks"]["shard_cids"], list): - raise ValueError("shard_cids is not a list.") + decoded_root = dag_cbor.decode(root_bytes) + if not isinstance(decoded_root, dict): + raise ValueError("Root object is not a valid dictionary.") + self._root_obj = decoded_root except Exception as e: - raise ValueError(f"Failed to decode root object: {e}") + raise ValueError(f"Failed to decode root object: {e}") from e - if self._root_obj.get("manifest_version") != "sharded_zarr_v1": + manifest_version = self._root_obj.get("manifest_version") + if manifest_version == SHARDED_ZARR_V1: + self._load_v1_root() + elif manifest_version == SHARDED_ZARR_V2: + self._load_v2_root() + else: raise ValueError( - f"Incompatible manifest version: {self._root_obj.get('manifest_version')}. Expected 'sharded_zarr_v1'." + f"Incompatible manifest version: {manifest_version!r}. Expected '{SHARDED_ZARR_V1}' or '{SHARDED_ZARR_V2}'." ) + def _load_v1_root(self) -> None: + if "chunks" not in self._root_obj: + raise ValueError("Root object is not a valid dictionary with 'chunks' key.") chunk_info = self._root_obj["chunks"] + if not isinstance(chunk_info.get("shard_cids"), list): + raise ValueError("shard_cids is not a list.") + + self._manifest_version = SHARDED_ZARR_V1 self._array_shape = tuple(chunk_info["array_shape"]) self._chunk_shape = tuple(chunk_info["chunk_shape"]) self._chunks_per_shard = chunk_info["sharding_config"]["chunks_per_shard"] + self._default_chunks_per_shard = self._chunks_per_shard self.__update_geometry() @@ -298,9 +582,283 @@ async def _load_root_from_cid(self): raise ValueError( f"Inconsistent number of shards. Expected {self._num_shards}, found {len(chunk_info['shard_cids'])}." ) + self.array_indices = { + "": ArrayIndex( + array_path="", + array_shape=self._array_shape, + chunk_shape=self._chunk_shape, + chunks_per_shard=self._chunks_per_shard, + shard_cids=chunk_info["shard_cids"], + ) + } + self._primary_array_path = "" + + def _load_v2_root(self) -> None: + metadata = self._root_obj.get("metadata") + arrays = self._root_obj.get("arrays") + if not isinstance(metadata, dict) or not isinstance(arrays, dict): + raise ValueError( + "Root object is not a valid v2 dictionary with 'metadata' and 'arrays' keys." + ) + + self._manifest_version = SHARDED_ZARR_V2 + self.array_indices = {} + self._primary_array_path = None + root_sharding_config = self._root_obj.get("sharding_config", {}) + if isinstance(root_sharding_config, dict): + self._default_chunks_per_shard = root_sharding_config.get( + "chunks_per_shard" + ) + else: + self._default_chunks_per_shard = None + + for array_path, array_manifest in arrays.items(): + if not isinstance(array_path, str) or not isinstance(array_manifest, dict): + raise ValueError("arrays must map string paths to dictionaries.") + try: + array_index = ArrayIndex.from_manifest(array_path, array_manifest) + except ValueError as exc: + if str(exc).startswith("Inconsistent number of shards"): + raise ValueError( + f"Inconsistent number of shards for array '{array_path}'. {exc}" + ) from exc + raise + self.array_indices[array_index.array_path] = array_index + if self._primary_array_path is None: + self._primary_array_path = array_index.array_path + + if self.array_indices: + primary_index = self.array_indices[self._primary_array_path or ""] + self._default_chunks_per_shard = primary_index.chunks_per_shard + self._set_legacy_geometry_from_index(primary_index) + else: + self._array_shape = () + self._chunk_shape = () + self._chunks_per_dim = () + self._chunks_per_shard = 0 + self._num_shards = 0 + self._total_chunks = 0 + + def _set_legacy_geometry_from_index(self, array_index: ArrayIndex) -> None: + self._array_shape = array_index.array_shape + self._chunk_shape = array_index.chunk_shape + self._chunks_per_dim = array_index.chunks_per_dim + self._chunks_per_shard = array_index.chunks_per_shard + self._num_shards = array_index.num_shards + self._total_chunks = array_index.total_chunks + + def _sync_arrays_to_root(self) -> None: + if self._manifest_version == SHARDED_ZARR_V2: + self._root_obj["arrays"] = { + array_path: array_index.to_manifest() + for array_path, array_index in self.array_indices.items() + } + + def _register_or_update_array_index( + self, + *, + array_path: str, + array_shape: tuple[int, ...], + chunk_shape: tuple[int, ...], + chunks_per_shard: Optional[int] = None, + ) -> ArrayIndex: + normalized_path = self._normalize_array_path(array_path) + if chunks_per_shard is None: + chunks_per_shard = self._default_chunks_per_shard + if chunks_per_shard is None: + raise RuntimeError("Store is missing a default chunks_per_shard value.") + + existing = self.array_indices.get(normalized_path) + if existing is None: + array_index = ArrayIndex.new( + array_path=normalized_path, + array_shape=array_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + self.array_indices[normalized_path] = array_index + if self._primary_array_path is None: + self._primary_array_path = normalized_path + self._set_legacy_geometry_from_index(array_index) + else: + existing.array_shape = tuple(array_shape) + existing.chunk_shape = tuple(chunk_shape) + existing._validate_geometry() + existing.resize(tuple(array_shape)) + array_index = existing + + if self._primary_array_path == normalized_path: + self._set_legacy_geometry_from_index(array_index) + self._sync_arrays_to_root() + self._dirty_root = True + return array_index + + @staticmethod + def _decode_metadata_json(raw_data: bytes) -> Optional[dict]: + try: + decoded = json.loads(raw_data.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + return None + return decoded if isinstance(decoded, dict) else None + + @staticmethod + def _extract_array_metadata( + metadata_json: dict, + ) -> Optional[tuple[tuple[int, ...], tuple[int, ...]]]: + shape = metadata_json.get("shape") + if shape is None: + return None + + chunk_shape = None + chunk_grid = metadata_json.get("chunk_grid") + if isinstance(chunk_grid, dict): + configuration = chunk_grid.get("configuration") + if isinstance(configuration, dict): + chunk_shape = configuration.get("chunk_shape") + + if chunk_shape is None: + chunk_shape = metadata_json.get("chunks") + + if chunk_shape is None: + return None + + return tuple(int(dim) for dim in shape), tuple(int(dim) for dim in chunk_shape) + + def _infer_v1_migration_source_array_path(self, primary_array_path: str) -> str: + metadata = self._root_obj.get("metadata", {}) + candidates = [primary_array_path] + primary_leaf = primary_array_path.rsplit("/", 1)[-1] + if primary_leaf not in candidates: + candidates.append(primary_leaf) + candidates.append("") + + for candidate in candidates: + metadata_keys = ( + ("zarr.json", ".zarray") + if candidate == "" + else (f"{candidate}/zarr.json", f"{candidate}/.zarray") + ) + if any(key in metadata for key in metadata_keys): + return candidate + return primary_leaf + + @staticmethod + def _rewrite_v1_metadata_key_for_migration( + key: str, source_array_path: str, primary_array_path: str + ) -> str: + parent_path = ( + primary_array_path.rsplit("/", 1)[0] if "/" in primary_array_path else "" + ) + + if source_array_path: + source_prefix = f"{source_array_path}/" + if key.startswith(source_prefix): + return f"{primary_array_path}/{key[len(source_prefix) :]}" + elif key.startswith("c/"): + return f"{primary_array_path}/{key}" + elif key in {"zarr.json", ".zarray"}: + return f"{primary_array_path}/{key}" + + if parent_path and "/" in key and not key.startswith(f"{parent_path}/"): + return f"{parent_path}/{key}" + return key + + async def _add_missing_group_metadata( + self, metadata: dict[str, IPLDKind], array_path: str + ) -> None: + parts = array_path.split("/") + group_paths = ["/".join(parts[:idx]) for idx in range(len(parts))] + group_metadata = json.dumps({ + "zarr_format": 3, + "node_type": "group", + "attributes": {}, + }).encode("utf-8") + + for group_path in group_paths: + metadata_key = ( + "zarr.json" if group_path == "" else f"{group_path}/zarr.json" + ) + metadata[metadata_key] = await self.cas.save(group_metadata, codec="raw") + + async def _register_array_metadata_from_bytes( + self, key: str, raw_data: bytes + ) -> None: + array_path = self._array_path_from_metadata_key(key) + if array_path is None: + return + + metadata_json = self._decode_metadata_json(raw_data) + if metadata_json is None: + return + + array_metadata = self._extract_array_metadata(metadata_json) + if array_metadata is None and self._manifest_version == SHARDED_ZARR_V2: + return + if array_metadata is None: + shape = metadata_json.get("shape") + if shape is None: + return + new_array_shape = tuple(int(dim) for dim in shape) + new_chunk_shape = self._chunk_shape + else: + new_array_shape, new_chunk_shape = array_metadata + + if self._manifest_version == SHARDED_ZARR_V2: + self._register_or_update_array_index( + array_path=array_path, + array_shape=new_array_shape, + chunk_shape=new_chunk_shape, + ) + return + + if ( + len(new_array_shape) == len(self._array_shape) + and new_array_shape != self._array_shape + ): + async with self._resize_lock: + if ( + len(new_array_shape) == len(self._array_shape) + and new_array_shape != self._array_shape + ): + self._resize_complete.clear() + try: + await self.resize_store(new_shape=new_array_shape) + finally: + self._resize_complete.set() + + async def _ensure_v2_parent_group_metadata(self, key: str) -> None: + if self._manifest_version != SHARDED_ZARR_V2: + return + + metadata_path = self._array_path_from_metadata_key(key) + if metadata_path is None: + return + + normalized_path = self._normalize_array_path(metadata_path) + parent_paths = [""] + if normalized_path: + parts = normalized_path.split("/") + parent_paths.extend("/".join(parts[:idx]) for idx in range(1, len(parts))) + + for parent_path in parent_paths: + metadata_key = ( + "zarr.json" if parent_path == "" else f"{parent_path}/zarr.json" + ) + if metadata_key in self._root_obj["metadata"]: + continue + group_metadata = json.dumps({ + "zarr_format": 3, + "node_type": "group", + "attributes": {}, + }).encode("utf-8") + metadata_cid = await self.cas.save(group_metadata, codec="raw") + self._root_obj["metadata"][metadata_key] = metadata_cid + self._metadata_read_cache[metadata_key] = group_metadata + self._dirty_root = True async def _fetch_and_cache_full_shard( self, + cache_key: ShardCacheKey, shard_idx: int, shard_cid: str, max_retries: int = 3, @@ -308,12 +866,6 @@ async def _fetch_and_cache_full_shard( ) -> None: """ Fetch a shard from CAS and cache it, with retry logic for transient errors. - - Args: - shard_idx: The index of the shard to fetch. - shard_cid: The CID of the shard. - max_retries: Maximum number of retry attempts for transient errors. - retry_delay: Delay between retry attempts in seconds. """ for attempt in range(max_retries): try: @@ -321,80 +873,85 @@ async def _fetch_and_cache_full_shard( decoded_shard = dag_cbor.decode(shard_data_bytes) if not isinstance(decoded_shard, list): raise TypeError(f"Shard {shard_idx} did not decode to a list.") - await self._shard_data_cache.put(shard_idx, decoded_shard) - # Always set the Event to unblock waiting coroutines - if shard_idx in self._pending_shard_loads: - self._pending_shard_loads[shard_idx].set() - del self._pending_shard_loads[shard_idx] - return # Success + shard_data: List[Optional[CID]] = [] + for item in decoded_shard: + if item is not None and not isinstance(item, CID): + raise TypeError(f"Shard {shard_idx} contains a non-CID entry.") + shard_data.append(item) + await self._shard_data_cache.put(cache_key, shard_data) + if cache_key in self._pending_shard_loads: + self._pending_shard_loads[cache_key].set() + del self._pending_shard_loads[cache_key] + return except (ConnectionError, TimeoutError) as e: - # Handle transient errors (e.g., network issues) if attempt < max_retries - 1: - await asyncio.sleep( - retry_delay * (2**attempt) - ) # Exponential backoff + await asyncio.sleep(retry_delay * (2**attempt)) continue - else: - raise RuntimeError( - f"Failed to fetch shard {shard_idx} after {max_retries} attempts: {e}" - ) + raise RuntimeError( + f"Failed to fetch shard {shard_idx} after {max_retries} attempts: {e}" + ) from e - def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: - # 1. Exclude .json files immediately (metadata) - if key.endswith(".json"): + def _parse_chunk_key(self, key: str) -> Optional[ChunkKey]: + if key.endswith(ZARR_METADATA_SUFFIXES): return None - excluded_array_prefixes = { - "time", - "lat", - "lon", - "latitude", - "longitude", - "forecast_reference_time", - "step", - } chunk_marker = "/c/" - marker_idx = key.rfind(chunk_marker) # Use rfind for robustness - if marker_idx == -1: - # Key does not contain "/c/", so it's not a chunk data key - # in the expected format (e.g., could be .zattrs, .zgroup at various levels). + marker_idx = key.rfind(chunk_marker) + if marker_idx != -1: + array_path = key[:marker_idx] + coord_part = key[marker_idx + len(chunk_marker) :] + elif key.startswith("c/"): + if self._manifest_version == SHARDED_ZARR_V1: + return None + array_path = "" + coord_part = key[len("c/") :] + else: return None - # Extract the part of the key before "/c/", which might represent the array/group path - # e.g., "temp" from "temp/c/0/0/0" - # e.g., "group1/lat" from "group1/lat/c/0" - # e.g., "" if key is "c/0/0/0" (root array) - path_before_c = key[:marker_idx] + normalized_path = self._normalize_array_path(array_path) + if self._manifest_version == SHARDED_ZARR_V1: + actual_array_name = ( + normalized_path.split("/")[-1] if normalized_path else "" + ) + if actual_array_name in self._V1_COORDINATE_ARRAY_PREFIXES: + return None - # Determine the actual array name (the last component of the path before "/c/") - actual_array_name = "" - if path_before_c: - actual_array_name = path_before_c.split("/")[-1] + parts = coord_part.split("/") + coords = tuple(map(int, parts)) - # If the determined array name is in our exclusion list, return None. - if actual_array_name in excluded_array_prefixes: - return None + if self._manifest_version == SHARDED_ZARR_V1: + self._validate_chunk_coords(coords, self.array_indices[""]) + elif normalized_path in self.array_indices: + self._validate_chunk_coords(coords, self.array_indices[normalized_path]) - # The part after "/c/" contains the chunk coordinates - coord_part = key[marker_idx + len(chunk_marker) :] - parts = coord_part.split("/") + return ChunkKey(array_path=normalized_path, coords=coords) - coords = tuple(map(int, parts)) - # Validate coordinates against the chunk grid of the store's configured array - for i, c_coord in enumerate(coords): - if not (0 <= c_coord < self._chunks_per_dim[i]): + @staticmethod + def _validate_chunk_coords( + chunk_coords: tuple[int, ...], array_index: ArrayIndex + ) -> None: + if len(chunk_coords) != len(array_index.chunks_per_dim): + raise IndexError("tuple index out of range") + for i, c_coord in enumerate(chunk_coords): + if not (0 <= c_coord < array_index.chunks_per_dim[i]): raise IndexError( - f"Chunk coordinate {c_coord} at dimension {i} is out of bounds for dimension size {self._chunks_per_dim[i]}." + f"Chunk coordinate {c_coord} at dimension {i} is out of bounds for dimension size {array_index.chunks_per_dim[i]}." ) - return coords def _get_linear_chunk_index(self, chunk_coords: Tuple[int, ...]) -> int: + return self._get_linear_chunk_index_for_index( + tuple(chunk_coords), self.array_indices[""] + ) + + @staticmethod + def _get_linear_chunk_index_for_index( + chunk_coords: tuple[int, ...], array_index: ArrayIndex + ) -> int: linear_index = 0 multiplier = 1 - # Convert N-D chunk coordinates to a flat 1-D index (row-major order) - for i in reversed(range(len(self._chunks_per_dim))): + for i in reversed(range(len(array_index.chunks_per_dim))): linear_index += chunk_coords[i] * multiplier - multiplier *= self._chunks_per_dim[i] + multiplier *= array_index.chunks_per_dim[i] return linear_index def _get_shard_info(self, linear_chunk_index: int) -> Tuple[int, int]: @@ -402,24 +959,82 @@ def _get_shard_info(self, linear_chunk_index: int) -> Tuple[int, int]: index_in_shard = linear_chunk_index % self._chunks_per_shard return shard_idx, index_in_shard + @staticmethod + def _get_shard_info_for_index( + linear_chunk_index: int, array_index: ArrayIndex + ) -> Tuple[int, int]: + shard_idx = linear_chunk_index // array_index.chunks_per_shard + index_in_shard = linear_chunk_index % array_index.chunks_per_shard + return shard_idx, index_in_shard + + def _array_index_for_path(self, array_path: Optional[str]) -> ArrayIndex: + if self._manifest_version == SHARDED_ZARR_V1: + return self.array_indices[""] + + normalized_path = self._normalize_array_path(array_path or "") + try: + return self.array_indices[normalized_path] + except KeyError as exc: + raise KeyError( + f"No array index registered for chunk path '{normalized_path}'." + ) from exc + + def _cache_key(self, array_path: Optional[str], shard_idx: int) -> ShardCacheKey: + if self._manifest_version == SHARDED_ZARR_V1: + return shard_idx + return (self._normalize_array_path(array_path or ""), shard_idx) + + def _map_byte_request( + self, byte_range: Optional[zarr.abc.store.ByteRequest] + ) -> tuple[Optional[int], Optional[int], Optional[int]]: + req_offset = None + req_length = None + req_suffix = None + + if byte_range: + if isinstance(byte_range, RangeByteRequest): + req_offset = byte_range.start + if byte_range.end is not None: + if byte_range.start > byte_range.end: + raise ValueError( + f"Byte range start ({byte_range.start}) cannot be greater than end ({byte_range.end})" + ) + req_length = byte_range.end - byte_range.start + elif isinstance(byte_range, OffsetByteRequest): + req_offset = byte_range.offset + elif isinstance(byte_range, SuffixByteRequest): + req_suffix = byte_range.suffix + return req_offset, req_length, req_suffix + + async def _get_legacy_metadata_chunk( + self, + key: str, + prototype: zarr.core.buffer.BufferPrototype, + byte_range: Optional[zarr.abc.store.ByteRequest], + ) -> Optional[zarr.core.buffer.Buffer]: + metadata_cid_obj = self._root_obj["metadata"].get(key) + if metadata_cid_obj is None: + return None + req_offset, req_length, req_suffix = self._map_byte_request(byte_range) + data = await self.cas.load( + str(metadata_cid_obj), + offset=req_offset, + length=req_length, + suffix=req_suffix, + ) + return prototype.buffer.from_bytes(data) + async def _load_or_initialize_shard_cache( - self, shard_idx: int + self, shard_idx: int, array_path: Optional[str] = None ) -> List[Optional[CID]]: """ Load a shard into the cache or initialize an empty shard if it doesn't exist. - - Args: - shard_idx: The index of the shard to load or initialize. - - Returns: - List[Optional[CID]]: The shard data (list of CIDs or None). - - Raises: - ValueError: If the shard index is out of bounds. - RuntimeError: If the shard cannot be loaded or initialized. """ started_at = time.perf_counter() - cached_shard = await self._shard_data_cache.get(shard_idx) + array_index = self._array_index_for_path(array_path) + cache_key = self._cache_key(array_index.array_path, shard_idx) + + cached_shard = await self._shard_data_cache.get(cache_key) if cached_shard is not None: instrumentation.record_shard_load( shard_idx=shard_idx, @@ -429,39 +1044,43 @@ async def _load_or_initialize_shard_cache( ) return cached_shard - if shard_idx in self._pending_shard_loads: + if cache_key in self._pending_shard_loads: try: - # Wait for the pending load with a timeout (e.g., 60 seconds) await asyncio.wait_for( - self._pending_shard_loads[shard_idx].wait(), timeout=60.0 + self._pending_shard_loads[cache_key].wait(), timeout=60.0 ) - cached_shard = await self._shard_data_cache.get(shard_idx) + cached_shard = await self._shard_data_cache.get(cache_key) if cached_shard is not None: return cached_shard - else: - raise RuntimeError( - f"Shard {shard_idx} not found in cache after pending load completed." - ) + raise RuntimeError( + f"Shard {shard_idx} not found in cache after pending load completed." + ) except asyncio.TimeoutError: - # Clean up the pending load to allow retry - if shard_idx in self._pending_shard_loads: - self._pending_shard_loads[shard_idx].set() - del self._pending_shard_loads[shard_idx] + if cache_key in self._pending_shard_loads: + self._pending_shard_loads[cache_key].set() + del self._pending_shard_loads[cache_key] raise RuntimeError(f"Timeout waiting for shard {shard_idx} to load.") - if not (0 <= shard_idx < self._num_shards): + if not (0 <= shard_idx < array_index.num_shards): raise ValueError(f"Shard index {shard_idx} out of bounds.") - shard_cid_obj = self._root_obj["chunks"]["shard_cids"][shard_idx] + shard_cid_obj = array_index.shard_cids[shard_idx] if shard_cid_obj: - self._pending_shard_loads[shard_idx] = asyncio.Event() + self._pending_shard_loads[cache_key] = asyncio.Event() shard_cid_str = str(shard_cid_obj) - await self._fetch_and_cache_full_shard(shard_idx, shard_cid_str) + try: + await self._fetch_and_cache_full_shard( + cache_key, shard_idx, shard_cid_str + ) + finally: + pending_load = self._pending_shard_loads.pop(cache_key, None) + if pending_load is not None: + pending_load.set() else: - empty_shard = [None] * self._chunks_per_shard - await self._shard_data_cache.put(shard_idx, empty_shard) + empty_shard: List[Optional[CID]] = [None] * array_index.chunks_per_shard + await self._shard_data_cache.put(cache_key, empty_shard) - result = await self._shard_data_cache.get(shard_idx) + result = await self._shard_data_cache.get(cache_key) if result is None: raise RuntimeError(f"Failed to load or initialize shard {shard_idx}") instrumentation.record_shard_load( @@ -470,7 +1089,7 @@ async def _load_or_initialize_shard_cache( seconds=time.perf_counter() - started_at, entries=len(result), ) - return result # type: ignore[return-value] + return result async def set_partial_values( self, key_start_values: Iterable[Tuple[str, int, BytesLike]] @@ -490,24 +1109,18 @@ async def get_partial_values( def with_read_only(self, read_only: bool = False) -> "ShardedZarrStore": """ - Return this store (if the flag already matches) or a *shallow* - clone that presents the requested read‑only status. - - The clone **shares** the same CAS instance and internal state; - no flushing, network traffic or async work is done. + Return this store (if the flag already matches) or a shallow clone with + the requested read-only status. """ - # Fast path if read_only == self.read_only: - return self # Same mode, return same instance + return self - # Create new instance with different read_only flag - # Creates a *bare* instance without running its __init__ clone = type(self).__new__(type(self)) - # Copy all attributes from the current instance clone.cas = self.cas clone._root_cid = self._root_cid clone._root_obj = self._root_obj + clone._manifest_version = self._manifest_version clone._resize_lock = self._resize_lock clone._resize_complete = self._resize_complete @@ -517,6 +1130,10 @@ def with_read_only(self, read_only: bool = False) -> "ShardedZarrStore": clone._pending_shard_loads = self._pending_shard_loads clone._metadata_read_cache = self._metadata_read_cache + clone.array_indices = self.array_indices + clone._primary_array_path = self._primary_array_path + clone._default_chunks_per_shard = self._default_chunks_per_shard + clone._array_shape = self._array_shape clone._chunk_shape = self._chunk_shape clone._chunks_per_dim = self._chunks_per_dim @@ -526,60 +1143,69 @@ def with_read_only(self, read_only: bool = False) -> "ShardedZarrStore": clone._dirty_root = self._dirty_root - # Re‑initialise the zarr base class so that Zarr sees the flag zarr.abc.store.Store.__init__(clone, read_only=read_only) return clone def __eq__(self, other: object) -> bool: if not isinstance(other, ShardedZarrStore): return False - # For equality, root CID is primary. Config like chunks_per_shard is part of that root's identity. return self._root_cid == other._root_cid - # If nothing to flush, return the root CID. async def flush(self) -> str: async with self._shard_data_cache._cache_lock: dirty_shards = list(self._shard_data_cache._dirty_shards) if dirty_shards: - for shard_idx in sorted(dirty_shards): - # Get the list of CIDs/Nones from the cache - shard_data_list = await self._shard_data_cache.get(shard_idx) + for cache_key in sorted(dirty_shards, key=str): + shard_data_list = await self._shard_data_cache.get(cache_key) if shard_data_list is None: - raise RuntimeError(f"Dirty shard {shard_idx} not found in cache") + raise RuntimeError(f"Dirty shard {cache_key} not found in cache") - # Encode this list into a DAG-CBOR byte representation - shard_data_bytes = dag_cbor.encode(shard_data_list) - - # Save the DAG-CBOR block and get its CID + shard_data_bytes = dag_cbor.encode(cast(IPLDKind, shard_data_list)) new_shard_cid_obj = await self.cas.save( shard_data_bytes, - codec="dag-cbor", # Use 'dag-cbor' codec + codec="dag-cbor", ) - - if ( - self._root_obj["chunks"]["shard_cids"][shard_idx] - != new_shard_cid_obj - ): - # Store the CID object directly - self._root_obj["chunks"]["shard_cids"][shard_idx] = ( - new_shard_cid_obj + if not isinstance(new_shard_cid_obj, CID): # pragma: no cover + raise TypeError( + "ShardedZarrStore requires CAS.save to return CIDs." ) - self._dirty_root = True - # Mark shard as clean after flushing - await self._shard_data_cache.mark_clean(shard_idx) + + if self._manifest_version == SHARDED_ZARR_V1: + if not isinstance(cache_key, int): # pragma: no cover + raise TypeError("v1 shard cache keys must be integers.") + shard_idx = int(cache_key) + if ( + self._root_obj["chunks"]["shard_cids"][shard_idx] + != new_shard_cid_obj + ): + self._root_obj["chunks"]["shard_cids"][shard_idx] = ( + new_shard_cid_obj + ) + self.array_indices[""].shard_cids[shard_idx] = new_shard_cid_obj + self._dirty_root = True + else: + if isinstance(cache_key, int): # pragma: no cover + raise TypeError("v2 shard cache keys must include array paths.") + array_path, shard_idx = cache_key + array_index = self.array_indices[array_path] + if array_index.shard_cids[shard_idx] != new_shard_cid_obj: + array_index.shard_cids[shard_idx] = new_shard_cid_obj + self._dirty_root = True + self._sync_arrays_to_root() + + await self._shard_data_cache.mark_clean(cache_key) if self._dirty_root: - # Ensure all metadata CIDs are CID objects for correct encoding self._root_obj["metadata"] = { k: (CID.decode(v) if isinstance(v, str) else v) for k, v in self._root_obj["metadata"].items() } + self._sync_arrays_to_root() root_obj_bytes = dag_cbor.encode(self._root_obj) new_root_cid = await self.cas.save(root_obj_bytes, codec="dag-cbor") self._root_cid = str(new_root_cid) self._dirty_root = False - # Ignore because root_cid will always exist after initialization or flush. return self._root_cid # type: ignore[return-value] async def get( @@ -599,10 +1225,9 @@ async def get( hit = False kind = "metadata" shard_idx_for_trace: int | None = None - chunk_coords = self._parse_chunk_key(key) + parsed_chunk = self._parse_chunk_key(key) try: - # Metadata request - if chunk_coords is None: + if parsed_chunk is None: metadata_cid_obj = self._root_obj["metadata"].get(key) if metadata_cid_obj is None: return None @@ -616,44 +1241,40 @@ async def get( self._metadata_read_cache[key] = data hit = True return prototype.buffer.from_bytes(data) - # Chunk data request + kind = "chunk" - linear_chunk_index = self._get_linear_chunk_index(chunk_coords) - shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + try: + array_index = self._array_index_for_path(parsed_chunk.array_path) + except KeyError: + return await self._get_legacy_metadata_chunk( + key, prototype, byte_range + ) + linear_chunk_index = self._get_linear_chunk_index_for_index( + parsed_chunk.coords, array_index + ) + shard_idx, index_in_shard = self._get_shard_info_for_index( + linear_chunk_index, array_index + ) shard_idx_for_trace = shard_idx - # This will load the full shard into cache if it's not already there. - shard_lock = self._shard_locks[shard_idx] + cache_key = self._cache_key(array_index.array_path, shard_idx) + shard_lock = self._shard_locks[cache_key] async with shard_lock: target_shard_list = await self._load_or_initialize_shard_cache( - shard_idx + shard_idx, array_index.array_path ) - # Get the CID object (or None) from the cached list. chunk_cid_obj = target_shard_list[index_in_shard] - if chunk_cid_obj is None: - return None # Chunk is empty/doesn't exist. + legacy_buffer = await self._get_legacy_metadata_chunk( + key, prototype, byte_range + ) + hit = legacy_buffer is not None + return legacy_buffer chunk_cid_str = str(chunk_cid_obj) - req_offset = None - req_length = None - req_suffix = None - - if byte_range: - if isinstance(byte_range, RangeByteRequest): - req_offset = byte_range.start - if byte_range.end is not None: - if byte_range.start > byte_range.end: - raise ValueError( - f"Byte range start ({byte_range.start}) cannot be greater than end ({byte_range.end})" - ) - req_length = byte_range.end - byte_range.start - elif isinstance(byte_range, OffsetByteRequest): - req_offset = byte_range.offset - elif isinstance(byte_range, SuffixByteRequest): - req_suffix = byte_range.suffix + req_offset, req_length, req_suffix = self._map_byte_request(byte_range) data = await self.cas.load( chunk_cid_str, offset=req_offset, @@ -678,77 +1299,82 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: raise PermissionError("Cannot write to a read-only store.") await self._resize_complete.wait() - if key.endswith("zarr.json") and not key == "zarr.json": - metadata_json = json.loads(value.to_bytes().decode("utf-8")) - new_array_shape = metadata_json.get("shape") - # Some metadata entries (e.g., group metadata) do not have a shape field. - if new_array_shape: - # Only resize when the metadata shape represents the primary array. - if ( - len(new_array_shape) == len(self._array_shape) - and tuple(new_array_shape) != self._array_shape - ): - async with self._resize_lock: - # Double-check after acquiring the lock, in case another task - # just finished this exact resize while we were waiting. - if ( - len(new_array_shape) == len(self._array_shape) - and tuple(new_array_shape) != self._array_shape - ): - # Block all other tasks until resize is complete. - self._resize_complete.clear() - try: - await self.resize_store( - new_shape=tuple(new_array_shape) - ) - finally: - # All waiting tasks will now un-pause and proceed safely. - self._resize_complete.set() - raw_data_bytes = value.to_bytes() - # Save the data to CAS first to get its CID. - # Metadata is often saved as 'raw', chunks as well unless compressed. + await self._ensure_v2_parent_group_metadata(key) + await self._register_array_metadata_from_bytes(key, raw_data_bytes) + try: data_cid_obj = await self.cas.save(raw_data_bytes, codec="raw") - await self.set_pointer(key, str(data_cid_obj)) + await self._set_pointer(key, str(data_cid_obj), register_metadata=False) if self._parse_chunk_key(key) is None: self._metadata_read_cache[key] = raw_data_bytes except Exception as e: - raise RuntimeError(f"Failed to save data for key {key}: {e}") + raise RuntimeError(f"Failed to save data for key {key}: {e}") from e return None # type: ignore[return-value] async def set_pointer(self, key: str, pointer: str) -> None: - chunk_coords = self._parse_chunk_key(key) + await self._set_pointer(key, pointer, register_metadata=True) - pointer_cid_obj = CID.decode(pointer) # Convert string to CID object + async def _set_pointer( + self, key: str, pointer: str, *, register_metadata: bool + ) -> None: + parsed_chunk = self._parse_chunk_key(key) + pointer_cid_obj = CID.decode(pointer) - if chunk_coords is None: # Metadata key + if parsed_chunk is None: + if register_metadata: + await self._ensure_v2_parent_group_metadata(key) self._root_obj["metadata"][key] = pointer_cid_obj self._dirty_root = True + if ( + register_metadata + and self._array_path_from_metadata_key(key) is not None + ): + raw_metadata = await self.cas.load(str(pointer_cid_obj)) + await self._register_array_metadata_from_bytes(key, raw_metadata) return None - linear_chunk_index = self._get_linear_chunk_index(chunk_coords) - shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + array_index = self._array_index_for_path(parsed_chunk.array_path) + linear_chunk_index = self._get_linear_chunk_index_for_index( + parsed_chunk.coords, array_index + ) + shard_idx, index_in_shard = self._get_shard_info_for_index( + linear_chunk_index, array_index + ) - shard_lock = self._shard_locks[shard_idx] + cache_key = self._cache_key(array_index.array_path, shard_idx) + shard_lock = self._shard_locks[cache_key] async with shard_lock: - target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + target_shard_list = await self._load_or_initialize_shard_cache( + shard_idx, array_index.array_path + ) if target_shard_list[index_in_shard] != pointer_cid_obj: target_shard_list[index_in_shard] = pointer_cid_obj - await self._shard_data_cache.mark_dirty(shard_idx) + await self._shard_data_cache.mark_dirty(cache_key) return None async def exists(self, key: str) -> bool: try: - chunk_coords = self._parse_chunk_key(key) - if chunk_coords is None: # Metadata + parsed_chunk = self._parse_chunk_key(key) + if parsed_chunk is None: + return key in self._root_obj.get("metadata", {}) + try: + array_index = self._array_index_for_path(parsed_chunk.array_path) + except KeyError: return key in self._root_obj.get("metadata", {}) - linear_chunk_index = self._get_linear_chunk_index(chunk_coords) - shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - # Load shard if not cached and check the index - target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) - return target_shard_list[index_in_shard] is not None + linear_chunk_index = self._get_linear_chunk_index_for_index( + parsed_chunk.coords, array_index + ) + shard_idx, index_in_shard = self._get_shard_info_for_index( + linear_chunk_index, array_index + ) + target_shard_list = await self._load_or_initialize_shard_cache( + shard_idx, array_index.array_path + ) + return target_shard_list[ + index_in_shard + ] is not None or key in self._root_obj.get("metadata", {}) except (ValueError, IndexError, KeyError): return False @@ -758,7 +1384,7 @@ def supports_writes(self) -> bool: @property def supports_partial_writes(self) -> bool: - return False # Each chunk CID is written atomically into a shard slot + return False @property def supports_deletes(self) -> bool: @@ -768,88 +1394,210 @@ async def delete(self, key: str) -> None: if self.read_only: raise PermissionError("Cannot delete from a read-only store.") - chunk_coords = self._parse_chunk_key(key) - if chunk_coords is None: # Metadata - # Coordinate/metadata deletions should be idempotent for caller convenience. + parsed_chunk = self._parse_chunk_key(key) + if parsed_chunk is None: if self._root_obj["metadata"].pop(key, None) is not None: self._metadata_read_cache.pop(key, None) self._dirty_root = True return None - linear_chunk_index = self._get_linear_chunk_index(chunk_coords) - shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + try: + array_index = self._array_index_for_path(parsed_chunk.array_path) + except KeyError: + if self._root_obj["metadata"].pop(key, None) is not None: + self._metadata_read_cache.pop(key, None) + self._dirty_root = True + return None + linear_chunk_index = self._get_linear_chunk_index_for_index( + parsed_chunk.coords, array_index + ) + shard_idx, index_in_shard = self._get_shard_info_for_index( + linear_chunk_index, array_index + ) - shard_lock = self._shard_locks[shard_idx] + cache_key = self._cache_key(array_index.array_path, shard_idx) + shard_lock = self._shard_locks[cache_key] async with shard_lock: - target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + target_shard_list = await self._load_or_initialize_shard_cache( + shard_idx, array_index.array_path + ) if target_shard_list[index_in_shard] is not None: target_shard_list[index_in_shard] = None - await self._shard_data_cache.mark_dirty(shard_idx) + await self._shard_data_cache.mark_dirty(cache_key) + elif self._root_obj["metadata"].pop(key, None) is not None: + self._metadata_read_cache.pop(key, None) + self._dirty_root = True @property def supports_listing(self) -> bool: return True async def list(self) -> AsyncIterator[str]: + yielded: set[str] = set() for key in list(self._root_obj.get("metadata", {})): + yielded.add(key) yield key + if self._manifest_version == SHARDED_ZARR_V2: + async for chunk_key in self._iter_chunk_keys(): + if chunk_key not in yielded: + yield chunk_key + + async def _iter_chunk_keys(self) -> AsyncIterator[str]: + for array_path, array_index in self.array_indices.items(): + for shard_idx in range(array_index.num_shards): + cache_key = self._cache_key(array_path, shard_idx) + shard_data = await self._shard_data_cache.get(cache_key) + if shard_data is None: + if array_index.shard_cids[shard_idx] is None: + continue + shard_data = await self._load_or_initialize_shard_cache( + shard_idx, array_path + ) + + for index_in_shard, cid_obj in enumerate(shard_data): + if cid_obj is None: + continue + linear_index = ( + shard_idx * array_index.chunks_per_shard + index_in_shard + ) + if linear_index >= array_index.total_chunks: + continue + coords = self._coords_from_linear_index( + linear_index, array_index.chunks_per_dim + ) + yield self._format_chunk_key(array_path, coords) + async def list_prefix(self, prefix: str) -> AsyncIterator[str]: async for key in self.list(): if key.startswith(prefix): yield key - async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, ...]): + def _list_dir_candidate_keys(self) -> Set[str]: + keys = set(self._root_obj.get("metadata", {})) + if self._manifest_version != SHARDED_ZARR_V2: + return keys + + for array_path in self.array_indices: + if array_path: + parts = array_path.split("/") + keys.update("/".join(parts[:idx]) for idx in range(1, len(parts) + 1)) + keys.add("c" if array_path == "" else f"{array_path}/c") + return keys + + def _is_v2_chunk_listing_prefix(self, normalized_prefix: str) -> bool: + if self._manifest_version != SHARDED_ZARR_V2: + return False + for array_path in self.array_indices: + chunk_prefix = "c" if array_path == "" else f"{array_path}/c" + if normalized_prefix == chunk_prefix or normalized_prefix.startswith( + f"{chunk_prefix}/" + ): + return True + return False + + async def graft_store( + self, + store_to_graft_cid: str, + chunk_offset: Tuple[int, ...], + *, + source_array_path: Optional[str] = None, + target_array_path: Optional[str] = None, + ) -> None: if self.read_only: raise PermissionError("Cannot graft onto a read-only store.") store_to_graft = await ShardedZarrStore.open( cas=self.cas, read_only=True, root_cid=store_to_graft_cid ) - source_chunk_grid = store_to_graft._chunks_per_dim - for local_coords in itertools.product(*[range(s) for s in source_chunk_grid]): - linear_local_index = store_to_graft._get_linear_chunk_index(local_coords) - local_shard_idx, index_in_local_shard = store_to_graft._get_shard_info( - linear_local_index + source_path = ( + source_array_path + if source_array_path is not None + else store_to_graft._primary_array_path + ) + if source_path is None: + return None + + source_index = store_to_graft._array_index_for_path(source_path) + target_path = ( + target_array_path if target_array_path is not None else source_path + ) + target_index = self._array_index_for_path(target_path) + if len(chunk_offset) != len(source_index.chunks_per_dim) or len( + chunk_offset + ) != len(target_index.chunks_per_dim): + raise ValueError( + "chunk_offset must have the same number of dimensions as both source and target arrays." + ) + + for local_coords in itertools.product(*[ + range(s) for s in source_index.chunks_per_dim + ]): + linear_local_index = self._get_linear_chunk_index_for_index( + tuple(local_coords), source_index + ) + local_shard_idx, index_in_local_shard = self._get_shard_info_for_index( + linear_local_index, source_index ) - # Load the source shard into its cache source_shard_list = await store_to_graft._load_or_initialize_shard_cache( - local_shard_idx + local_shard_idx, source_index.array_path ) pointer_cid_obj = source_shard_list[index_in_local_shard] if pointer_cid_obj is None: continue - # Calculate global coordinates and write to the main store's index global_coords = tuple( c_local + c_offset - for c_local, c_offset in zip(local_coords, chunk_offset) + for c_local, c_offset in zip(local_coords, chunk_offset, strict=True) ) - linear_global_index = self._get_linear_chunk_index(global_coords) - global_shard_idx, index_in_global_shard = self._get_shard_info( - linear_global_index + try: + self._validate_chunk_coords(global_coords, target_index) + except IndexError as exc: + raise ValueError( + f"Graft target chunk coordinates {global_coords} are out of bounds." + ) from exc + linear_global_index = self._get_linear_chunk_index_for_index( + global_coords, target_index + ) + global_shard_idx, index_in_global_shard = self._get_shard_info_for_index( + linear_global_index, target_index ) - shard_lock = self._shard_locks[global_shard_idx] + cache_key = self._cache_key(target_index.array_path, global_shard_idx) + shard_lock = self._shard_locks[cache_key] async with shard_lock: target_shard_list = await self._load_or_initialize_shard_cache( - global_shard_idx + global_shard_idx, target_index.array_path ) if target_shard_list[index_in_global_shard] != pointer_cid_obj: target_shard_list[index_in_global_shard] = pointer_cid_obj - await self._shard_data_cache.mark_dirty(global_shard_idx) + await self._shard_data_cache.mark_dirty(cache_key) - async def resize_store(self, new_shape: Tuple[int, ...]): + async def resize_store( + self, new_shape: Tuple[int, ...], *, array_path: Optional[str] = None + ) -> None: """ - Resizes the store's main shard index to accommodate a new overall array shape. - This is a metadata-only operation on the store's root object. - Used when doing skeleton writes or appends via xarray where the array shape changes. + Resizes one shard index to accommodate a new array shape. """ if self.read_only: raise PermissionError("Cannot resize a read-only store.") + + if self._manifest_version == SHARDED_ZARR_V2: + target_path = ( + array_path if array_path is not None else self._primary_array_path + ) + if target_path is None: + raise RuntimeError("Store is not properly initialized for resizing.") + array_index = self._array_index_for_path(target_path) + array_index.resize(tuple(new_shape)) + if self._primary_array_path == array_index.array_path: + self._set_legacy_geometry_from_index(array_index) + self._sync_arrays_to_root() + self._dirty_root = True + return None + if ( - # self._root_obj is None self._chunk_shape is None or self._chunks_per_shard is None or self._array_shape is None @@ -863,7 +1611,7 @@ async def resize_store(self, new_shape: Tuple[int, ...]): self._array_shape = tuple(new_shape) self._chunks_per_dim = tuple( math.ceil(a / c) if c > 0 else 0 - for a, c in zip(self._array_shape, self._chunk_shape) + for a, c in zip(self._array_shape, self._chunk_shape, strict=True) ) self._total_chunks = math.prod(self._chunks_per_dim) old_num_shards = self._num_shards if self._num_shards is not None else 0 @@ -882,17 +1630,29 @@ async def resize_store(self, new_shape: Tuple[int, ...]): "shard_cids" ][: self._num_shards] + self.array_indices[""] = ArrayIndex( + array_path="", + array_shape=self._array_shape, + chunk_shape=self._chunk_shape, + chunks_per_shard=self._chunks_per_shard, + shard_cids=self._root_obj["chunks"]["shard_cids"], + ) self._dirty_root = True + return None - async def resize_variable(self, variable_name: str, new_shape: Tuple[int, ...]): + async def resize_variable( + self, variable_name: str, new_shape: Tuple[int, ...] + ) -> None: """ - Resizes the Zarr metadata for a specific variable (e.g., '.json' file). - This does NOT change the store's main shard index. + Resizes the Zarr metadata and shard index for a specific variable. """ if self.read_only: raise PermissionError("Cannot resize a read-only store.") - zarr_metadata_key = f"{variable_name}/zarr.json" + normalized_name = self._normalize_array_path(variable_name) + zarr_metadata_key = ( + "zarr.json" if normalized_name == "" else f"{normalized_name}/zarr.json" + ) old_zarr_metadata_cid = self._root_obj["metadata"].get(zarr_metadata_key) if not old_zarr_metadata_cid: @@ -902,30 +1662,98 @@ async def resize_variable(self, variable_name: str, new_shape: Tuple[int, ...]): old_zarr_metadata_bytes = await self.cas.load(old_zarr_metadata_cid) zarr_metadata_json = json.loads(old_zarr_metadata_bytes) - zarr_metadata_json["shape"] = list(new_shape) new_zarr_metadata_bytes = json.dumps(zarr_metadata_json, indent=2).encode( "utf-8" ) - # Metadata is a raw blob of bytes new_zarr_metadata_cid = await self.cas.save( new_zarr_metadata_bytes, codec="raw" ) self._root_obj["metadata"][zarr_metadata_key] = new_zarr_metadata_cid self._metadata_read_cache[zarr_metadata_key] = new_zarr_metadata_bytes + if self._manifest_version == SHARDED_ZARR_V2: + await self._register_array_metadata_from_bytes( + zarr_metadata_key, new_zarr_metadata_bytes + ) self._dirty_root = True + async def migrate_v1_to_v2(self, primary_array_path: str) -> str: + """ + Rewrite this store root as a v2 manifest, reusing the existing v1 shards + under ``primary_array_path``. + """ + normalized_path = self._normalize_array_path(primary_array_path) + if not normalized_path: + raise ValueError("primary_array_path must be a non-empty array path.") + if self._manifest_version != SHARDED_ZARR_V1: + raise ValueError("Only sharded_zarr_v1 stores can be migrated to v2.") + + await self.flush() + await self._shard_data_cache.clear() + + source_array_path = self._infer_v1_migration_source_array_path(normalized_path) + old_metadata = dict(self._root_obj.get("metadata", {})) + migrated_metadata = { + self._rewrite_v1_metadata_key_for_migration( + key, source_array_path, normalized_path + ): cid + for key, cid in old_metadata.items() + } + await self._add_missing_group_metadata(migrated_metadata, normalized_path) + old_shard_cids = list(self._root_obj["chunks"]["shard_cids"]) + migrated_index = ArrayIndex( + array_path=normalized_path, + array_shape=self._array_shape, + chunk_shape=self._chunk_shape, + chunks_per_shard=self._chunks_per_shard, + shard_cids=old_shard_cids, + ) + + self._manifest_version = SHARDED_ZARR_V2 + self.array_indices = {normalized_path: migrated_index} + self._primary_array_path = normalized_path + self._default_chunks_per_shard = migrated_index.chunks_per_shard + self._set_legacy_geometry_from_index(migrated_index) + self._root_obj = { + "manifest_version": SHARDED_ZARR_V2, + "store_type": "py_hamt.sharded_zarr", + "zarr_format": 3, + "sharding_config": { + "chunks_per_shard": migrated_index.chunks_per_shard, + "order": migrated_index.order, + }, + "metadata": migrated_metadata, + "arrays": {normalized_path: migrated_index.to_manifest()}, + } + self._metadata_read_cache.clear() + self._dirty_root = True + return await self.flush() + async def list_dir(self, prefix: str) -> AsyncIterator[str]: seen: Set[str] = set() - if prefix == "": - async for key in self.list(): # Iterates metadata keys - # e.g., if key is "group1/.zgroup" or "array1/.json", first_component is "group1" or "array1" - # if key is ".zgroup", first_component is ".zgroup" - first_component = key.split("/", 1)[0] + normalized_prefix = prefix.strip("/") + match_prefix = f"{normalized_prefix}/" if normalized_prefix else "" + + if self._is_v2_chunk_listing_prefix(normalized_prefix): + async for key in self._iter_chunk_keys(): + if not key.startswith(match_prefix): + continue + suffix = key[len(match_prefix) :] + first_component = suffix.split("/", 1)[0] if first_component not in seen: seen.add(first_component) yield first_component - else: - raise NotImplementedError("Listing with a prefix is not implemented yet.") + return + + for key in self._list_dir_candidate_keys(): + if not key.startswith(match_prefix): + continue + suffix = key[len(match_prefix) :] + if suffix == "": + continue + first_component = suffix.split("/", 1)[0] + if first_component not in seen: + seen.add(first_component) + yield first_component diff --git a/tests/test_sharded_store_grafting.py b/tests/test_sharded_store_grafting.py index 59acfd3..edd3171 100644 --- a/tests/test_sharded_store_grafting.py +++ b/tests/test_sharded_store_grafting.py @@ -280,7 +280,7 @@ async def test_graft_store_invalid_cases(create_ipfs: tuple[str, str]): proto = zarr.core.buffer.default_buffer_prototype() await source_store.set("temp/c/0/0", proto.buffer.from_bytes(b"data")) source_root_cid = await source_store.flush() - with pytest.raises(ValueError, match="Shard index 10 out of bounds."): + with pytest.raises(ValueError, match="out of bounds"): await target_store.graft_store( source_root_cid, chunk_offset=(10, 0) ) # Out of bounds for target (4x4 chunks) diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 79446ac..fb66d30 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -782,11 +782,8 @@ async def test_listing_and_metadata( # Test listing with a prefix prefix = "temp/" - with pytest.raises( - NotImplementedError, match="Listing with a prefix is not implemented yet." - ): - async for key in store_read.list_dir(prefix): - print(f"Key with prefix '{prefix}': {key}") + prefixed_dir_keys = {key async for key in store_read.list_dir(prefix)} + assert {"zarr.json"}.issubset(prefixed_dir_keys) with pytest.raises( ValueError, match="Byte range requests are not supported for metadata keys." diff --git a/tests/test_sharded_zarr_store_v2.py b/tests/test_sharded_zarr_store_v2.py new file mode 100644 index 0000000..f031105 --- /dev/null +++ b/tests/test_sharded_zarr_store_v2.py @@ -0,0 +1,910 @@ +import json + +import dag_cbor +import numpy as np +import pytest +import xarray as xr +import zarr.core.buffer +from dag_cbor.ipld import IPLDKind +from multiformats import CID, multihash +from zarr.abc.store import RangeByteRequest + +from py_hamt import HAMT +from py_hamt.hamt_to_sharded_converter import ( + _is_zarr_chunk_key, + convert_hamt_to_sharded, +) +from py_hamt.sharded_zarr_store import SHARDED_ZARR_V2, ArrayIndex, ShardedZarrStore +from py_hamt.store_httpx import ContentAddressedStore +from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +class LocalCIDCAS(ContentAddressedStore): + """Small CID-addressed in-memory CAS for ShardedZarrStore tests.""" + + def __init__(self) -> None: + self.store: dict[str, bytes] = {} + + @staticmethod + def _key(cid: IPLDKind) -> str: + decoded = CID.decode(cid) if isinstance(cid, str) else cid + if not isinstance(decoded, CID): + raise TypeError(f"Expected CID, got {type(cid).__name__}") + return str(decoded.encode("base32")) + + async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> CID: + digest = multihash.digest(data, "sha2-256") + cid = CID("base32", 1, codec, digest) + self.store[self._key(cid)] = data + return cid + + async def load( + self, + id: IPLDKind, + offset: int | None = None, + length: int | None = None, + suffix: int | None = None, + ) -> bytes: + data = self.store[self._key(id)] + if offset is not None: + if length is None: + return data[offset:] + return data[offset : offset + length] + if suffix is not None: + return data[-suffix:] + return data + + +def _pyramid_level(data: np.ndarray) -> xr.Dataset: + return xr.Dataset( + {"FPAR": (("time", "y", "x"), data)}, + coords={ + "time": np.arange(data.shape[0]), + "y": np.arange(data.shape[1]), + "x": np.arange(data.shape[2]), + }, + ) + + +@pytest.mark.asyncio +async def test_v2_grouped_pyramid_arrays_are_path_aware() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=2, + manifest_version=SHARDED_ZARR_V2, + ) + + level_0 = _pyramid_level(np.arange(16).reshape(2, 2, 4)).chunk( + {"time": 1, "y": 1, "x": 2} + ) + level_1 = _pyramid_level(np.arange(8).reshape(2, 2, 2) + 100).chunk( + {"time": 1, "y": 2, "x": 1} + ) + level_2 = _pyramid_level(np.arange(4).reshape(2, 1, 2) + 200).chunk( + {"time": 1, "y": 1, "x": 1} + ) + + level_0.to_zarr(store=store, group="0", mode="w", zarr_format=3) + level_1.to_zarr(store=store, group="1", mode="a", zarr_format=3) + level_2.to_zarr(store=store, group="2", mode="a", zarr_format=3) + + assert store._root_obj["manifest_version"] == SHARDED_ZARR_V2 + assert store.array_indices["0/FPAR"].array_shape == (2, 2, 4) + assert store.array_indices["1/FPAR"].chunk_shape == (1, 2, 1) + assert store.array_indices["2/FPAR"].array_shape == (2, 1, 2) + + assert await store.exists("0/FPAR/c/0/0/0") + assert await store.exists("1/FPAR/c/0/0/0") + assert await store.exists("2/FPAR/c/0/0/0") + assert await store.exists("0/x/c/0") + assert await store.exists("0/y/c/0") + assert await store.exists("0/time/c/0") + + root_cid = await store.flush() + read_store = await ShardedZarrStore.open( + cas=cas, read_only=True, root_cid=root_cid + ) + + xr.testing.assert_identical( + level_0, xr.open_zarr(store=read_store, group="0").compute() + ) + xr.testing.assert_identical( + level_1, xr.open_zarr(store=read_store, group="1").compute() + ) + xr.testing.assert_identical( + level_2, xr.open_zarr(store=read_store, group="2").compute() + ) + + root_entries = {entry async for entry in read_store.list_dir("")} + assert {"0", "1", "2", "zarr.json"}.issubset(root_entries) + level_entries = {entry async for entry in read_store.list_dir("0")} + assert {"FPAR", "x", "y", "time", "zarr.json"}.issubset(level_entries) + + prefix_keys = {key async for key in read_store.list_prefix("0/FPAR/")} + assert "0/FPAR/zarr.json" in prefix_keys + assert "0/FPAR/c/0/0/0" in prefix_keys + + proto = zarr.core.buffer.default_buffer_prototype() + full_chunk = await read_store.get("0/FPAR/c/0/0/0", proto) + assert full_chunk is not None + partial_chunk = await read_store.get( + "0/FPAR/c/0/0/0", proto, RangeByteRequest(start=0, end=5) + ) + assert partial_chunk is not None + assert partial_chunk.to_bytes() == full_chunk.to_bytes()[:5] + + write_store = await ShardedZarrStore.open( + cas=cas, read_only=False, root_cid=root_cid + ) + await write_store.delete("1/FPAR/c/0/0/0") + assert not await write_store.exists("1/FPAR/c/0/0/0") + assert await write_store.exists("0/FPAR/c/0/0/0") + + +@pytest.mark.asyncio +async def test_v2_resize_is_array_local() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=2, + manifest_version=SHARDED_ZARR_V2, + ) + proto = zarr.core.buffer.default_buffer_prototype() + metadata_a = { + "zarr_format": 3, + "node_type": "array", + "shape": [2, 2], + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [1, 1]}}, + } + metadata_b = { + "zarr_format": 3, + "node_type": "array", + "shape": [4], + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": [2]}}, + } + + await store.set( + "0/FPAR/zarr.json", + proto.buffer.from_bytes(json.dumps(metadata_a).encode()), + ) + await store.set( + "1/FPAR/zarr.json", + proto.buffer.from_bytes(json.dumps(metadata_b).encode()), + ) + await store.set("0/FPAR/c/0/0", proto.buffer.from_bytes(b"array-a")) + await store.set("1/FPAR/c/0", proto.buffer.from_bytes(b"array-b")) + + await store.resize_store((3, 2), array_path="0/FPAR") + + assert store.array_indices["0/FPAR"].array_shape == (3, 2) + assert store.array_indices["1/FPAR"].array_shape == (4,) + assert await store.exists("0/FPAR/c/0/0") + assert await store.exists("1/FPAR/c/0") + + await store.resize_variable("0/FPAR", (4, 2)) + assert store.array_indices["0/FPAR"].array_shape == (4, 2) + + +@pytest.mark.asyncio +async def test_v1_migrate_to_v2_reuses_shards() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + array_shape=(2, 2), + chunk_shape=(1, 1), + chunks_per_shard=2, + ) + proto = zarr.core.buffer.default_buffer_prototype() + await store.set("FPAR/c/0/0", proto.buffer.from_bytes(b"legacy")) + + migrated_cid = await store.migrate_v1_to_v2("0/FPAR") + migrated_store = await ShardedZarrStore.open( + cas=cas, read_only=True, root_cid=migrated_cid + ) + + assert migrated_store._root_obj["manifest_version"] == SHARDED_ZARR_V2 + assert "0/FPAR" in migrated_store.array_indices + migrated_chunk = await migrated_store.get("0/FPAR/c/0/0", proto) + assert migrated_chunk is not None + assert migrated_chunk.to_bytes() == b"legacy" + assert await migrated_store.get("FPAR/c/0/0", proto) is None + + +@pytest.mark.asyncio +async def test_v1_migrate_to_v2_rewrites_metadata_for_group_open() -> None: + cas = LocalCIDCAS() + data = np.arange(4).reshape(1, 2, 2) + source = xr.Dataset( + {"FPAR": (("time", "lat", "lon"), data)}, + coords={ + "time": np.arange(1), + "lat": np.arange(2), + "lon": np.arange(2), + }, + ).chunk({"time": 1, "lat": 1, "lon": 1}) + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + array_shape=(1, 2, 2), + chunk_shape=(1, 1, 1), + chunks_per_shard=2, + ) + + source.to_zarr(store=store, mode="w", zarr_format=3) + migrated_cid = await store.migrate_v1_to_v2("0/FPAR") + migrated_store = await ShardedZarrStore.open( + cas=cas, read_only=True, root_cid=migrated_cid + ) + + assert "0/zarr.json" in migrated_store._root_obj["metadata"] + assert "0/FPAR/zarr.json" in migrated_store._root_obj["metadata"] + assert "FPAR/zarr.json" not in migrated_store._root_obj["metadata"] + xr.testing.assert_identical( + source, xr.open_zarr(store=migrated_store, group="0").compute() + ) + assert ( + ShardedZarrStore._rewrite_v1_metadata_key_for_migration( + "zarr.json", "", "0/FPAR" + ) + == "0/FPAR/zarr.json" + ) + + +@pytest.mark.asyncio +async def test_v1_root_chunk_keys_remain_metadata_compatible() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + array_shape=(1,), + chunk_shape=(1,), + chunks_per_shard=1, + ) + proto = zarr.core.buffer.default_buffer_prototype() + await store.set("c/0", proto.buffer.from_bytes(b"root-array-chunk")) + root_cid = await store.flush() + + read_store = await ShardedZarrStore.open(cas=cas, read_only=True, root_cid=root_cid) + assert await read_store.exists("c/0") + chunk = await read_store.get("c/0", proto) + assert chunk is not None + assert chunk.to_bytes() == b"root-array-chunk" + + +@pytest.mark.asyncio +async def test_v1_root_metadata_chunks_migrate_to_primary_path() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + array_shape=(1,), + chunk_shape=(1,), + chunks_per_shard=1, + ) + proto = zarr.core.buffer.default_buffer_prototype() + await store.set( + "zarr.json", + proto.buffer.from_bytes( + json.dumps( + { + "zarr_format": 3, + "node_type": "array", + "shape": [1], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1]}, + }, + } + ).encode() + ), + ) + await store.set("c/0", proto.buffer.from_bytes(b"root-array-chunk")) + + migrated_cid = await store.migrate_v1_to_v2("0/root") + migrated_store = await ShardedZarrStore.open( + cas=cas, read_only=False, root_cid=migrated_cid + ) + + assert await migrated_store.exists("0/root/c/0") + chunk = await migrated_store.get("0/root/c/0", proto) + assert chunk is not None + assert chunk.to_bytes() == b"root-array-chunk" + await migrated_store.delete("0/root/c/0") + assert not await migrated_store.exists("0/root/c/0") + + +@pytest.mark.asyncio +async def test_v1_c_named_group_chunks_use_shard_index() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + array_shape=(1, 1, 1), + chunk_shape=(1, 1, 1), + chunks_per_shard=1, + ) + proto = zarr.core.buffer.default_buffer_prototype() + + await store.set("c/FPAR/c/0/0/0", proto.buffer.from_bytes(b"v1-c-group")) + root_cid = await store.flush() + read_store = await ShardedZarrStore.open(cas=cas, read_only=True, root_cid=root_cid) + + chunk = await read_store.get("c/FPAR/c/0/0/0", proto) + assert chunk is not None + assert chunk.to_bytes() == b"v1-c-group" + + +@pytest.mark.asyncio +async def test_migrated_v1_coordinate_chunks_remain_readable() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + array_shape=(2,), + chunk_shape=(1,), + chunks_per_shard=1, + ) + proto = zarr.core.buffer.default_buffer_prototype() + await store.set( + "lat/zarr.json", + proto.buffer.from_bytes( + json.dumps( + { + "zarr_format": 3, + "node_type": "array", + "shape": [2], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [2]}, + }, + } + ).encode() + ), + ) + await store.set("lat/c/0", proto.buffer.from_bytes(b"coordinate-chunk")) + + migrated_cid = await store.migrate_v1_to_v2("0/FPAR") + migrated_store = await ShardedZarrStore.open( + cas=cas, read_only=True, root_cid=migrated_cid + ) + + assert await migrated_store.exists("0/lat/c/0") + chunk = await migrated_store.get("0/lat/c/0", proto) + assert chunk is not None + assert chunk.to_bytes() == b"coordinate-chunk" + partial = await migrated_store.get( + "0/lat/c/0", proto, RangeByteRequest(start=0, end=10) + ) + assert partial is not None + assert partial.to_bytes() == b"coordinate" + + migrated_write_store = await ShardedZarrStore.open( + cas=cas, read_only=False, root_cid=migrated_cid + ) + await migrated_write_store.delete("0/lat/c/0") + assert not await migrated_write_store.exists("0/lat/c/0") + + +@pytest.mark.asyncio +async def test_empty_v2_root_reopen_retains_default_sharding_config() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=3, + manifest_version=SHARDED_ZARR_V2, + ) + root_cid = await store.flush() + reopened = await ShardedZarrStore.open( + cas=cas, read_only=False, root_cid=root_cid + ) + proto = zarr.core.buffer.default_buffer_prototype() + + await reopened.set( + "a/zarr.json", + proto.buffer.from_bytes( + json.dumps( + { + "zarr_format": 3, + "node_type": "array", + "shape": [2], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1]}, + }, + } + ).encode() + ), + ) + await reopened.set("a/c/0", proto.buffer.from_bytes(b"chunk")) + + assert reopened.array_indices["a"].chunks_per_shard == 3 + assert await reopened.exists("a/c/0") + + +@pytest.mark.asyncio +async def test_v2_c_named_arrays_groups_and_metadata_suffixes() -> None: + cas = LocalCIDCAS() + proto = zarr.core.buffer.default_buffer_prototype() + array_store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version=SHARDED_ZARR_V2, + ) + await array_store.set( + "c/zarr.json", + proto.buffer.from_bytes( + json.dumps( + { + "zarr_format": 3, + "node_type": "array", + "shape": [1], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1]}, + }, + } + ).encode() + ), + ) + await array_store.set("c/c/0", proto.buffer.from_bytes(b"top-level-c")) + top_level_chunk = await array_store.get("c/c/0", proto) + assert top_level_chunk is not None + assert top_level_chunk.to_bytes() == b"top-level-c" + + group_store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version=SHARDED_ZARR_V2, + ) + await group_store.set("c/.zattrs", proto.buffer.from_bytes(b"{}")) + attrs = await group_store.get("c/.zattrs", proto) + assert attrs is not None + assert attrs.to_bytes() == b"{}" + await group_store.set( + "c/FPAR/zarr.json", + proto.buffer.from_bytes( + json.dumps( + { + "zarr_format": 3, + "node_type": "array", + "shape": [1, 1, 1], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1, 1, 1]}, + }, + } + ).encode() + ), + ) + await group_store.set("c/FPAR/c/0/0/0", proto.buffer.from_bytes(b"group-c")) + group_chunk = await group_store.get("c/FPAR/c/0/0/0", proto) + assert group_chunk is not None + assert group_chunk.to_bytes() == b"group-c" + + +@pytest.mark.asyncio +async def test_v2_list_dir_can_walk_explicit_chunk_prefixes() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version=SHARDED_ZARR_V2, + ) + proto = zarr.core.buffer.default_buffer_prototype() + await store.set( + "a/zarr.json", + proto.buffer.from_bytes( + json.dumps( + { + "zarr_format": 3, + "node_type": "array", + "shape": [2, 1], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1, 1]}, + }, + } + ).encode() + ), + ) + await store.set("a/c/1/0", proto.buffer.from_bytes(b"chunk")) + await store.set( + "b/zarr.json", + proto.buffer.from_bytes( + json.dumps( + { + "zarr_format": 3, + "node_type": "array", + "shape": [1, 1], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1, 1]}, + }, + } + ).encode() + ), + ) + await store.set("b/c/0/0", proto.buffer.from_bytes(b"other-chunk")) + + assert {entry async for entry in store.list_dir("a/c")} == {"1"} + assert {entry async for entry in store.list_dir("a/c/1")} == {"0"} + assert {entry async for entry in store.list_dir("a/c/1/0")} == set() + + +@pytest.mark.asyncio +async def test_v2_root_array_chunks_and_defensive_validation_paths() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version=SHARDED_ZARR_V2, + array_shape=(1,), + chunk_shape=(1,), + ) + proto = zarr.core.buffer.default_buffer_prototype() + + await store.set("c/0", proto.buffer.from_bytes(b"root-v2")) + assert await store.exists("c/0") + chunk = await store.get("c/0", proto) + assert chunk is not None + assert chunk.to_bytes() == b"root-v2" + + with pytest.raises(ValueError, match="Shard index 2 out of bounds"): + await store._load_or_initialize_shard_cache(2, "") + + root_obj = { + "manifest_version": SHARDED_ZARR_V2, + "store_type": "py_hamt.sharded_zarr", + "zarr_format": 3, + "sharding_config": "invalid", + "metadata": {}, + "arrays": {}, + } + root_cid = await cas.save(dag_cbor.encode(root_obj), codec="dag-cbor") + opened = await ShardedZarrStore.open( + cas=cas, read_only=False, root_cid=str(root_cid.encode("base32")) + ) + assert opened._default_chunks_per_shard is None + + empty_index_store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version=SHARDED_ZARR_V2, + array_shape=(1,), + chunk_shape=(1,), + ) + assert {key async for key in empty_index_store.list()} == set() + + +@pytest.mark.asyncio +async def test_failed_shard_validation_clears_pending_load() -> None: + cas = LocalCIDCAS() + store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version=SHARDED_ZARR_V2, + array_shape=(1,), + chunk_shape=(1,), + ) + bad_shard_cid = await cas.save(dag_cbor.encode([123]), codec="dag-cbor") + store.array_indices[""].shard_cids[0] = bad_shard_cid + + with pytest.raises(TypeError, match="non-CID"): + await store._load_or_initialize_shard_cache(0, "") + assert ("", 0) not in store._pending_shard_loads + + +@pytest.mark.asyncio +async def test_graft_rejects_offsets_that_alias_invalid_coordinates() -> None: + cas = LocalCIDCAS() + source = await ShardedZarrStore.open( + cas=cas, + read_only=False, + array_shape=(1, 1), + chunk_shape=(1, 1), + chunks_per_shard=4, + ) + target = await ShardedZarrStore.open( + cas=cas, + read_only=False, + array_shape=(2, 2), + chunk_shape=(1, 1), + chunks_per_shard=4, + ) + proto = zarr.core.buffer.default_buffer_prototype() + await source.set("temp/c/0/0", proto.buffer.from_bytes(b"source")) + source_cid = await source.flush() + + with pytest.raises(ValueError, match="chunk_offset"): + await target.graft_store(source_cid, chunk_offset=(0,)) + with pytest.raises(ValueError, match="out of bounds"): + await target.graft_store(source_cid, chunk_offset=(0, 2)) + assert not await target.exists("temp/c/1/0") + + +def test_converter_chunk_key_classifier_keeps_c_named_metadata_first() -> None: + assert not _is_zarr_chunk_key("c/zarr.json") + assert not _is_zarr_chunk_key("0/c/zarr.json") + assert _is_zarr_chunk_key("0/c/c/0") + + +@pytest.mark.asyncio +async def test_converter_discovers_grouped_arrays() -> None: + cas = LocalCIDCAS() + hamt = await HAMT.build(cas=cas, values_are_bytes=True) + source_store = ZarrHAMTStore(hamt) + + level_0 = _pyramid_level(np.arange(4).reshape(1, 2, 2)).chunk( + {"time": 1, "y": 1, "x": 1} + ) + level_1 = _pyramid_level(np.arange(2).reshape(1, 1, 2) + 10).chunk( + {"time": 1, "y": 1, "x": 1} + ) + level_0.to_zarr( + store=source_store, + group="0", + mode="w", + zarr_format=3, + consolidated=False, + ) + level_1.to_zarr( + store=source_store, + group="1", + mode="a", + zarr_format=3, + consolidated=False, + ) + await hamt.make_read_only() + + sharded_root = await convert_hamt_to_sharded(cas, str(hamt.root_node_id), 2) + converted_store = await ShardedZarrStore.open( + cas=cas, read_only=True, root_cid=sharded_root + ) + + assert converted_store._root_obj["manifest_version"] == SHARDED_ZARR_V2 + assert {"0/FPAR", "1/FPAR"}.issubset(converted_store.array_indices) + xr.testing.assert_identical( + level_0, + xr.open_zarr(store=converted_store, group="0", consolidated=False).compute(), + ) + xr.testing.assert_identical( + level_1, + xr.open_zarr(store=converted_store, group="1", consolidated=False).compute(), + ) + + +def test_array_index_validation_paths() -> None: + with pytest.raises(ValueError, match="Inconsistent number of shards"): + ArrayIndex( + array_path="a", + array_shape=(4,), + chunk_shape=(1,), + chunks_per_shard=2, + shard_cids=[None], + ) + with pytest.raises(ValueError, match="chunks_per_shard"): + ArrayIndex( + array_path="a", + array_shape=(1,), + chunk_shape=(1,), + chunks_per_shard=0, + shard_cids=[], + ) + with pytest.raises(ValueError, match="same rank"): + ArrayIndex( + array_path="a", + array_shape=(1, 1), + chunk_shape=(1,), + chunks_per_shard=1, + shard_cids=[], + ) + with pytest.raises(ValueError, match="row-major"): + ArrayIndex( + array_path="a", + array_shape=(1,), + chunk_shape=(1,), + chunks_per_shard=1, + shard_cids=[None], + order="F", + ) + with pytest.raises(ValueError, match="shard_cids is not a list"): + ArrayIndex.from_manifest( + "a", + { + "array_shape": [1], + "chunk_shape": [1], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": "bad", + }, + ) + with pytest.raises(ValueError, match="sharding_config"): + ArrayIndex.from_manifest( + "a", + { + "array_shape": [1], + "chunk_shape": [1], + "sharding_config": "bad", + "shard_cids": [None], + }, + ) + + index = ArrayIndex.new("a", (4,), (1,), 2) + with pytest.raises(ValueError, match="same number of dimensions"): + index.resize((1, 1)) + index.resize((1,)) + assert index.num_shards == 1 + assert len(index.shard_cids) == 1 + + +@pytest.mark.asyncio +async def test_v2_validation_paths() -> None: + cas = LocalCIDCAS() + with pytest.raises(ValueError, match="Incompatible manifest version"): + await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version="bad", + ) + with pytest.raises(ValueError, match="both be provided"): + await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version=SHARDED_ZARR_V2, + array_shape=(1,), + ) + with pytest.raises(ValueError, match="must be provided"): + await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version="sharded_zarr_v1", + ) + + seeded_store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version=SHARDED_ZARR_V2, + array_shape=(1,), + chunk_shape=(1,), + primary_array_path="seed", + ) + assert seeded_store.array_indices["seed"].array_shape == (1,) + + assert ShardedZarrStore._array_path_from_metadata_key("a/.zarray") == "a" + assert ShardedZarrStore._format_chunk_key("", (0,)) == "c/0" + + empty_store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + chunks_per_shard=1, + manifest_version=SHARDED_ZARR_V2, + ) + assert empty_store._root_obj["sharding_config"]["chunks_per_shard"] == 1 + with pytest.raises(RuntimeError, match="not properly initialized"): + await empty_store.resize_store((1,)) + with pytest.raises(ValueError, match="non-empty"): + await seeded_store.migrate_v1_to_v2("") + with pytest.raises(ValueError, match="Only sharded_zarr_v1"): + await seeded_store.migrate_v1_to_v2("seed") + + assert ShardedZarrStore._array_path_from_metadata_key("attrs") is None + assert ShardedZarrStore._decode_metadata_json(b"not-json") is None + assert ShardedZarrStore._decode_metadata_json(b"[]") is None + assert ShardedZarrStore._extract_array_metadata({}) is None + assert ShardedZarrStore._extract_array_metadata({"shape": [1]}) is None + + await empty_store._register_array_metadata_from_bytes( + "attrs", json.dumps({"shape": [1]}).encode() + ) + await empty_store._register_array_metadata_from_bytes("a/zarr.json", b"not-json") + await empty_store._register_array_metadata_from_bytes( + "a/zarr.json", json.dumps({"shape": [1]}).encode() + ) + assert empty_store.array_indices == {} + with pytest.raises(RuntimeError, match="chunks_per_shard"): + empty_store._default_chunks_per_shard = None + empty_store._register_or_update_array_index( + array_path="a", array_shape=(1,), chunk_shape=(1,) + ) + + await empty_store.delete("missing/c/0") + empty_store._root_obj["metadata"]["missing"] = None + empty_store._root_obj["metadata"]["missing/"] = None + assert {entry async for entry in empty_store.list_dir("missing")} == set() + + empty_source_cid = await empty_store.flush() + await seeded_store.graft_store(empty_source_cid, chunk_offset=(0,)) + + +@pytest.mark.asyncio +async def test_v2_invalid_root_and_shard_validation() -> None: + cas = LocalCIDCAS() + + invalid_v2_roots = [ + {"manifest_version": SHARDED_ZARR_V2, "metadata": [], "arrays": {}}, + { + "manifest_version": SHARDED_ZARR_V2, + "metadata": {}, + "arrays": { + "bad": { + "array_shape": [1], + "chunk_shape": [1], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": "bad", + } + }, + }, + { + "manifest_version": SHARDED_ZARR_V2, + "metadata": {}, + "arrays": { + "bad": { + "array_shape": [2], + "chunk_shape": [1], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": [None], + } + }, + }, + ] + for root_obj in invalid_v2_roots: + root_cid = await cas.save(dag_cbor.encode(root_obj), codec="dag-cbor") + with pytest.raises(ValueError): + await ShardedZarrStore.open(cas=cas, read_only=True, root_cid=str(root_cid)) + + store = ShardedZarrStore(cas=cas, read_only=True) + store._root_obj = {"manifest_version": SHARDED_ZARR_V2, "metadata": {}, "arrays": {1: {}}} + with pytest.raises(ValueError, match="arrays must map"): + store._load_v2_root() + + bad_shard_cid = await cas.save(dag_cbor.encode([1]), codec="dag-cbor") + root_obj = { + "manifest_version": SHARDED_ZARR_V2, + "metadata": {}, + "arrays": { + "a": { + "array_shape": [1], + "chunk_shape": [1], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": [bad_shard_cid], + } + }, + } + root_cid = await cas.save(dag_cbor.encode(root_obj), codec="dag-cbor") + store = await ShardedZarrStore.open(cas=cas, read_only=True, root_cid=str(root_cid)) + with pytest.raises(TypeError, match="non-CID"): + await store.get("a/c/0", zarr.core.buffer.default_buffer_prototype()) + + extra_chunk_cid = await cas.save(b"extra", codec="raw") + sparse_shard_cid = await cas.save( + dag_cbor.encode([None, extra_chunk_cid]), codec="dag-cbor" + ) + root_obj = { + "manifest_version": SHARDED_ZARR_V2, + "metadata": {}, + "arrays": { + "a": { + "array_shape": [1], + "chunk_shape": [1], + "sharding_config": {"chunks_per_shard": 2}, + "shard_cids": [sparse_shard_cid], + } + }, + } + root_cid = await cas.save(dag_cbor.encode(root_obj), codec="dag-cbor") + store = await ShardedZarrStore.open(cas=cas, read_only=True, root_cid=str(root_cid)) + assert {key async for key in store.list_prefix("a/c/")} == set() From ac1d1d29424f6a4e2a8bf82ae0b1137baf225249 Mon Sep 17 00:00:00 2001 From: 0xSwego <0xSwego@gmail.com> Date: Mon, 29 Jun 2026 17:32:42 +0100 Subject: [PATCH 2/2] fix: align pyramid zarr checks with CI --- py_hamt/hamt_to_sharded_converter.py | 4 +- tests/test_sharded_zarr_store_v2.py | 194 +++++++++++++-------------- 2 files changed, 97 insertions(+), 101 deletions(-) diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py index 2f8b050..2ba24a9 100644 --- a/py_hamt/hamt_to_sharded_converter.py +++ b/py_hamt/hamt_to_sharded_converter.py @@ -6,7 +6,7 @@ from .hamt import HAMT from .sharded_zarr_store import SHARDED_ZARR_V2, ShardedZarrStore -from .store_httpx import KuboCAS +from .store_httpx import ContentAddressedStore, KuboCAS def _is_zarr_chunk_key(key: str) -> bool: @@ -16,7 +16,7 @@ def _is_zarr_chunk_key(key: str) -> bool: async def convert_hamt_to_sharded( - cas: KuboCAS, hamt_root_cid: str, chunks_per_shard: int + cas: ContentAddressedStore, hamt_root_cid: str, chunks_per_shard: int ) -> str: """ Converts a Zarr dataset from a HAMT-based store to a ShardedZarrStore. diff --git a/tests/test_sharded_zarr_store_v2.py b/tests/test_sharded_zarr_store_v2.py index f031105..0588365 100644 --- a/tests/test_sharded_zarr_store_v2.py +++ b/tests/test_sharded_zarr_store_v2.py @@ -76,15 +76,21 @@ async def test_v2_grouped_pyramid_arrays_are_path_aware() -> None: manifest_version=SHARDED_ZARR_V2, ) - level_0 = _pyramid_level(np.arange(16).reshape(2, 2, 4)).chunk( - {"time": 1, "y": 1, "x": 2} - ) - level_1 = _pyramid_level(np.arange(8).reshape(2, 2, 2) + 100).chunk( - {"time": 1, "y": 2, "x": 1} - ) - level_2 = _pyramid_level(np.arange(4).reshape(2, 1, 2) + 200).chunk( - {"time": 1, "y": 1, "x": 1} - ) + level_0 = _pyramid_level(np.arange(16).reshape(2, 2, 4)).chunk({ + "time": 1, + "y": 1, + "x": 2, + }) + level_1 = _pyramid_level(np.arange(8).reshape(2, 2, 2) + 100).chunk({ + "time": 1, + "y": 2, + "x": 1, + }) + level_2 = _pyramid_level(np.arange(4).reshape(2, 1, 2) + 200).chunk({ + "time": 1, + "y": 1, + "x": 1, + }) level_0.to_zarr(store=store, group="0", mode="w", zarr_format=3) level_1.to_zarr(store=store, group="1", mode="a", zarr_format=3) @@ -103,9 +109,7 @@ async def test_v2_grouped_pyramid_arrays_are_path_aware() -> None: assert await store.exists("0/time/c/0") root_cid = await store.flush() - read_store = await ShardedZarrStore.open( - cas=cas, read_only=True, root_cid=root_cid - ) + read_store = await ShardedZarrStore.open(cas=cas, read_only=True, root_cid=root_cid) xr.testing.assert_identical( level_0, xr.open_zarr(store=read_store, group="0").compute() @@ -289,17 +293,15 @@ async def test_v1_root_metadata_chunks_migrate_to_primary_path() -> None: await store.set( "zarr.json", proto.buffer.from_bytes( - json.dumps( - { - "zarr_format": 3, - "node_type": "array", - "shape": [1], - "chunk_grid": { - "name": "regular", - "configuration": {"chunk_shape": [1]}, - }, - } - ).encode() + json.dumps({ + "zarr_format": 3, + "node_type": "array", + "shape": [1], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1]}, + }, + }).encode() ), ) await store.set("c/0", proto.buffer.from_bytes(b"root-array-chunk")) @@ -352,17 +354,15 @@ async def test_migrated_v1_coordinate_chunks_remain_readable() -> None: await store.set( "lat/zarr.json", proto.buffer.from_bytes( - json.dumps( - { - "zarr_format": 3, - "node_type": "array", - "shape": [2], - "chunk_grid": { - "name": "regular", - "configuration": {"chunk_shape": [2]}, - }, - } - ).encode() + json.dumps({ + "zarr_format": 3, + "node_type": "array", + "shape": [2], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [2]}, + }, + }).encode() ), ) await store.set("lat/c/0", proto.buffer.from_bytes(b"coordinate-chunk")) @@ -399,25 +399,21 @@ async def test_empty_v2_root_reopen_retains_default_sharding_config() -> None: manifest_version=SHARDED_ZARR_V2, ) root_cid = await store.flush() - reopened = await ShardedZarrStore.open( - cas=cas, read_only=False, root_cid=root_cid - ) + reopened = await ShardedZarrStore.open(cas=cas, read_only=False, root_cid=root_cid) proto = zarr.core.buffer.default_buffer_prototype() await reopened.set( "a/zarr.json", proto.buffer.from_bytes( - json.dumps( - { - "zarr_format": 3, - "node_type": "array", - "shape": [2], - "chunk_grid": { - "name": "regular", - "configuration": {"chunk_shape": [1]}, - }, - } - ).encode() + json.dumps({ + "zarr_format": 3, + "node_type": "array", + "shape": [2], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1]}, + }, + }).encode() ), ) await reopened.set("a/c/0", proto.buffer.from_bytes(b"chunk")) @@ -439,17 +435,15 @@ async def test_v2_c_named_arrays_groups_and_metadata_suffixes() -> None: await array_store.set( "c/zarr.json", proto.buffer.from_bytes( - json.dumps( - { - "zarr_format": 3, - "node_type": "array", - "shape": [1], - "chunk_grid": { - "name": "regular", - "configuration": {"chunk_shape": [1]}, - }, - } - ).encode() + json.dumps({ + "zarr_format": 3, + "node_type": "array", + "shape": [1], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1]}, + }, + }).encode() ), ) await array_store.set("c/c/0", proto.buffer.from_bytes(b"top-level-c")) @@ -470,17 +464,15 @@ async def test_v2_c_named_arrays_groups_and_metadata_suffixes() -> None: await group_store.set( "c/FPAR/zarr.json", proto.buffer.from_bytes( - json.dumps( - { - "zarr_format": 3, - "node_type": "array", - "shape": [1, 1, 1], - "chunk_grid": { - "name": "regular", - "configuration": {"chunk_shape": [1, 1, 1]}, - }, - } - ).encode() + json.dumps({ + "zarr_format": 3, + "node_type": "array", + "shape": [1, 1, 1], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1, 1, 1]}, + }, + }).encode() ), ) await group_store.set("c/FPAR/c/0/0/0", proto.buffer.from_bytes(b"group-c")) @@ -502,34 +494,30 @@ async def test_v2_list_dir_can_walk_explicit_chunk_prefixes() -> None: await store.set( "a/zarr.json", proto.buffer.from_bytes( - json.dumps( - { - "zarr_format": 3, - "node_type": "array", - "shape": [2, 1], - "chunk_grid": { - "name": "regular", - "configuration": {"chunk_shape": [1, 1]}, - }, - } - ).encode() + json.dumps({ + "zarr_format": 3, + "node_type": "array", + "shape": [2, 1], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1, 1]}, + }, + }).encode() ), ) await store.set("a/c/1/0", proto.buffer.from_bytes(b"chunk")) await store.set( "b/zarr.json", proto.buffer.from_bytes( - json.dumps( - { - "zarr_format": 3, - "node_type": "array", - "shape": [1, 1], - "chunk_grid": { - "name": "regular", - "configuration": {"chunk_shape": [1, 1]}, - }, - } - ).encode() + json.dumps({ + "zarr_format": 3, + "node_type": "array", + "shape": [1, 1], + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [1, 1]}, + }, + }).encode() ), ) await store.set("b/c/0/0", proto.buffer.from_bytes(b"other-chunk")) @@ -645,12 +633,16 @@ async def test_converter_discovers_grouped_arrays() -> None: hamt = await HAMT.build(cas=cas, values_are_bytes=True) source_store = ZarrHAMTStore(hamt) - level_0 = _pyramid_level(np.arange(4).reshape(1, 2, 2)).chunk( - {"time": 1, "y": 1, "x": 1} - ) - level_1 = _pyramid_level(np.arange(2).reshape(1, 1, 2) + 10).chunk( - {"time": 1, "y": 1, "x": 1} - ) + level_0 = _pyramid_level(np.arange(4).reshape(1, 2, 2)).chunk({ + "time": 1, + "y": 1, + "x": 1, + }) + level_1 = _pyramid_level(np.arange(2).reshape(1, 1, 2) + 10).chunk({ + "time": 1, + "y": 1, + "x": 1, + }) level_0.to_zarr( store=source_store, group="0", @@ -867,7 +859,11 @@ async def test_v2_invalid_root_and_shard_validation() -> None: await ShardedZarrStore.open(cas=cas, read_only=True, root_cid=str(root_cid)) store = ShardedZarrStore(cas=cas, read_only=True) - store._root_obj = {"manifest_version": SHARDED_ZARR_V2, "metadata": {}, "arrays": {1: {}}} + store._root_obj = { + "manifest_version": SHARDED_ZARR_V2, + "metadata": {}, + "arrays": {1: {}}, + } with pytest.raises(ValueError, match="arrays must map"): store._load_v2_root()