diff --git a/lib/crewai-files/src/crewai_files/cache/upload_cache.py b/lib/crewai-files/src/crewai_files/cache/upload_cache.py index 48cebdfa14..c94e164c7e 100644 --- a/lib/crewai-files/src/crewai_files/cache/upload_cache.py +++ b/lib/crewai-files/src/crewai_files/cache/upload_cache.py @@ -1,4 +1,4 @@ -"""Cache for tracking uploaded files using aiocache.""" +"""Cache for tracking uploaded files using aiocache or ValkeyCache.""" from __future__ import annotations @@ -10,10 +10,11 @@ from datetime import datetime, timezone import hashlib import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol from aiocache import Cache # type: ignore[import-untyped] from aiocache.serializers import PickleSerializer # type: ignore[import-untyped] +from crewai.utilities.cache_config import parse_cache_url from crewai_files.core.constants import DEFAULT_MAX_CACHE_ENTRIES, DEFAULT_TTL_SECONDS from crewai_files.uploaders.factory import ProviderType @@ -51,6 +52,33 @@ def is_expired(self) -> bool: return False return datetime.now(timezone.utc) >= self.expires_at + def to_dict(self) -> dict[str, Any]: + """Serialize to a JSON-compatible dict.""" + return { + "file_id": self.file_id, + "provider": self.provider, + "file_uri": self.file_uri, + "content_type": self.content_type, + "uploaded_at": self.uploaded_at.isoformat(), + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> CachedUpload: + """Deserialize from a dict.""" + return cls( + file_id=data["file_id"], + provider=data["provider"], + file_uri=data.get("file_uri"), + content_type=data["content_type"], + uploaded_at=datetime.fromisoformat(data["uploaded_at"]), + expires_at=( + datetime.fromisoformat(data["expires_at"]) + if data.get("expires_at") + else None + ), + ) + def _make_key(file_hash: str, provider: str) -> str: """Create a cache key from file hash and provider.""" @@ -58,14 +86,7 @@ def _make_key(file_hash: str, provider: str) -> str: def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str: - """Compute SHA-256 hash from streaming chunks. - - Args: - chunks: Iterator of byte chunks. - - Returns: - Hexadecimal hash string. - """ + """Compute SHA-256 hash from streaming chunks.""" hasher = hashlib.sha256() for chunk in chunks: hasher.update(chunk) @@ -73,10 +94,7 @@ def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str: def _compute_file_hash(file: FileInput) -> str: - """Compute SHA-256 hash of file content. - - Uses streaming for FilePath sources to avoid loading large files into memory. - """ + """Compute SHA-256 hash of file content.""" from crewai_files.core.sources import FilePath source = file._file_source @@ -86,10 +104,73 @@ def _compute_file_hash(file: FileInput) -> str: return hashlib.sha256(content).hexdigest() +class CacheBackend(Protocol): + """Protocol for cache backends used by UploadCache.""" + + async def get(self, key: str) -> CachedUpload | None: ... + async def set(self, key: str, value: CachedUpload, ttl: int) -> None: ... + async def delete(self, key: str) -> bool: ... + + +class AiocacheBackend: + """Cache backend backed by aiocache (memory or Redis).""" + + def __init__(self, cache: Cache) -> None: # type: ignore[no-any-unimported] + self._cache = cache + + async def get(self, key: str) -> CachedUpload | None: + result = await self._cache.get(key) + if isinstance(result, CachedUpload): + return result + return None + + async def set(self, key: str, value: CachedUpload, ttl: int) -> None: + await self._cache.set(key, value, ttl=ttl) + + async def delete(self, key: str) -> bool: + result = await self._cache.delete(key) + return bool(result > 0 if isinstance(result, int) else result) + + +class ValkeyCacheBackend: + """Cache backend backed by ValkeyCache (JSON serialization).""" + + def __init__( + self, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: str | None = None, + default_ttl: int | None = None, + ) -> None: + from crewai.memory.storage.valkey_cache import ValkeyCache + + self._cache = ValkeyCache( + host=host, port=port, db=db, password=password, default_ttl=default_ttl + ) + + async def get(self, key: str) -> CachedUpload | None: + data = await self._cache.get(key) + if data is None: + return None + try: + return CachedUpload.from_dict(data) + except (KeyError, ValueError) as e: + logger.warning(f"Failed to deserialize cached upload: {e}") + return None + + async def set(self, key: str, value: CachedUpload, ttl: int) -> None: + await self._cache.set(key, value.to_dict(), ttl=ttl) + + async def delete(self, key: str) -> bool: + await self._cache.delete(key) + return True # ValkeyCache.delete is void + + class UploadCache: - """Async cache for tracking uploaded files using aiocache. + """Async cache for tracking uploaded files. - Supports in-memory caching by default, with optional Redis backend + Supports in-memory caching by default, with optional Redis or Valkey backend for distributed setups. Attributes: @@ -110,7 +191,7 @@ def __init__( Args: ttl: Default TTL in seconds. namespace: Cache namespace. - cache_type: Backend type ("memory" or "redis"). + cache_type: Backend type ("memory", "redis", or "valkey"). max_entries: Maximum cache entries (None for unlimited). **cache_kwargs: Additional args for cache backend. """ @@ -120,18 +201,39 @@ def __init__( self._provider_keys: dict[ProviderType, set[str]] = {} self._key_access_order: list[str] = [] - if cache_type == "redis": - self._cache = Cache( - Cache.REDIS, - serializer=PickleSerializer(), - namespace=namespace, - **cache_kwargs, + self._backend: CacheBackend = self._create_backend( + cache_type, namespace, ttl, **cache_kwargs + ) + + @staticmethod + def _create_backend( + cache_type: str, + namespace: str, + ttl: int, + **cache_kwargs: Any, + ) -> CacheBackend: + """Create the appropriate cache backend.""" + if cache_type == "valkey": + conn = parse_cache_url() or {} + return ValkeyCacheBackend( + host=cache_kwargs.get("host", conn.get("host", "localhost")), + port=cache_kwargs.get("port", conn.get("port", 6379)), + db=cache_kwargs.get("db", conn.get("db", 0)), + password=cache_kwargs.get("password", conn.get("password")), + default_ttl=ttl, ) - else: - self._cache = Cache( - serializer=PickleSerializer(), - namespace=namespace, + if cache_type == "redis": + return AiocacheBackend( + Cache( + Cache.REDIS, + serializer=PickleSerializer(), + namespace=namespace, + **cache_kwargs, + ) ) + return AiocacheBackend( + Cache(serializer=PickleSerializer(), namespace=namespace) + ) def _track_key(self, provider: ProviderType, key: str) -> None: """Track a key for a provider (for cleanup) and access order.""" @@ -157,11 +259,9 @@ async def _evict_if_needed(self) -> int: """ if self.max_entries is None: return 0 - current_count = len(self) if current_count < self.max_entries: return 0 - to_evict = max(1, self.max_entries // 10) return await self._evict_oldest(to_evict) @@ -176,31 +276,24 @@ async def _evict_oldest(self, count: int) -> int: """ evicted = 0 keys_to_evict = self._key_access_order[:count] - for key in keys_to_evict: - await self._cache.delete(key) + await self._backend.delete(key) self._key_access_order.remove(key) for provider_keys in self._provider_keys.values(): provider_keys.discard(key) evicted += 1 - if evicted > 0: logger.debug(f"Evicted {evicted} oldest cache entries") - return evicted + # ------------------------------------------------------------------ + # Async public API + # ------------------------------------------------------------------ + async def aget( self, file: FileInput, provider: ProviderType ) -> CachedUpload | None: - """Get a cached upload for a file. - - Args: - file: The file to look up. - provider: The provider name. - - Returns: - Cached upload if found and not expired, None otherwise. - """ + """Get a cached upload for a file.""" file_hash = _compute_file_hash(file) return await self.aget_by_hash(file_hash, provider) @@ -217,17 +310,14 @@ async def aget_by_hash( Cached upload if found and not expired, None otherwise. """ key = _make_key(file_hash, provider) - result = await self._cache.get(key) - + result = await self._backend.get(key) if result is None: return None - if isinstance(result, CachedUpload): - if result.is_expired(): - await self._cache.delete(key) - self._untrack_key(provider, key) - return None - return result - return None + if result.is_expired(): + await self._backend.delete(key) + self._untrack_key(provider, key) + return None + return result async def aset( self, @@ -237,18 +327,7 @@ async def aset( file_uri: str | None = None, expires_at: datetime | None = None, ) -> CachedUpload: - """Cache an uploaded file. - - Args: - file: The file that was uploaded. - provider: The provider name. - file_id: Provider-specific file identifier. - file_uri: Optional URI for accessing the file. - expires_at: When the upload expires. - - Returns: - The created cache entry. - """ + """Cache an uploaded file.""" file_hash = _compute_file_hash(file) return await self.aset_by_hash( file_hash=file_hash, @@ -282,7 +361,6 @@ async def aset_by_hash( The created cache entry. """ await self._evict_if_needed() - key = _make_key(file_hash, provider) now = datetime.now(timezone.utc) @@ -299,7 +377,7 @@ async def aset_by_hash( if expires_at is not None: ttl = max(0, int((expires_at - now).total_seconds())) - await self._cache.set(key, cached, ttl=ttl) + await self._backend.set(key, cached, ttl=ttl) self._track_key(provider, key) logger.debug(f"Cached upload: {file_id} for provider {provider}") return cached @@ -316,9 +394,7 @@ async def aremove(self, file: FileInput, provider: ProviderType) -> bool: """ file_hash = _compute_file_hash(file) key = _make_key(file_hash, provider) - - result = await self._cache.delete(key) - removed = bool(result > 0 if isinstance(result, int) else result) + removed = await self._backend.delete(key) if removed: self._untrack_key(provider, key) return removed @@ -335,11 +411,10 @@ async def aremove_by_file_id(self, file_id: str, provider: ProviderType) -> bool """ if provider not in self._provider_keys: return False - for key in list(self._provider_keys[provider]): - cached = await self._cache.get(key) - if isinstance(cached, CachedUpload) and cached.file_id == file_id: - await self._cache.delete(key) + cached = await self._backend.get(key) + if cached is not None and cached.file_id == file_id: + await self._backend.delete(key) self._untrack_key(provider, key) return True return False @@ -351,17 +426,13 @@ async def aclear_expired(self) -> int: Number of entries removed. """ removed = 0 - for provider, keys in list(self._provider_keys.items()): for key in list(keys): - cached = await self._cache.get(key) - if cached is None or ( - isinstance(cached, CachedUpload) and cached.is_expired() - ): - await self._cache.delete(key) + cached = await self._backend.get(key) + if cached is None or cached.is_expired(): + await self._backend.delete(key) self._untrack_key(provider, key) removed += 1 - if removed > 0: logger.debug(f"Cleared {removed} expired cache entries") return removed @@ -373,9 +444,12 @@ async def aclear(self) -> int: Number of entries cleared. """ count = sum(len(keys) for keys in self._provider_keys.values()) - await self._cache.clear(namespace=self.namespace) + # Delete all tracked keys individually (works for all backends) + for keys in self._provider_keys.values(): + for key in keys: + await self._backend.delete(key) self._provider_keys.clear() - + self._key_access_order.clear() if count > 0: logger.debug(f"Cleared {count} cache entries") return count @@ -391,14 +465,17 @@ async def aget_all_for_provider(self, provider: ProviderType) -> list[CachedUplo """ if provider not in self._provider_keys: return [] - results: list[CachedUpload] = [] for key in list(self._provider_keys[provider]): - cached = await self._cache.get(key) - if isinstance(cached, CachedUpload) and not cached.is_expired(): + cached = await self._backend.get(key) + if cached is not None and not cached.is_expired(): results.append(cached) return results + # ------------------------------------------------------------------ + # Sync wrappers + # ------------------------------------------------------------------ + @staticmethod def _run_sync(coro: Any) -> Any: """Run an async coroutine from sync context without blocking event loop.""" @@ -489,11 +566,7 @@ def __len__(self) -> int: return sum(len(keys) for keys in self._provider_keys.values()) def get_providers(self) -> builtins.set[ProviderType]: - """Get all provider names that have cached entries. - - Returns: - Set of provider names. - """ + """Get all provider names that have cached entries.""" return builtins.set(self._provider_keys.keys()) @@ -506,17 +579,7 @@ def get_upload_cache( cache_type: str = "memory", **cache_kwargs: Any, ) -> UploadCache: - """Get or create the default upload cache. - - Args: - ttl: Default TTL in seconds. - namespace: Cache namespace. - cache_type: Backend type ("memory" or "redis"). - **cache_kwargs: Additional args for cache backend. - - Returns: - The upload cache instance. - """ + """Get or create the default upload cache.""" global _default_cache if _default_cache is None: _default_cache = UploadCache( diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index fc48e6661d..ff40b459f6 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -111,6 +111,9 @@ file-processing = [ qdrant-edge = [ "qdrant-edge-py>=0.6.0", ] +valkey = [ + "valkey-glide>=1.3.0", +] [project.scripts] diff --git a/lib/crewai/src/crewai/a2a/utils/agent_card.py b/lib/crewai/src/crewai/a2a/utils/agent_card.py index df5886988e..d3a47e2fef 100644 --- a/lib/crewai/src/crewai/a2a/utils/agent_card.py +++ b/lib/crewai/src/crewai/a2a/utils/agent_card.py @@ -13,8 +13,12 @@ from typing import TYPE_CHECKING from a2a.client.errors import A2AClientHTTPError -from a2a.types import AgentCapabilities, AgentCard, AgentSkill -from aiocache import cached # type: ignore[import-untyped] +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, +) +from aiocache import cached, caches # type: ignore[import-untyped] from aiocache.serializers import PickleSerializer # type: ignore[import-untyped] import httpx @@ -32,6 +36,7 @@ A2AAuthenticationFailedEvent, A2AConnectionErrorEvent, ) +from crewai.utilities.cache_config import get_aiocache_config if TYPE_CHECKING: @@ -40,6 +45,18 @@ from crewai.task import Task +_cache_configured = False + + +def _ensure_cache_configured() -> None: + """Configure aiocache on first use (lazy initialization).""" + global _cache_configured + if _cache_configured: + return + caches.set_config(get_aiocache_config()) + _cache_configured = True + + def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str: """Get TLS verify parameter from auth scheme. @@ -191,6 +208,7 @@ async def afetch_agent_card( else: auth_hash = _auth_store.compute_key("none", "") _auth_store.set(auth_hash, auth) + _ensure_cache_configured() agent_card: AgentCard = await _afetch_agent_card_cached( endpoint, auth_hash, timeout ) diff --git a/lib/crewai/src/crewai/a2a/utils/task.py b/lib/crewai/src/crewai/a2a/utils/task.py index 6af935bb35..478c5c5f8a 100644 --- a/lib/crewai/src/crewai/a2a/utils/task.py +++ b/lib/crewai/src/crewai/a2a/utils/task.py @@ -9,9 +9,8 @@ from functools import wraps import json import logging -import os +import threading from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast -from urllib.parse import urlparse from a2a.server.agent_execution import RequestContext from a2a.server.events import EventQueue @@ -38,7 +37,6 @@ from a2a.utils.errors import ServerError from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped] from pydantic import BaseModel -from typing_extensions import TypedDict from crewai.a2a.utils.agent_card import _get_server_config from crewai.a2a.utils.content_type import validate_message_parts @@ -50,12 +48,18 @@ A2AServerTaskStartedEvent, ) from crewai.task import Task +from crewai.utilities.cache_config import ( + get_aiocache_config, + parse_cache_url, + use_valkey_cache, +) from crewai.utilities.pydantic_schema_utils import create_model_from_schema if TYPE_CHECKING: from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry from crewai.agent import Agent + from crewai.memory.storage.valkey_cache import ValkeyCache logger = logging.getLogger(__name__) @@ -64,52 +68,49 @@ T = TypeVar("T") -class RedisCacheConfig(TypedDict, total=False): - """Configuration for aiocache Redis backend.""" +# --------------------------------------------------------------------------- +# Lazy cache initialisation +# --------------------------------------------------------------------------- - cache: str - endpoint: str - port: int - db: int - password: str +_task_cache: ValkeyCache | None = None +_cache_initialized = False +_cache_init_lock = threading.Lock() +# Configure aiocache at import time (matches upstream behaviour). +# This is safe — it only touches aiocache, no optional dependencies. +# The Valkey path is deferred to _ensure_task_cache() to avoid importing +# valkey-glide at module level (it may not be installed). +if not use_valkey_cache(): + caches.set_config(get_aiocache_config()) -def _parse_redis_url(url: str) -> RedisCacheConfig: - """Parse a Redis URL into aiocache configuration. - Args: - url: Redis connection URL (e.g., redis://localhost:6379/0). +def _ensure_task_cache() -> None: + """Initialise the Valkey task cache on first use (thread-safe). - Returns: - Configuration dict for aiocache.RedisCache. + For the aiocache path, configuration happens at module level above. + This function only needs to run for the Valkey path. """ - parsed = urlparse(url) - config: RedisCacheConfig = { - "cache": "aiocache.RedisCache", - "endpoint": parsed.hostname or "localhost", - "port": parsed.port or 6379, - } - if parsed.path and parsed.path != "/": - try: - config["db"] = int(parsed.path.lstrip("/")) - except ValueError: - pass - if parsed.password: - config["password"] = parsed.password - return config - + global _task_cache, _cache_initialized + if _cache_initialized: + return + + with _cache_init_lock: + if _cache_initialized: + return + + if use_valkey_cache(): + from crewai.memory.storage.valkey_cache import ValkeyCache + + conn = parse_cache_url() or {} + _task_cache = ValkeyCache( + host=conn.get("host", "localhost"), + port=conn.get("port", 6379), + db=conn.get("db", 0), + password=conn.get("password"), + default_ttl=3600, + ) -_redis_url = os.environ.get("REDIS_URL") - -caches.set_config( - { - "default": _parse_redis_url(_redis_url) - if _redis_url - else { - "cache": "aiocache.SimpleMemoryCache", - } - } -) + _cache_initialized = True def cancellable( @@ -130,6 +131,8 @@ def cancellable( @wraps(fn) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: """Wrap function with cancellation monitoring.""" + _ensure_task_cache() + context: RequestContext | None = None for arg in args: if isinstance(arg, RequestContext): @@ -142,10 +145,19 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: return await fn(*args, **kwargs) task_id = context.task_id - cache = caches.get("default") - async def poll_for_cancel() -> bool: - """Poll cache for cancellation flag.""" + async def poll_for_cancel_valkey() -> bool: + """Poll ValkeyCache for cancellation flag.""" + while True: + if _task_cache is not None and await _task_cache.get( + f"cancel:{task_id}" + ): + return True + await asyncio.sleep(0.1) + + async def poll_for_cancel_aiocache() -> bool: + """Poll aiocache for cancellation flag.""" + cache = caches.get("default") while True: if await cache.get(f"cancel:{task_id}"): return True @@ -153,8 +165,14 @@ async def poll_for_cancel() -> bool: async def watch_for_cancel() -> bool: """Watch for cancellation events via pub/sub or polling.""" + if _task_cache is not None: + # ValkeyCache: use polling (pub/sub not implemented yet) + return await poll_for_cancel_valkey() + + # aiocache: use pub/sub if Redis, otherwise poll + cache = caches.get("default") if isinstance(cache, SimpleMemoryCache): - return await poll_for_cancel() + return await poll_for_cancel_aiocache() try: client = cache.client @@ -168,7 +186,7 @@ async def watch_for_cancel() -> bool: "Cancel watcher Redis error, falling back to polling", extra={"task_id": task_id, "error": str(e)}, ) - return await poll_for_cancel() + return await poll_for_cancel_aiocache() return False execute_task = asyncio.create_task(fn(*args, **kwargs)) @@ -190,7 +208,12 @@ async def watch_for_cancel() -> bool: cancel_watch.cancel() return execute_task.result() finally: - await cache.delete(f"cancel:{task_id}") + # Clean up cancellation flag + if _task_cache is not None: + await _task_cache.delete(f"cancel:{task_id}") + else: + cache = caches.get("default") + await cache.delete(f"cancel:{task_id}") return wrapper @@ -475,6 +498,8 @@ async def cancel( if task_id is None or context_id is None: raise ServerError(InvalidParamsError(message="task_id and context_id required")) + _ensure_task_cache() + if context.current_task and context.current_task.status.state in ( TaskState.completed, TaskState.failed, @@ -482,11 +507,16 @@ async def cancel( ): return context.current_task - cache = caches.get("default") - - await cache.set(f"cancel:{task_id}", True, ttl=3600) - if not isinstance(cache, SimpleMemoryCache): - await cache.client.publish(f"cancel:{task_id}", "cancel") + if _task_cache is not None: + # Use ValkeyCache + await _task_cache.set(f"cancel:{task_id}", True, ttl=3600) + # Note: pub/sub not implemented for ValkeyCache yet, relies on polling + else: + # Use aiocache + cache = caches.get("default") + await cache.set(f"cancel:{task_id}", True, ttl=3600) + if not isinstance(cache, SimpleMemoryCache): + await cache.client.publish(f"cancel:{task_id}", "cancel") await event_queue.enqueue_event( TaskStatusUpdateEvent( diff --git a/lib/crewai/src/crewai/memory/encoding_flow.py b/lib/crewai/src/crewai/memory/encoding_flow.py index acd025d553..ac753d26d0 100644 --- a/lib/crewai/src/crewai/memory/encoding_flow.py +++ b/lib/crewai/src/crewai/memory/encoding_flow.py @@ -18,7 +18,7 @@ from typing import Any from uuid import uuid4 -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from crewai.flow.flow import Flow, listen, start from crewai.memory.analyze import ( @@ -68,6 +68,31 @@ class ItemState(BaseModel): plan: ConsolidationPlan | None = None result_record: MemoryRecord | None = None + @field_validator("similar_records", "result_record", mode="before") + @classmethod + def ensure_embedding_is_list(cls, v: Any) -> Any: + """Ensure MemoryRecord embeddings are list[float], not bytes.""" + if v is None: + return None + if isinstance(v, list): + # Process list of MemoryRecords + for record in v: + if isinstance(record, MemoryRecord) and isinstance( + record.embedding, bytes + ): + import numpy as np + + arr = np.frombuffer(record.embedding, dtype=np.float32) + record.embedding = [float(x) for x in arr] + return v + if isinstance(v, MemoryRecord) and isinstance(v.embedding, bytes): + # Process single MemoryRecord + import numpy as np + + arr = np.frombuffer(v.embedding, dtype=np.float32) + v.embedding = [float(x) for x in arr] + return v + class EncodingState(BaseModel): """Batch-level state for the encoding flow.""" diff --git a/lib/crewai/src/crewai/memory/storage/valkey_cache.py b/lib/crewai/src/crewai/memory/storage/valkey_cache.py new file mode 100644 index 0000000000..b713655764 --- /dev/null +++ b/lib/crewai/src/crewai/memory/storage/valkey_cache.py @@ -0,0 +1,189 @@ +"""Valkey-based cache implementation for CrewAI. + +This module provides a simple cache interface using Valkey-GLIDE client +for caching operations with optional TTL support. It replaces Redis usage +in A2A communication, file uploads, and agent card caching. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from glide import GlideClient, GlideClientConfiguration, NodeAddress + + +_logger = logging.getLogger(__name__) + + +class ValkeyCache: + """Simple cache interface using Valkey-GLIDE client. + + Provides get/set/delete/exists operations for caching with optional TTL. + Uses JSON serialization for complex values and lazy client initialization. + + Example: + >>> cache = ValkeyCache(host="localhost", port=6379) + >>> await cache.set("key", {"data": "value"}, ttl=3600) + >>> value = await cache.get("key") + >>> await cache.delete("key") + """ + + def __init__( + self, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: str | None = None, + default_ttl: int | None = None, + ) -> None: + """Initialize Valkey cache. + + Args: + host: Valkey server hostname. + port: Valkey server port. + db: Database number to use. + password: Optional password for authentication. + default_ttl: Default TTL in seconds (None = no expiration). + """ + self._host = host + self._port = port + self._db = db + self._password = password + self._default_ttl = default_ttl + self._client: GlideClient | None = None + + async def _get_client(self) -> GlideClient: + """Get or create Valkey client (lazy initialization). + + Returns: + Initialized GlideClient instance. + + Raises: + RuntimeError: If connection to Valkey fails. + TimeoutError: If connection attempt times out (10 seconds). + """ + import asyncio + + if self._client is None: + host = self._host + port = self._port + db = self._db + try: + from glide import ServerCredentials + + config = GlideClientConfiguration( + addresses=[NodeAddress(host, port)], + database_id=db, + credentials=( + ServerCredentials(password=self._password) + if self._password + else None + ), + ) + + # Add connection timeout (10 seconds) + try: + self._client = await asyncio.wait_for( + GlideClient.create(config), timeout=10.0 + ) + except asyncio.TimeoutError as e: + _logger.error("Connection timeout connecting to Valkey") + raise TimeoutError( + "Connection timeout to Valkey. " + "Ensure Valkey is running and accessible." + ) from e + + _logger.info("Valkey cache client initialized") + except (TimeoutError, RuntimeError): + raise + except Exception as e: + _logger.error( + "Failed to create Valkey cache client: %s", type(e).__name__ + ) + raise RuntimeError( + "Cannot connect to Valkey. Check connection settings." + ) from e + + return self._client + + async def get(self, key: str) -> Any | None: + """Get value from cache. + + Args: + key: Cache key. + + Returns: + Cached value (deserialized from JSON) or None if not found. + """ + client = await self._get_client() + value = await client.get(key) + + if value is None: + return None + + try: + return json.loads(value) + except json.JSONDecodeError: + _logger.warning(f"Failed to deserialize cached value for key: {key}") + return None + + async def set( + self, + key: str, + value: Any, + ttl: int | None = None, + ) -> None: + """Set value in cache. + + Args: + key: Cache key. + value: Value to cache (will be serialized to JSON). + ttl: TTL in seconds (None uses default_ttl, 0 = no expiration). + """ + from glide import ExpirySet, ExpiryType + + client = await self._get_client() + serialized = json.dumps(value) + + ttl_to_use = ttl if ttl is not None else self._default_ttl + + if ttl_to_use and ttl_to_use > 0: + # Set with expiration using SET command with EX option + await client.set( + key, + serialized, + expiry=ExpirySet(ExpiryType.SEC, ttl_to_use), + ) + else: + await client.set(key, serialized) + + async def delete(self, key: str) -> None: + """Delete value from cache. + + Args: + key: Cache key to delete. + """ + client = await self._get_client() + await client.delete([key]) + + async def exists(self, key: str) -> bool: + """Check if key exists in cache. + + Args: + key: Cache key to check. + + Returns: + True if key exists, False otherwise. + """ + client = await self._get_client() + result = await client.exists([key]) + return result > 0 + + async def close(self) -> None: + """Close Valkey client connection.""" + if self._client: + await self._client.close() + self._client = None + _logger.debug("Valkey cache client closed") diff --git a/lib/crewai/src/crewai/memory/storage/valkey_storage.py b/lib/crewai/src/crewai/memory/storage/valkey_storage.py new file mode 100644 index 0000000000..9cc05effb9 --- /dev/null +++ b/lib/crewai/src/crewai/memory/storage/valkey_storage.py @@ -0,0 +1,1838 @@ +"""Valkey-backed storage for the unified memory system. + +This module provides ValkeyStorage, a distributed storage backend that implements +the StorageBackend protocol using Valkey-GLIDE as the underlying data store. +It supports vector similarity search via Valkey Search module and provides +efficient indexing for scope, category, and metadata filtering. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Coroutine +from datetime import datetime +import json +import logging +import threading +from typing import Any + +from glide import ( + BackoffStrategy, + ConfigurationError, + ConnectionError, + DataType, + DistanceMetricType, + Field, + FtCreateOptions, + FtSearchLimit, + FtSearchOptions, + GlideClient, + GlideClientConfiguration, + NodeAddress, + NumericField, + RangeByIndex, + RangeByScore, + ReturnField, + ScoreBoundary, + ServerCredentials, + TagField, + VectorAlgorithm, + VectorField, + VectorFieldAttributesFlat, + VectorFieldAttributesHnsw, + VectorType, + ft, +) +import numpy as np + +from crewai.memory.types import MemoryRecord, ScopeInfo + + +_logger = logging.getLogger(__name__) + + +class ValkeyStorage: + """Valkey-backed storage for the unified memory system. + + Provides distributed, high-performance storage using Valkey-GLIDE client. + Implements the StorageBackend protocol with both sync and async methods. + + This implementation supports standalone Valkey mode only. Cluster mode is + not supported in this version. + + Example: + >>> storage = ValkeyStorage(host="localhost", port=6379) + >>> record = MemoryRecord(content="test", embedding=[0.1, 0.2]) + >>> storage.save([record]) + >>> retrieved = storage.get_record(record.id) + """ + + # ------------------------------------------------------------------ + # Key helpers — single source of truth for Valkey key patterns. + # Note: dynamic parts (scope, category, metadata values) are not + # encoded here because Valkey keys are opaque byte strings and the + # ':' delimiter is only meaningful to our own code. If cross-tenant + # isolation is required, callers should validate inputs before + # passing them to the storage layer. + # ------------------------------------------------------------------ + + @staticmethod + def _record_key(record_id: str) -> str: + return f"record:{record_id}" + + @staticmethod + def _scope_key(scope: str) -> str: + return f"scope:{scope}" + + @staticmethod + def _category_key(category: str) -> str: + return f"category:{category}" + + @staticmethod + def _metadata_key(key: str, value: str) -> str: + return f"metadata:{key}:{value}" + + def __init__( + self, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: str | None = None, + use_tls: bool = False, + tls_ca_cert_path: str | None = None, + tls_client_cert_path: str | None = None, + tls_client_key_path: str | None = None, + vector_dim: int = 1536, + index_algorithm: str = "HNSW", + ) -> None: + """Initialize Valkey storage with connection parameters and vector index config. + + Note: This implementation supports standalone Valkey mode only. + Cluster mode is not supported in this version. + + TLS Support: Basic TLS encryption is supported via ``use_tls=True``. + Custom CA and client certificates are defined in the GLIDE Rust core + but not yet exposed in the Python bindings. The ``tls_*_path`` + parameters are accepted for forward compatibility and will be wired + in once the Python client adds support. + + Args: + host: Valkey server hostname. + port: Valkey server port. + db: Database number to use (standalone mode only). + password: Optional password for authentication. + use_tls: Enable TLS/SSL encryption for connections. + tls_ca_cert_path: Path to CA certificate (forward-compat, not yet wired). + tls_client_cert_path: Path to client certificate (forward-compat, not yet wired). + tls_client_key_path: Path to client key (forward-compat, not yet wired). + vector_dim: Dimension of embedding vectors (default 1536 for OpenAI). + index_algorithm: Vector index algorithm ("HNSW" or "FLAT"). + """ + self._host = host + self._port = port + self._db = db + self._password = password + self._use_tls = use_tls + self._vector_dim = vector_dim + self._index_algorithm = index_algorithm + self._client: GlideClient | None = None + self._index_created = False + self._sync_lock = threading.Lock() + + # Reject TLS cert paths until the GLIDE Python client exposes them + if tls_ca_cert_path or tls_client_cert_path or tls_client_key_path: + raise NotImplementedError( + "Custom TLS certificates are not yet supported by the " + "valkey-glide Python client. Use use_tls=True for basic " + "TLS with system CA certificates." + ) + + # Write lock for compatibility with memory system + # Note: Valkey handles concurrency at the server level, so this is a no-op lock + self._write_lock = threading.RLock() + + async def _get_client(self) -> GlideClient: + """Get or create Valkey client with lazy initialization. + + Returns: + Initialized GlideClient instance. + + Raises: + RuntimeError: If connection to Valkey fails. + TimeoutError: If connection attempt times out (10 seconds). + """ + if self._client is None: + try: + # Build node address with explicit host and port + node = NodeAddress(host=self._host, port=self._port) + + # Build configuration + config = GlideClientConfiguration( + addresses=[node], + database_id=self._db, + use_tls=self._use_tls, + credentials=( + ServerCredentials(password=self._password) + if self._password + else None + ), + request_timeout=2000, # 2 seconds for FT.SEARCH and other commands + reconnect_strategy=BackoffStrategy( + num_of_retries=5, + factor=200, # milliseconds + exponent_base=2, + ), + ) + + # Add connection timeout (10 seconds) + try: + self._client = await asyncio.wait_for( + GlideClient.create(config), timeout=10.0 + ) + except asyncio.TimeoutError as e: + _logger.error( + f"Connection timeout after 10 seconds to Valkey at {self._host}:{self._port}" + ) + raise TimeoutError( + f"Connection timeout to Valkey at {self._host}:{self._port}. " + "Ensure Valkey is running and accessible." + ) from e + + _logger.info( + f"Connected to Valkey at {self._host}:{self._port} (db={self._db}, tls={self._use_tls})" + ) + + except (ConfigurationError, ConnectionError) as e: + _logger.error(f"Failed to create Valkey client: {e}") + raise RuntimeError( + f"Cannot connect to Valkey at {self._host}:{self._port}" + ) from e + + return self._client + + @property + def write_lock(self) -> threading.RLock: + """Write lock for compatibility with memory system. + + Note: Valkey handles concurrency at the server level with atomic operations, + so this lock is primarily for API compatibility with other storage backends. + """ + return self._write_lock + + def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any: + """Bridge async operations to sync context. + + Uses a dedicated background thread with a persistent event loop so the + Valkey client (and its TCP connection) can be reused across calls. + + Concurrent sync callers are serialized via a lock to avoid overloading + the single-threaded background event loop (e.g. when the encoding flow + dispatches parallel searches from a ThreadPoolExecutor). + + Args: + coro: Coroutine to execute. + + Returns: + Result of the coroutine execution. + """ + with self._sync_lock: + bg_loop = self._get_or_create_loop() + future = asyncio.run_coroutine_threadsafe(coro, bg_loop) + return future.result() + + # ------------------------------------------------------------------ + # Persistent event-loop helpers + # ------------------------------------------------------------------ + # Class-level: a single background event loop shared by ALL ValkeyStorage + # instances. This is intentional — the loop is just an I/O scheduler and + # the glide client handles per-connection state internally. + # _bg_lock guards loop creation; _sync_lock (instance-level, set in + # __init__) serialises sync callers so they don't flood the loop. + # ------------------------------------------------------------------ + _bg_loop: asyncio.AbstractEventLoop | None = None + _bg_thread: threading.Thread | None = None + _bg_lock: threading.Lock = threading.Lock() + + @classmethod + def _get_or_create_loop(cls) -> asyncio.AbstractEventLoop: + """Return a long-lived event loop running on a background daemon thread.""" + if cls._bg_loop is not None and cls._bg_loop.is_running(): + return cls._bg_loop + + with cls._bg_lock: + # Double-check after acquiring lock + if cls._bg_loop is not None and cls._bg_loop.is_running(): + return cls._bg_loop + + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=loop.run_forever, daemon=True, name="valkey-io" + ) + thread.start() + cls._bg_loop = loop + cls._bg_thread = thread + return loop + + async def __aenter__(self) -> ValkeyStorage: + """Async context manager entry.""" + await self._get_client() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + if self._client: + await self._client.close() + self._client = None + + def __del__(self) -> None: + """Cleanup client connection on deletion.""" + if self._client: + try: + bg_loop = type(self)._bg_loop + if bg_loop is not None and bg_loop.is_running(): + # Schedule close on the background loop — more reliable than + # create_task which can be GC'd before it runs. + asyncio.run_coroutine_threadsafe(self._client.close(), bg_loop) + else: + close_result = self._client.close() + if asyncio.iscoroutine(close_result): + asyncio.run(close_result) + except Exception as e: + _logger.debug(f"Error closing client during cleanup: {e}") + + def _embedding_to_bytes(self, embedding: list[float]) -> bytes: + """Convert embedding list to binary format for Valkey Search. + + Args: + embedding: List of floats representing the embedding vector. + + Returns: + Binary representation as float32 array. + """ + return np.array(embedding, dtype=np.float32).tobytes() + + def _bytes_to_embedding(self, data: bytes) -> list[float]: + """Convert binary format back to embedding list. + + Args: + data: Binary data from Valkey. + + Returns: + List of floats representing the embedding vector. + """ + arr = np.frombuffer(data, dtype=np.float32) + return [float(x) for x in arr] + + def _record_to_dict(self, record: MemoryRecord) -> dict[str, str | bytes]: + """Convert MemoryRecord to Valkey hash fields. + + Args: + record: Memory record to serialize. + + Returns: + Dictionary of field names to string/bytes values. + + Raises: + ValueError: If serialization fails for any field. + """ + try: + result: dict[str, str | bytes] = { + "id": record.id, + "content": record.content, + "scope": record.scope, + "categories": ",".join(record.categories) + if record.categories + else "", # TAG field format + "metadata": json.dumps(record.metadata), + "importance": str(record.importance), + "created_at": record.created_at.isoformat(), + "last_accessed": record.last_accessed.isoformat(), + "source": record.source or "", + "private": "true" if record.private else "false", + } + + # Add embedding as binary vector field if present + if record.embedding: + result["embedding"] = self._embedding_to_bytes(record.embedding) + else: + result["embedding"] = b"" # Empty bytes for no embedding + + return result + except (TypeError, ValueError) as e: + raise ValueError(f"Failed to serialize record {record.id}: {e}") from e + + def _dict_to_record( + self, data: dict[str, Any] | dict[bytes, bytes] + ) -> MemoryRecord | None: + """Convert Valkey hash fields to MemoryRecord. + + Args: + data: Dictionary of field names to values from Valkey (may be bytes or str keys/values). + + Returns: + Reconstructed MemoryRecord, or None if deserialization fails. + """ + try: + # Convert bytes keys/values to strings if needed + str_data: dict[str, Any] = {} + for key, value in data.items(): + str_key = key.decode("utf-8") if isinstance(key, bytes) else key + + # Handle value conversion - keep embedding as bytes + if isinstance(value, bytes): + if str_key == "embedding": + # Keep embedding as bytes - don't try to decode + str_data[str_key] = value + else: + # Try to decode other fields as UTF-8 + try: + str_data[str_key] = value.decode("utf-8") + except UnicodeDecodeError: + # Keep as bytes if decode fails + str_data[str_key] = value + else: + str_data[str_key] = value + + # Deserialize embedding if present + embedding: list[float] | None = None + embedding_data = str_data.get("embedding") + if embedding_data: + if isinstance(embedding_data, bytes): + if len(embedding_data) > 0: + embedding = self._bytes_to_embedding(embedding_data) + # else: empty bytes, leave embedding as None + elif isinstance(embedding_data, str) and embedding_data: + # Fallback for string representation + try: + embedding = json.loads(embedding_data) + except json.JSONDecodeError: + # Invalid JSON, leave as None + pass + + # Parse categories - handle both TAG format (comma-separated) and JSON format + categories_str = str_data.get("categories", "") + if categories_str: + if categories_str.startswith("["): + # JSON format (legacy) + categories = json.loads(categories_str) + else: + # TAG format (comma-separated) + categories = [ + c.strip() for c in categories_str.split(",") if c.strip() + ] + else: + categories = [] + + return MemoryRecord( + id=str_data["id"], + content=str_data["content"], + scope=str_data["scope"], + categories=categories, + metadata=json.loads(str_data["metadata"]), + importance=float(str_data["importance"]), + created_at=datetime.fromisoformat(str_data["created_at"]), + last_accessed=datetime.fromisoformat(str_data["last_accessed"]), + embedding=embedding, + source=str_data.get("source") or None, + private=str_data.get("private", "false").lower() == "true", + ) + except (KeyError, ValueError, TypeError) as e: + # Try to get ID from data for error logging + record_id = "unknown" + try: + if data: + # Try both bytes and str keys + id_value = data.get(b"id") if b"id" in data else data.get("id") # type: ignore[call-overload] + if id_value: + record_id = ( + id_value.decode("utf-8") + if isinstance(id_value, bytes) + else str(id_value) + ) + except Exception as id_error: + _logger.debug( + f"Could not extract record ID for error logging: {id_error}" + ) + _logger.error(f"Failed to deserialize record {record_id}: {e}") + return None + + async def _ensure_vector_index(self) -> None: + """Create Valkey Search vector index if it doesn't exist. + + Creates an index named 'memory_index' on record:* hashes with: + - Vector field for embeddings (HNSW or FLAT algorithm) + - TAG fields for scope and categories + - NUMERIC fields for created_at and importance + + Raises: + RuntimeError: If Valkey Search module is not available. + """ + if self._index_created: + return + + client = await self._get_client() + + try: + # Check if index already exists + existing = await ft.list(client) + names = { + i.decode("utf-8") if isinstance(i, bytes) else str(i) + for i in (existing or []) + } + if "memory_index" in names: + _logger.debug("Vector index 'memory_index' already exists") + self._index_created = True + return + except Exception as e: + _logger.debug("Could not list indexes, will attempt create: %s", e) + + try: + # Build vector field attributes using the concrete subclass + vector_attrs: VectorFieldAttributesHnsw | VectorFieldAttributesFlat + if self._index_algorithm == "HNSW": + algorithm = VectorAlgorithm.HNSW + vector_attrs = VectorFieldAttributesHnsw( + dimensions=self._vector_dim, + distance_metric=DistanceMetricType.COSINE, + type=VectorType.FLOAT32, + ) + else: + algorithm = VectorAlgorithm.FLAT + vector_attrs = VectorFieldAttributesFlat( + dimensions=self._vector_dim, + distance_metric=DistanceMetricType.COSINE, + type=VectorType.FLOAT32, + ) + + # Build schema + schema: list[Field] = [ + VectorField("embedding", algorithm, vector_attrs), + TagField("scope"), + TagField("categories", separator=","), + NumericField("created_at"), + NumericField("importance"), + ] + + # Create index using native ft.create + options = FtCreateOptions(DataType.HASH, prefixes=["record:"]) + await ft.create(client, "memory_index", schema, options) + + _logger.info( + "Created vector index 'memory_index' with %s algorithm (dim=%d)", + self._index_algorithm, + self._vector_dim, + ) + self._index_created = True + + except Exception as e: + error_msg = str(e).lower() + if "unknown command" in error_msg or "ft.create" in error_msg: + raise RuntimeError( + "Valkey Search module is not available. " + "Please ensure Valkey is running with the Search module loaded. " + "Use 'valkey/valkey-bundle:latest' Docker image or install the module separately." + ) from e + raise RuntimeError(f"Failed to create vector index: {e}") from e + + async def _update_indexes( + self, + record_id: str, + scope: str, + categories: list[str], + metadata: dict[str, Any], + timestamp: float, + ) -> None: + """Update all index structures for a record. + + Adds record ID to: + - Scope sorted set with timestamp score + - Category sets for all categories + - Metadata index sets for all metadata key-value pairs + + Args: + record_id: Unique identifier of the record. + scope: Hierarchical scope path (e.g., "/agent/task"). + categories: List of category names. + metadata: Dictionary of metadata key-value pairs. + timestamp: Unix timestamp for scope index score. + """ + client = await self._get_client() + + # Update scope index (sorted set with timestamp score) + # Handle root scope "/" as special case + scope_key = self._scope_key(scope) + await client.zadd(scope_key, {record_id: timestamp}) + + # Update category indexes (sets) + for category in categories: + category_key = self._category_key(category) + await client.sadd(category_key, [record_id]) + + # Update metadata indexes (sets for each key-value pair) + for key, value in metadata.items(): + # Convert value to string for consistent key naming + value_str = str(value) + metadata_key = self._metadata_key(key, value_str) + await client.sadd(metadata_key, [record_id]) + + async def _remove_from_indexes( + self, + record_id: str, + scope: str, + categories: list[str], + metadata: dict[str, Any], + ) -> None: + """Remove record from all index structures. + + Removes record ID from: + - Scope sorted set + - All category sets + - All metadata index sets + + Args: + record_id: Unique identifier of the record. + scope: Hierarchical scope path. + categories: List of category names. + metadata: Dictionary of metadata key-value pairs. + """ + client = await self._get_client() + + # Remove from scope index + scope_key = self._scope_key(scope) + await client.zrem(scope_key, [record_id]) + + # Remove from category indexes + for category in categories: + category_key = self._category_key(category) + await client.srem(category_key, [record_id]) + + # Remove from metadata indexes + for key, value in metadata.items(): + value_str = str(value) + metadata_key = self._metadata_key(key, value_str) + await client.srem(metadata_key, [record_id]) + + async def asave(self, records: list[MemoryRecord]) -> None: + """Save multiple records as a batch. + + Stores record fields in hash structure with key pattern "record:{id}". + Stores embedding as binary vector field in record hash for Valkey Search auto-indexing. + Updates scope sorted set, category sets, and metadata index sets. + + Note: + Operations are issued as individual commands, not wrapped in + MULTI/EXEC. Partial failures are possible under network errors. + + Args: + records: List of memory records to save. + + Raises: + ValueError: If serialization fails for any record. + RuntimeError: If Valkey connection fails. + """ + if not records: + return + + client = await self._get_client() + + # Ensure vector index exists before saving + await self._ensure_vector_index() + + # Build commands for atomic batch execution + for record in records: + record_key = self._record_key(record.id) + + # Convert record to hash fields (includes embedding as bytes) + record_dict = self._record_to_dict(record) + + # Store record hash (Valkey Search will auto-index it) + await client.hset( + record_key, + record_dict, # type: ignore[arg-type] # str keys are valid str|bytes + ) + + # Update all index structures + timestamp = record.created_at.timestamp() + await self._update_indexes( + record.id, + record.scope, + record.categories, + record.metadata, + timestamp, + ) + + def save(self, records: list[MemoryRecord]) -> None: + """Save multiple records atomically (sync wrapper). + + Args: + records: List of memory records to save. + + Raises: + ValueError: If serialization fails for any record. + RuntimeError: If Valkey connection fails or called from async context. + """ + self._run_async(self.asave(records)) + + def get_record(self, record_id: str) -> MemoryRecord | None: + """Retrieve record by ID. + + Fetches record hash from "record:{id}" key and deserializes all fields + including datetime, JSON, and boolean values. + + Args: + record_id: Unique identifier of the record to retrieve. + + Returns: + MemoryRecord if found, None if record doesn't exist or deserialization fails. + """ + result: MemoryRecord | None = self._run_async(self._aget_record(record_id)) + return result + + async def _aget_record(self, record_id: str) -> MemoryRecord | None: + """Retrieve record by ID (async implementation). + + Args: + record_id: Unique identifier of the record to retrieve. + + Returns: + MemoryRecord if found, None if record doesn't exist or deserialization fails. + """ + client = await self._get_client() + record_key = self._record_key(record_id) + + try: + # Fetch all fields from record hash + data = await client.hgetall(record_key) + + if not data: + # Record doesn't exist + return None + + # Deserialize to MemoryRecord + return self._dict_to_record(data) + + except Exception as e: + _logger.error(f"Error retrieving record {record_id}: {e}") + return None + + def update(self, record: MemoryRecord) -> None: + """Update existing record or create new one. + + Preserves created_at timestamp from original record if it exists. + Updates last_accessed timestamp to current time. + Removes record from old indexes and adds to new indexes atomically. + + Args: + record: Memory record to update. + + Raises: + ValueError: If serialization fails. + RuntimeError: If Valkey connection fails or called from async context. + """ + self._run_async(self._aupdate(record)) + + async def _aupdate(self, record: MemoryRecord) -> None: + """Update existing record or create new one (async implementation). + + Args: + record: Memory record to update. + """ + client = await self._get_client() + record_key = self._record_key(record.id) + + # Fetch existing record to preserve created_at and get old index values + existing_data = await client.hgetall(record_key) + + if existing_data: + # Convert bytes to strings for parsing (skip embedding which is binary) + str_data: dict[str, str] = {} + for key, value in existing_data.items(): + str_key = key.decode("utf-8") if isinstance(key, bytes) else key + # Skip embedding field - it's binary data, not UTF-8 + if str_key == "embedding": + continue + # Handle other binary fields gracefully + if isinstance(value, bytes): + try: + str_value = value.decode("utf-8") + except UnicodeDecodeError: + continue # Skip fields that can't be decoded + else: + str_value = value + str_data[str_key] = str_value + + # Preserve created_at from existing record + try: + original_created_at = datetime.fromisoformat(str_data["created_at"]) + record.created_at = original_created_at + except (KeyError, ValueError) as e: + _logger.warning( + f"Could not preserve created_at for record {record.id}: {e}" + ) + + # Update last_accessed to current time + record.last_accessed = datetime.now() + + # Parse old values for index cleanup + try: + old_scope = str_data.get("scope", "") + # Handle both TAG format (comma-separated) and JSON format (legacy) + categories_str = str_data.get("categories", "") + if categories_str.startswith("["): + old_categories = json.loads(categories_str) + else: + old_categories = [ + c.strip() for c in categories_str.split(",") if c.strip() + ] + old_metadata = json.loads(str_data.get("metadata", "{}")) + except (json.JSONDecodeError, ValueError) as e: + _logger.warning( + f"Could not parse old index values for record {record.id}: {e}" + ) + old_scope = "" + old_categories = [] + old_metadata = {} + + # Remove from old indexes + await self._remove_from_indexes( + record.id, old_scope, old_categories, old_metadata + ) + + # Convert record to hash fields + record_dict = self._record_to_dict(record) + + # Store updated record hash + await client.hset( + record_key, + record_dict, # type: ignore[arg-type] # str keys are valid str|bytes + ) + + # Add to new indexes + timestamp = record.created_at.timestamp() + await self._update_indexes( + record.id, record.scope, record.categories, record.metadata, timestamp + ) + + async def adelete( + self, + scope_prefix: str | None = None, + categories: list[str] | None = None, + record_ids: list[str] | None = None, + older_than: datetime | None = None, + metadata_filter: dict[str, Any] | None = None, + ) -> int: + """Delete records matching criteria. + + Supports deletion by record_ids, scope_prefix, categories, older_than, metadata_filter. + Multiple criteria are combined with AND logic. + + Note: + Operations are issued as individual commands, not wrapped in + MULTI/EXEC. Partial failures are possible under network errors. + + Args: + scope_prefix: Delete records in scope and subscopes. + categories: Delete records matching any of these categories. + record_ids: List of specific record IDs to delete. + older_than: Delete records created before this datetime. + metadata_filter: Delete records matching metadata key-value pairs. + + Returns: + Count of deleted records. + + Raises: + RuntimeError: If Valkey connection fails. + """ + client = await self._get_client() + + # Step 1: Identify records to delete based on criteria + ids_to_delete: set[str] = set() + + # Filter by record_ids + if record_ids: + ids_to_delete.update(record_ids) + + # Filter by scope_prefix + if scope_prefix is not None: + scope_ids = await self._find_records_by_scope(scope_prefix) + if ids_to_delete: + ids_to_delete &= set(scope_ids) # AND logic + else: + ids_to_delete.update(scope_ids) + + # Filter by categories + if categories: + category_ids = await self._find_records_by_categories(categories) + if ids_to_delete: + ids_to_delete &= set(category_ids) # AND logic + else: + ids_to_delete.update(category_ids) + + # Filter by older_than + if older_than is not None: + old_ids = await self._find_records_older_than(older_than) + if ids_to_delete: + ids_to_delete &= set(old_ids) # AND logic + else: + ids_to_delete.update(old_ids) + + # Filter by metadata + if metadata_filter: + metadata_ids = await self._find_records_by_metadata(metadata_filter) + if ids_to_delete: + ids_to_delete &= set(metadata_ids) # AND logic + else: + ids_to_delete.update(metadata_ids) + + # If no criteria specified, delete nothing + if not ids_to_delete: + return 0 + + # Step 2: Fetch record data to identify which indexes to clean + records_data = await self._fetch_records_for_deletion(list(ids_to_delete)) + + # Step 3: Delete records and clean indexes + for record_id, data in records_data.items(): + record_key = self._record_key(record_id) + + # Delete record hash (Valkey Search auto-removes from vector index) + await client.delete([record_key]) + + # Remove from all index structures + await self._remove_from_indexes( + record_id, data["scope"], data["categories"], data["metadata"] + ) + + return len(records_data) + + def delete( + self, + scope_prefix: str | None = None, + categories: list[str] | None = None, + record_ids: list[str] | None = None, + older_than: datetime | None = None, + metadata_filter: dict[str, Any] | None = None, + ) -> int: + """Delete records matching criteria (sync wrapper). + + Args: + scope_prefix: Delete records in scope and subscopes. + categories: Delete records matching any of these categories. + record_ids: List of specific record IDs to delete. + older_than: Delete records created before this datetime. + metadata_filter: Delete records matching metadata key-value pairs. + + Returns: + Count of deleted records. + + Raises: + RuntimeError: If Valkey connection fails or called from async context. + """ + result: int = self._run_async( + self.adelete( + scope_prefix=scope_prefix, + categories=categories, + record_ids=record_ids, + older_than=older_than, + metadata_filter=metadata_filter, + ) + ) + return result + + async def _find_records_by_scope(self, scope_prefix: str) -> list[str]: + """Find all record IDs in scope and subscopes. + + Args: + scope_prefix: Scope path prefix to match. + + Returns: + List of record IDs in matching scopes. + """ + client = await self._get_client() + record_ids: set[str] = set() + + # Scan for all scope keys + cursor: str | bytes = "0" + while True: + result = await client.scan(cursor, match="scope:*", count=1000) + cursor_new: str | bytes = result[0] # type: ignore[assignment] + keys: list[bytes] = result[1] # type: ignore[assignment] + + for key_bytes in keys: + # Extract scope path from key + key_str = ( + key_bytes.decode("utf-8") + if isinstance(key_bytes, bytes) + else key_bytes + ) + scope_path = key_str.split(":", 1)[1] if ":" in key_str else "" + + # Check if scope matches prefix + if scope_path.startswith(scope_prefix): + # Get all record IDs in this scope + scope_key = ( + key_bytes.decode("utf-8") + if isinstance(key_bytes, bytes) + else key_bytes + ) + members_result = await client.zrange(scope_key, RangeByIndex(0, -1)) + # Convert bytes to strings + record_ids.update( + m.decode("utf-8") if isinstance(m, bytes) else str(m) + for m in members_result + ) + + # Check if cursor is 0 (scan complete) + cursor_str = ( + cursor_new.decode("utf-8") + if isinstance(cursor_new, bytes) + else cursor_new + ) + if cursor_str == "0": + break + cursor = cursor_new + + return list(record_ids) + + async def _find_records_by_categories(self, categories: list[str]) -> list[str]: + """Find all record IDs matching any of the categories. + + Args: + categories: List of category names. + + Returns: + List of record IDs with any of the categories. + """ + client = await self._get_client() + record_ids: set[str] = set() + + for category in categories: + category_key = self._category_key(category) + members = await client.smembers(category_key) + # Convert bytes to strings + str_members = [ + m.decode("utf-8") if isinstance(m, bytes) else m for m in members + ] + record_ids.update(str_members) + + return list(record_ids) + + async def _find_records_older_than(self, older_than: datetime) -> list[str]: + """Find all record IDs created before the specified datetime. + + Args: + older_than: Datetime threshold. + + Returns: + List of record IDs created before older_than. + """ + client = await self._get_client() + record_ids: set[str] = set() + threshold = older_than.timestamp() + + # Scan all scope keys and filter by timestamp + cursor: str | bytes = "0" + while True: + result = await client.scan(cursor, match="scope:*", count=1000) + cursor_new: str | bytes = result[0] # type: ignore[assignment] + keys: list[bytes] = result[1] # type: ignore[assignment] + + for key_bytes in keys: + # Get records with score (timestamp) less than threshold + scope_key = ( + key_bytes.decode("utf-8") + if isinstance(key_bytes, bytes) + else key_bytes + ) + members_result = await client.zrange( + scope_key, + RangeByScore( + ScoreBoundary(0), + ScoreBoundary(threshold), + ), + ) + # Convert bytes to strings + record_ids.update( + m.decode("utf-8") if isinstance(m, bytes) else str(m) + for m in members_result + ) + + # Check if cursor is 0 (scan complete) + cursor_str = ( + cursor_new.decode("utf-8") + if isinstance(cursor_new, bytes) + else cursor_new + ) + if cursor_str == "0": + break + cursor = cursor_new + + return list(record_ids) + + async def _find_records_by_metadata( + self, metadata_filter: dict[str, Any] + ) -> list[str]: + """Find all record IDs matching all metadata criteria (AND logic). + + Args: + metadata_filter: Dictionary of metadata key-value pairs. + + Returns: + List of record IDs matching all metadata criteria. + """ + client = await self._get_client() + + # Get record IDs for each metadata criterion + metadata_sets: list[set[str]] = [] + for key, value in metadata_filter.items(): + value_str = str(value) + metadata_key = self._metadata_key(key, value_str) + members = await client.smembers(metadata_key) + # Convert bytes to strings + str_members = { + m.decode("utf-8") if isinstance(m, bytes) else m for m in members + } + metadata_sets.append(str_members) + + # Compute intersection (AND logic) + if not metadata_sets: + return [] + + result = metadata_sets[0] + for s in metadata_sets[1:]: + result &= s + + return list(result) + + async def _fetch_records_for_deletion( + self, record_ids: list[str] + ) -> dict[str, dict[str, Any]]: + """Fetch record data needed for index cleanup. + + Args: + record_ids: List of record IDs to fetch. + + Returns: + Dictionary mapping record ID to parsed record data. + """ + client = await self._get_client() + records_data: dict[str, dict[str, Any]] = {} + + for record_id in record_ids: + record_key = self._record_key(record_id) + data = await client.hgetall(record_key) + + if data: + # Convert bytes to strings (skip embedding which is binary) + str_data: dict[str, str] = {} + for key, value in data.items(): + str_key = key.decode("utf-8") if isinstance(key, bytes) else key + # Skip embedding field - it's binary + if str_key == "embedding": + continue + # Handle other binary fields gracefully + if isinstance(value, bytes): + try: + str_value = value.decode("utf-8") + except UnicodeDecodeError: + continue # Skip fields that can't be decoded + else: + str_value = value + str_data[str_key] = str_value + + # Parse categories and metadata for index cleanup + try: + # Parse categories — handle both TAG (comma-separated) and JSON format + categories_str = str_data.get("categories", "") + if categories_str and categories_str.startswith("["): + categories = json.loads(categories_str) + elif categories_str: + categories = [ + c.strip() for c in categories_str.split(",") if c.strip() + ] + else: + categories = [] + + parsed_data = { + "scope": str_data.get("scope", ""), + "categories": categories, + "metadata": json.loads(str_data.get("metadata", "{}")) + if str_data.get("metadata") + else {}, + } + records_data[record_id] = parsed_data + except (json.JSONDecodeError, ValueError) as e: + _logger.warning( + f"Could not parse record {record_id} for deletion: {e}" + ) + # Still delete the record, just skip index cleanup + records_data[record_id] = { + "scope": "", + "categories": [], + "metadata": {}, + } + + return records_data + + async def _vector_search( + self, + query_embedding: list[float], + scope_prefix: str | None = None, + categories: list[str] | None = None, + metadata_filter: dict[str, Any] | None = None, + limit: int = 10, + min_score: float = 0.0, + ) -> list[tuple[MemoryRecord, float]]: + """Perform server-side vector search using Valkey Search. + + Uses FT.SEARCH command with KNN query for vector similarity. + Applies filters for scope, categories, and metadata in the same query. + + Args: + query_embedding: Embedding vector for the query. + scope_prefix: Optional scope path prefix to filter results. + categories: Optional list of categories (OR logic). + metadata_filter: Optional metadata key-value pairs (AND logic). + limit: Maximum number of results to return. + min_score: Minimum similarity score threshold (0.0 to 1.0). + + Returns: + List of (MemoryRecord, score) tuples ordered by descending score. + + Raises: + RuntimeError: If Valkey Search module is not available. + """ + client = await self._get_client() + + # Ensure vector index exists + await self._ensure_vector_index() + + # Build query components + query_parts: list[str] = [] + + # Scope prefix filter + # Format: @scope:{prefix*} + if scope_prefix: + # Escape special characters in scope prefix + escaped_scope = self._escape_search_query(scope_prefix) + # For root scope "/", match everything + if scope_prefix == "/": + query_parts.append("*") + else: + query_parts.append(f"@scope:{{{escaped_scope}*}}") + + # Category filter (OR logic) + # Format: @categories:{cat1|cat2|cat3} + if categories: + # Escape each category and join with | + escaped_categories = [self._escape_search_query(cat) for cat in categories] + cat_query = "|".join(escaped_categories) + query_parts.append(f"@categories:{{{cat_query}}}") + + # Metadata filters (AND logic) + # Format: @{key}:{value} + if metadata_filter: + for key, value in metadata_filter.items(): + # Escape key and value + escaped_key = self._escape_search_query(key) + escaped_value = self._escape_search_query(str(value)) + query_parts.append(f"@{escaped_key}:{{{escaped_value}}}") + + # Combine filters + filter_query = " ".join(query_parts) if query_parts else "*" + + # Build KNN query with filters + # Format: (filter)=>[KNN limit @field $BLOB AS score] + # Note: Don't wrap single "*" in parentheses + if filter_query == "*": + query = f"{filter_query}=>[KNN {limit} @embedding $BLOB AS score]" + else: + query = f"({filter_query})=>[KNN {limit} @embedding $BLOB AS score]" + + # Prepare embedding blob for PARAMS + embedding_blob = self._embedding_to_bytes(query_embedding) + + # Build FT.SEARCH options + # Note: Vector search results are sorted by distance ascending (nearest first). + # We convert distance to similarity in _parse_search_result and re-sort descending. + return_fields = [ + ReturnField(field_identifier="id"), + ReturnField(field_identifier="content"), + ReturnField(field_identifier="scope"), + ReturnField(field_identifier="categories"), + ReturnField(field_identifier="metadata"), + ReturnField(field_identifier="importance"), + ReturnField(field_identifier="created_at"), + ReturnField(field_identifier="last_accessed"), + ReturnField(field_identifier="source"), + ReturnField(field_identifier="private"), + ReturnField(field_identifier="score"), + ] + + search_options = FtSearchOptions( + return_fields=return_fields, + params={"BLOB": embedding_blob}, + limit=FtSearchLimit(0, limit), + ) + + try: + # Execute native ft.search + result = await ft.search(client, "memory_index", query, search_options) + + # Native ft.search returns: [count, {key1: {fields...}, key2: {fields...}}] + if not result or not isinstance(result, list) or len(result) < 1: + return [] + + # First element is total count + total_count_raw = result[0] + if isinstance(total_count_raw, (int, str)): + total_count = int(total_count_raw) if total_count_raw else 0 + else: + total_count = 0 + if total_count == 0: + return [] + + # Parse documents from dict format + records: list[tuple[MemoryRecord, float]] = [] + if len(result) > 1 and isinstance(result[1], dict): + docs_dict = result[1] + for doc_fields in docs_dict.values(): + field_dict = self._normalize_field_dict(doc_fields) + parsed = self._parse_search_result(field_dict, min_score) + if parsed is not None: + records.append(parsed) + + # Sort by score descending (should already be sorted, but ensure) + records.sort(key=lambda x: x[1], reverse=True) + + return records + + except Exception as e: + error_msg = str(e).lower() + if "unknown command" in error_msg or "ft.search" in error_msg: + raise RuntimeError( + "Valkey Search module is not available. " + "Please ensure Valkey is running with the Search module loaded." + ) from e + _logger.error(f"Vector search failed: {e}") + raise + + # ------------------------------------------------------------------ + # Search result parsing helpers + # ------------------------------------------------------------------ + + @staticmethod + def _normalize_field_dict(raw: dict[Any, Any]) -> dict[str, Any]: + """Convert a raw field dict (possibly bytes keys/values) to str keys. + + Embedding values are kept as bytes; all other bytes values are decoded + to UTF-8 (falling back to raw bytes on decode errors). + """ + out: dict[str, Any] = {} + for key, value in raw.items(): + str_key = key.decode("utf-8") if isinstance(key, bytes) else str(key) + if isinstance(value, bytes): + if str_key == "embedding": + out[str_key] = value + else: + try: + out[str_key] = value.decode("utf-8") + except UnicodeDecodeError: + out[str_key] = value + else: + out[str_key] = value + return out + + def _parse_search_result( + self, + field_dict: dict[str, Any], + min_score: float, + ) -> tuple[MemoryRecord, float] | None: + """Extract score, apply min_score filter, and deserialize a search hit. + + Score is converted from cosine distance ([0, 2]) to similarity ([0, 1]) + and clamped to that range. + + Returns: + (MemoryRecord, score) or None if filtered out or deserialization fails. + """ + # Extract score — Valkey Search returns cosine distance + score = 0.0 + for score_key in ("__score", "score"): + if score_key in field_dict: + distance = float(field_dict[score_key]) + score = max(0.0, min(1.0, 1.0 - (distance / 2.0))) + break + + if score < min_score: + return None + + record = self._dict_to_record(field_dict) + if record is None: + return None + return (record, score) + + def _escape_search_query(self, text: str) -> str: + """Escape special characters in Valkey Search query. + + Valkey Search uses special characters: , . < > { } [ ] " ' : ; ! @ # $ % ^ & * ( ) - + = ~ | + + Args: + text: Text to escape. + + Returns: + Escaped text safe for use in search queries. + """ + # Characters that need escaping in Valkey Search queries. + # Note: both '=' and '>' are escaped individually, so the KNN + # clause delimiter '=>' becomes '\=\>' and cannot be injected. + special_chars = r",.<>{}[]\"':;!@#$%^&*()-+=~|" + for char in special_chars: + text = text.replace(char, f"\\{char}") + return text + + async def asearch( + self, + query_embedding: list[float], + scope_prefix: str | None = None, + categories: list[str] | None = None, + metadata_filter: dict[str, Any] | None = None, + limit: int = 10, + min_score: float = 0.0, + ) -> list[tuple[MemoryRecord, float]]: + """Search for memories by vector similarity (async). + + Uses Valkey Search module for server-side vector similarity computation. + Applies filters for scope, categories, and metadata in the same query. + + Args: + query_embedding: Embedding vector for the query. + scope_prefix: Optional scope path prefix to filter results. + categories: Optional list of categories (OR logic). + metadata_filter: Optional metadata key-value pairs (AND logic). + limit: Maximum number of results to return. + min_score: Minimum similarity score threshold (0.0 to 1.0). + + Returns: + List of (MemoryRecord, score) tuples ordered by relevance (descending score). + + Raises: + RuntimeError: If Valkey Search module is not available. + """ + return await self._vector_search( + query_embedding, + scope_prefix, + categories, + metadata_filter, + limit, + min_score, + ) + + def search( + self, + query_embedding: list[float], + scope_prefix: str | None = None, + categories: list[str] | None = None, + metadata_filter: dict[str, Any] | None = None, + limit: int = 10, + min_score: float = 0.0, + ) -> list[tuple[MemoryRecord, float]]: + """Search for memories by vector similarity (sync wrapper). + + Uses Valkey Search module for server-side vector similarity computation. + Applies filters for scope, categories, and metadata in the same query. + + Args: + query_embedding: Embedding vector for the query. + scope_prefix: Optional scope path prefix to filter results. + categories: Optional list of categories (OR logic). + metadata_filter: Optional metadata key-value pairs (AND logic). + limit: Maximum number of results to return. + min_score: Minimum similarity score threshold (0.0 to 1.0). + + Returns: + List of (MemoryRecord, score) tuples ordered by relevance (descending score). + + Raises: + RuntimeError: If Valkey Search module is not available or called from async context. + """ + result: list[tuple[MemoryRecord, float]] = self._run_async( + self.asearch( + query_embedding, + scope_prefix, + categories, + metadata_filter, + limit, + min_score, + ) + ) + return result + + def list_records( + self, + scope_prefix: str | None = None, + limit: int = 200, + offset: int = 0, + ) -> list[MemoryRecord]: + """List records in a scope, newest first. + + Uses scope sorted set ZRANGE with REV flag for newest-first ordering. + Supports scope_prefix filtering and pagination via limit and offset. + + Args: + scope_prefix: Optional scope path prefix to filter by. + limit: Maximum number of records to return (default 200). + offset: Number of records to skip for pagination (default 0). + + Returns: + List of MemoryRecord, ordered by created_at descending (newest first). + """ + result: list[MemoryRecord] = self._run_async( + self._alist_records(scope_prefix, limit, offset) + ) + return result + + async def _alist_records( + self, + scope_prefix: str | None = None, + limit: int = 200, + offset: int = 0, + ) -> list[MemoryRecord]: + """List records in a scope, newest first (async implementation). + + Args: + scope_prefix: Optional scope path prefix to filter by. + limit: Maximum number of records to return. + offset: Number of records to skip for pagination. + + Returns: + List of MemoryRecord, ordered by created_at descending. + """ + client = await self._get_client() + + # Find all record IDs in scope(s) + if scope_prefix is not None: + # Get records from matching scopes + record_ids = await self._find_records_by_scope(scope_prefix) + else: + # Get all records from all scopes + record_ids = [] + cursor: str | bytes = "0" + while True: + result = await client.scan(cursor, match="scope:*", count=1000) + cursor_new: str | bytes = result[0] # type: ignore[assignment] + keys: list[bytes] = result[1] # type: ignore[assignment] + + for key_bytes in keys: + # Get all record IDs in this scope + scope_key = ( + key_bytes.decode("utf-8") + if isinstance(key_bytes, bytes) + else key_bytes + ) + members_result = await client.zrange(scope_key, RangeByIndex(0, -1)) + record_ids.extend( + m.decode("utf-8") if isinstance(m, bytes) else str(m) + for m in members_result + ) + + # Check if cursor is 0 (scan complete) + cursor_str = ( + cursor_new.decode("utf-8") + if isinstance(cursor_new, bytes) + else cursor_new + ) + if cursor_str == "0": + break + cursor = cursor_new + + # Fetch records and sort by created_at descending + records: list[MemoryRecord] = [] + for record_id in record_ids: + record = await self._aget_record(record_id) + if record: + records.append(record) + + # Sort by created_at descending (newest first) + records.sort(key=lambda r: r.created_at, reverse=True) + + # Apply pagination + return records[offset : offset + limit] + + def get_scope_info(self, scope: str) -> ScopeInfo: + """Get information about a scope. + + Counts records in scope and subscopes using sorted set cardinality. + Extracts categories used within scope. + Finds oldest and newest record timestamps. + Lists immediate child scope paths. + + Args: + scope: The scope path. + + Returns: + ScopeInfo with record count, categories, date range, child scopes. + """ + result: ScopeInfo = self._run_async(self._aget_scope_info(scope)) + return result + + async def _aget_scope_info(self, scope: str) -> ScopeInfo: + """Get information about a scope (async implementation). + + Args: + scope: The scope path. + + Returns: + ScopeInfo with record count, categories, date range, child scopes. + """ + # Normalize scope path + scope = scope.rstrip("/") or "/" + prefix = scope if scope != "/" else "" + + # Find all record IDs in scope and subscopes + record_ids = await self._find_records_by_scope(prefix or "/") + + if not record_ids: + return ScopeInfo( + path=scope, + record_count=0, + categories=[], + oldest_record=None, + newest_record=None, + child_scopes=[], + ) + + # Fetch records to extract categories and timestamps + categories_set: set[str] = set() + oldest: datetime | None = None + newest: datetime | None = None + + for record_id in record_ids: + record = await self._aget_record(record_id) + if record: + # Collect categories + categories_set.update(record.categories) + + # Track oldest and newest timestamps + if oldest is None or record.created_at < oldest: + oldest = record.created_at + if newest is None or record.created_at > newest: + newest = record.created_at + + # Find immediate child scopes + child_scopes = await self._alist_scopes(scope) + + return ScopeInfo( + path=scope, + record_count=len(record_ids), + categories=sorted(categories_set), + oldest_record=oldest, + newest_record=newest, + child_scopes=child_scopes, + ) + + def list_scopes(self, parent: str = "/") -> list[str]: + """List immediate child scopes under a parent path. + + Defaults to root scope "/" when no parent specified. + Parses scope paths from scope sorted set keys. + Returns only immediate children, not grandchildren. + + Args: + parent: Parent scope path (default root "/"). + + Returns: + List of immediate child scope paths in sorted order. + """ + result: list[str] = self._run_async(self._alist_scopes(parent)) + return result + + async def _alist_scopes(self, parent: str = "/") -> list[str]: + """List immediate child scopes under a parent path (async implementation). + + Args: + parent: Parent scope path (default root "/"). + + Returns: + List of immediate child scope paths in sorted order. + """ + client = await self._get_client() + + # Normalize parent path + parent = parent.rstrip("/") or "" + prefix = (parent + "/") if parent else "/" + + # Scan for all scope keys + children: set[str] = set() + cursor: str | bytes = "0" + while True: + result = await client.scan(cursor, match="scope:*", count=1000) + cursor_new: str | bytes = result[0] # type: ignore[assignment] + keys: list[bytes] = result[1] # type: ignore[assignment] + + for key_bytes in keys: + # Extract scope path from key + key_str = ( + key_bytes.decode("utf-8") + if isinstance(key_bytes, bytes) + else key_bytes + ) + scope_path = key_str.split(":", 1)[1] if ":" in key_str else "" + + # Check if scope is a child of parent + if scope_path.startswith(prefix) and scope_path != ( + prefix.rstrip("/") or "/" + ): + # Extract the immediate child component + rest = scope_path[len(prefix) :] + first_component = rest.split("/", 1)[0] + if first_component: + child_path = prefix + first_component + children.add(child_path) + + # Check if cursor is 0 (scan complete) + cursor_str = ( + cursor_new.decode("utf-8") + if isinstance(cursor_new, bytes) + else cursor_new + ) + if cursor_str == "0": + break + cursor = cursor_new + + return sorted(children) + + def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]: + """List categories and their counts within a scope. + + Supports filtering by scope_prefix. + Computes counts by measuring category set cardinality. + Returns global category counts when scope_prefix is None. + + Args: + scope_prefix: Optional scope to limit to (None = global). + + Returns: + Mapping of category name to record count. + """ + result: dict[str, int] = self._run_async(self._alist_categories(scope_prefix)) + return result + + async def _alist_categories( + self, scope_prefix: str | None = None + ) -> dict[str, int]: + """List categories and their counts within a scope (async implementation). + + Args: + scope_prefix: Optional scope to limit to (None = global). + + Returns: + Mapping of category name to record count. + """ + client = await self._get_client() + + if scope_prefix is not None: + # Get records in scope and count their categories + record_ids = await self._find_records_by_scope(scope_prefix) + counts: dict[str, int] = {} + + for record_id in record_ids: + record = await self._aget_record(record_id) + if record: + for category in record.categories: + counts[category] = counts.get(category, 0) + 1 + + return counts + # Global category counts - scan all category sets + counts = {} + cursor: str | bytes = "0" + while True: + result = await client.scan(cursor, match="category:*", count=1000) + cursor_new: str | bytes = result[0] # type: ignore[assignment] + keys: list[bytes] = result[1] # type: ignore[assignment] + + for key_bytes in keys: + # Extract category name from key + key_str = ( + key_bytes.decode("utf-8") + if isinstance(key_bytes, bytes) + else key_bytes + ) + category_name = key_str.split(":", 1)[1] if ":" in key_str else "" + + if category_name: + # Get cardinality of category set + category_key = self._category_key(category_name) + count = await client.scard(category_key) + counts[category_name] = int(count) if count else 0 + + # Check if cursor is 0 (scan complete) + cursor_str = ( + cursor_new.decode("utf-8") + if isinstance(cursor_new, bytes) + else cursor_new + ) + if cursor_str == "0": + break + cursor = cursor_new + + return counts + + def count(self, scope_prefix: str | None = None) -> int: + """Count records in scope (and subscopes). + + Uses scope sorted set cardinality for efficient counting. + Supports scope_prefix filtering. + Returns total count across all scopes when scope_prefix is None. + + Args: + scope_prefix: Optional scope path (None = all). + + Returns: + Number of records. + """ + result: int = self._run_async(self._acount(scope_prefix)) + return result + + async def _acount(self, scope_prefix: str | None = None) -> int: + """Count records in scope (and subscopes) (async implementation). + + Args: + scope_prefix: Optional scope path (None = all). + + Returns: + Number of records. + """ + if scope_prefix is None or scope_prefix.strip("/") == "": + # Count all records across all scopes + record_ids = await self._find_records_by_scope("/") + return len(set(record_ids)) # Use set to deduplicate + # Count records in specific scope and subscopes + record_ids = await self._find_records_by_scope(scope_prefix) + return len(set(record_ids)) # Use set to deduplicate + + def reset(self, scope_prefix: str | None = None) -> None: + """Reset (delete all) memories in scope. + + Deletes all records in scope and subscopes when scope_prefix provided. + Deletes all records across all scopes when scope_prefix is None. + Removes all index structures atomically. + + Args: + scope_prefix: Optional scope path (None = reset all). + """ + self._run_async(self._areset(scope_prefix)) + + async def _areset(self, scope_prefix: str | None = None) -> None: + """Reset (delete all) memories in scope (async implementation). + + Args: + scope_prefix: Optional scope path (None = reset all). + """ + # Use delete with scope_prefix to remove all records + await self.adelete(scope_prefix=scope_prefix) diff --git a/lib/crewai/src/crewai/memory/types.py b/lib/crewai/src/crewai/memory/types.py index e787b569d0..fc37027cfd 100644 --- a/lib/crewai/src/crewai/memory/types.py +++ b/lib/crewai/src/crewai/memory/types.py @@ -2,13 +2,17 @@ from __future__ import annotations +import concurrent.futures from datetime import datetime +import logging from typing import Any from uuid import uuid4 -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator +_logger = logging.getLogger(__name__) + # When searching the vector store, we ask for more results than the caller # requested so that post-search steps (composite scoring, deduplication, # category filtering) have enough candidates to fill the final result set. @@ -57,6 +61,23 @@ class MemoryRecord(BaseModel): repr=False, description="Vector embedding for semantic search. Excluded from serialization to save tokens.", ) + + @field_validator("embedding", mode="before") + @classmethod + def validate_embedding(cls, v: Any) -> list[float] | None: + """Ensure embedding is always list[float] or None, never bytes.""" + if v is None: + return None + if isinstance(v, bytes): + # Convert bytes to list[float] if needed + import numpy as np + + if len(v) == 0: + return None + arr = np.frombuffer(v, dtype=np.float32) + return [float(x) for x in arr] + return [float(x) for x in v] + source: str | None = Field( default=None, description=( @@ -304,7 +325,11 @@ def embed_text(embedder: Any, text: str) -> list[float]: """ if not text or not text.strip(): return [] + + # Just call the embedder directly - the blocking issue needs to be fixed + # at a higher level (making Memory.recall() async) result = embedder([text]) + if not result: return [] first = result[0] @@ -315,6 +340,11 @@ def embed_text(embedder: Any, text: str) -> list[float]: return list(first) +# Reusable thread pool for running embedder calls from sync context +# when an async event loop is already running. +_EMBED_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]: """Embed multiple texts in a single API call. @@ -328,6 +358,8 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]: Returns: List of embeddings, one per input text. Empty texts produce empty lists. """ + import asyncio + if not texts: return [] # Filter out empty texts, remembering their positions @@ -337,7 +369,23 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]: if not valid: return [[] for _ in texts] - result = embedder([t for _, t in valid]) + # Check if we're in an async context + result: Any + try: + asyncio.get_running_loop() + # We're in an async context, but this is a sync function + # Run embedder in thread pool to avoid blocking the event loop + try: + result = _EMBED_POOL.submit(embedder, [t for _, t in valid]).result( + timeout=30 + ) + except concurrent.futures.TimeoutError: + _logger.warning("Embedder timed out after 30s, returning empty embeddings") + return [[] for _ in texts] + except RuntimeError: + # Not in async context, run directly + result = embedder([t for _, t in valid]) + embeddings: list[list[float]] = [[] for _ in texts] for (orig_idx, _), emb in zip(valid, result, strict=False): if hasattr(emb, "tolist"): diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index d879bace0c..93827bac18 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -5,6 +5,7 @@ from concurrent.futures import Future, ThreadPoolExecutor import contextvars from datetime import datetime +import logging import threading import time from typing import TYPE_CHECKING, Annotated, Any, Literal @@ -36,6 +37,9 @@ from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec +_logger = logging.getLogger(__name__) + + if TYPE_CHECKING: from chromadb.utils.embedding_functions.openai_embedding_function import ( OpenAIEmbeddingFunction, @@ -211,6 +215,17 @@ def model_post_init(self, __context: Any) -> None: from crewai.memory.storage.lancedb_storage import LanceDBStorage self._storage = LanceDBStorage() + elif self.storage == "valkey": + from crewai.memory.storage.valkey_storage import ValkeyStorage + from crewai.utilities.cache_config import parse_cache_url + + conn = parse_cache_url() or {} + self._storage = ValkeyStorage( + host=conn.get("host", "localhost"), + port=conn.get("port", 6379), + db=conn.get("db", 0), + password=conn.get("password"), + ) else: from crewai.memory.storage.lancedb_storage import LanceDBStorage @@ -316,16 +331,60 @@ def _on_save_done(self, future: Future[Any]) -> None: except Exception: # noqa: S110 pass # swallow everything during shutdown - def drain_writes(self) -> None: + def drain_writes(self, timeout_per_save: float = 60.0) -> None: """Block until all pending background saves have completed. Called automatically by ``recall()`` and should be called by the crew at shutdown to ensure no saves are lost. + + Args: + timeout_per_save: Maximum seconds to wait per save operation. + Default 60s. If a save times out, logs warning + but continues to avoid blocking crew completion. """ with self._pending_lock: pending = list(self._pending_saves) - for future in pending: - future.result() # blocks until done; re-raises exceptions + + if pending: + _logger.debug( + "[DRAIN_WRITES] Waiting for %d pending saves...", len(pending) + ) + + failed_saves = 0 + for i, future in enumerate(pending): + try: + _logger.debug( + "[DRAIN_WRITES] Waiting for save %d/%d...", i + 1, len(pending) + ) + future.result(timeout=timeout_per_save) + _logger.debug( + "[DRAIN_WRITES] Save %d/%d completed", i + 1, len(pending) + ) + except TimeoutError: # noqa: PERF203 + failed_saves += 1 + _logger.warning( + "[DRAIN_WRITES] Save %d/%d timed out after %ss. " + "This save will be abandoned. Consider increasing timeout or checking " + "LLM/embedder performance.", + i + 1, + len(pending), + timeout_per_save, + ) + # Don't raise - just log and continue to avoid blocking crew completion + except Exception as e: + failed_saves += 1 + _logger.error( + "[DRAIN_WRITES] Save %d/%d failed: %s", i + 1, len(pending), e + ) + # Don't raise - just log and continue + + if failed_saves > 0: + _logger.warning( + "[DRAIN_WRITES] %d/%d saves failed or timed out. " + "Some memories may not have been persisted.", + failed_saves, + len(pending), + ) def close(self) -> None: """Drain pending saves, flush storage, and shut down the background thread pool.""" diff --git a/lib/crewai/src/crewai/utilities/cache_config.py b/lib/crewai/src/crewai/utilities/cache_config.py new file mode 100644 index 0000000000..d13e74383d --- /dev/null +++ b/lib/crewai/src/crewai/utilities/cache_config.py @@ -0,0 +1,66 @@ +"""Shared cache configuration helpers for Valkey/Redis URL parsing.""" + +from __future__ import annotations + +import logging +import os +from typing import Any +from urllib.parse import urlparse + + +_logger = logging.getLogger(__name__) + + +def parse_cache_url() -> dict[str, Any] | None: + """Parse VALKEY_URL or REDIS_URL from environment. + + Priority: VALKEY_URL > REDIS_URL. + + Returns: + Dict with host, port, db, password keys, or None if no URL is set. + """ + url = os.environ.get("VALKEY_URL") or os.environ.get("REDIS_URL") + if not url: + return None + parsed = urlparse(url) + return { + "host": parsed.hostname or "localhost", + "port": parsed.port or 6379, + "db": ( + int(parsed.path.lstrip("/")) if parsed.path and parsed.path != "/" else 0 + ), + "password": parsed.password, + } + + +def get_aiocache_config() -> dict[str, Any]: + """Build an aiocache configuration dict from environment. + + Uses VALKEY_URL or REDIS_URL (both are Redis-wire-compatible) to + configure ``aiocache.RedisCache``. Falls back to + ``aiocache.SimpleMemoryCache`` when neither variable is set. + + Returns: + Configuration dict suitable for ``aiocache.caches.set_config()``. + """ + conn = parse_cache_url() + if conn is not None: + return { + "default": { + "cache": "aiocache.RedisCache", + "endpoint": conn["host"], + "port": conn["port"], + "db": conn.get("db", 0), + "password": conn.get("password"), + } + } + return { + "default": { + "cache": "aiocache.SimpleMemoryCache", + } + } + + +def use_valkey_cache() -> bool: + """Return True if VALKEY_URL is set in the environment.""" + return bool(os.environ.get("VALKEY_URL")) diff --git a/lib/crewai/tests/memory/storage/test_valkey_cache.py b/lib/crewai/tests/memory/storage/test_valkey_cache.py new file mode 100644 index 0000000000..7534368aa4 --- /dev/null +++ b/lib/crewai/tests/memory/storage/test_valkey_cache.py @@ -0,0 +1,498 @@ +"""Tests for ValkeyCache implementation.""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from crewai.memory.storage.valkey_cache import ValkeyCache + + +@pytest.fixture +def mock_glide_client() -> AsyncMock: + """Create a mock GlideClient for testing.""" + client = AsyncMock() + client.get = AsyncMock() + client.set = AsyncMock() + client.delete = AsyncMock() + client.exists = AsyncMock() + client.close = AsyncMock() + return client + + +@pytest.fixture +def valkey_cache(mock_glide_client: AsyncMock) -> ValkeyCache: + """Create a ValkeyCache instance with mocked client.""" + cache = ValkeyCache(host="localhost", port=6379, db=0) + + # Mock the client creation to return our mock + async def mock_create_client() -> AsyncMock: + cache._client = mock_glide_client + return mock_glide_client + + cache._get_client = mock_create_client # type: ignore[method-assign] + return cache + + +class TestValkeyCacheBasicOperations: + """Tests for basic ValkeyCache operations (get/set/delete/exists).""" + + @pytest.mark.asyncio + async def test_set_and_get_string_value( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test setting and getting a string value.""" + # Mock get to return serialized string + mock_glide_client.get.return_value = json.dumps("test_value") + + # Set value + await valkey_cache.set("test_key", "test_value") + + # Verify set was called + mock_glide_client.set.assert_called_once() + call_args = mock_glide_client.set.call_args + assert call_args[0][0] == "test_key" + assert call_args[0][1] == json.dumps("test_value") + + # Get value + result = await valkey_cache.get("test_key") + + # Verify get was called and result is correct + mock_glide_client.get.assert_called_once_with("test_key") + assert result == "test_value" + + @pytest.mark.asyncio + async def test_set_and_get_dict_value( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test setting and getting a dictionary value.""" + test_dict = {"key1": "value1", "key2": 42, "key3": [1, 2, 3]} + mock_glide_client.get.return_value = json.dumps(test_dict) + + # Set value + await valkey_cache.set("dict_key", test_dict) + + # Verify set was called with serialized dict + mock_glide_client.set.assert_called_once() + call_args = mock_glide_client.set.call_args + assert call_args[0][0] == "dict_key" + assert call_args[0][1] == json.dumps(test_dict) + + # Get value + result = await valkey_cache.get("dict_key") + + # Verify result matches original dict + assert result == test_dict + + @pytest.mark.asyncio + async def test_set_and_get_list_value( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test setting and getting a list value.""" + test_list = [1, "two", 3.0, {"nested": "dict"}] + mock_glide_client.get.return_value = json.dumps(test_list) + + await valkey_cache.set("list_key", test_list) + result = await valkey_cache.get("list_key") + + assert result == test_list + + @pytest.mark.asyncio + async def test_get_nonexistent_key_returns_none( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test getting a non-existent key returns None.""" + mock_glide_client.get.return_value = None + + result = await valkey_cache.get("nonexistent_key") + + assert result is None + mock_glide_client.get.assert_called_once_with("nonexistent_key") + + @pytest.mark.asyncio + async def test_delete_key( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test deleting a key.""" + await valkey_cache.delete("test_key") + + mock_glide_client.delete.assert_called_once_with(["test_key"]) + + @pytest.mark.asyncio + async def test_exists_returns_true_for_existing_key( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test exists returns True for existing key.""" + mock_glide_client.exists.return_value = 1 + + result = await valkey_cache.exists("existing_key") + + assert result is True + mock_glide_client.exists.assert_called_once_with(["existing_key"]) + + @pytest.mark.asyncio + async def test_exists_returns_false_for_nonexistent_key( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test exists returns False for non-existent key.""" + mock_glide_client.exists.return_value = 0 + + result = await valkey_cache.exists("nonexistent_key") + + assert result is False + mock_glide_client.exists.assert_called_once_with(["nonexistent_key"]) + + +class TestValkeyCacheTTL: + """Tests for ValkeyCache TTL functionality.""" + + @pytest.mark.asyncio + async def test_set_with_explicit_ttl( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test setting a value with explicit TTL.""" + await valkey_cache.set("ttl_key", "value", ttl=3600) + + # Verify set was called with expiry + mock_glide_client.set.assert_called_once() + call_args = mock_glide_client.set.call_args + assert call_args[0][0] == "ttl_key" + assert call_args[0][1] == json.dumps("value") + assert "expiry" in call_args[1] + + @pytest.mark.asyncio + async def test_set_with_default_ttl( + self, mock_glide_client: AsyncMock + ) -> None: + """Test setting a value with default TTL from constructor.""" + cache = ValkeyCache(host="localhost", port=6379, default_ttl=1800) + + async def mock_create_client() -> AsyncMock: + cache._client = mock_glide_client + return mock_glide_client + + cache._get_client = mock_create_client # type: ignore[method-assign] + + await cache.set("default_ttl_key", "value") + + # Verify set was called with default TTL + mock_glide_client.set.assert_called_once() + call_args = mock_glide_client.set.call_args + assert "expiry" in call_args[1] + + @pytest.mark.asyncio + async def test_set_without_ttl( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test setting a value without TTL (no expiration).""" + await valkey_cache.set("no_ttl_key", "value") + + # Verify set was called without expiry + mock_glide_client.set.assert_called_once() + call_args = mock_glide_client.set.call_args + assert call_args[0][0] == "no_ttl_key" + assert call_args[0][1] == json.dumps("value") + # Should not have expiry parameter + assert "expiry" not in call_args[1] or call_args[1].get("expiry") is None + + @pytest.mark.asyncio + async def test_set_with_zero_ttl_no_expiration( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test setting a value with TTL=0 means no expiration.""" + await valkey_cache.set("zero_ttl_key", "value", ttl=0) + + # Verify set was called without expiry + mock_glide_client.set.assert_called_once() + call_args = mock_glide_client.set.call_args + assert "expiry" not in call_args[1] or call_args[1].get("expiry") is None + + @pytest.mark.asyncio + async def test_explicit_ttl_overrides_default( + self, mock_glide_client: AsyncMock + ) -> None: + """Test explicit TTL overrides default TTL.""" + cache = ValkeyCache(host="localhost", port=6379, default_ttl=1800) + + async def mock_create_client() -> AsyncMock: + cache._client = mock_glide_client + return mock_glide_client + + cache._get_client = mock_create_client # type: ignore[method-assign] + + await cache.set("override_key", "value", ttl=7200) + + # Verify set was called with explicit TTL (7200), not default (1800) + mock_glide_client.set.assert_called_once() + call_args = mock_glide_client.set.call_args + assert "expiry" in call_args[1] + + +class TestValkeyCacheJSONSerialization: + """Tests for ValkeyCache JSON serialization edge cases.""" + + @pytest.mark.asyncio + async def test_serialize_none_value( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test serializing None value.""" + mock_glide_client.get.return_value = json.dumps(None) + + await valkey_cache.set("none_key", None) + result = await valkey_cache.get("none_key") + + assert result is None + + @pytest.mark.asyncio + async def test_serialize_boolean_values( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test serializing boolean values.""" + mock_glide_client.get.side_effect = [ + json.dumps(True), + json.dumps(False), + ] + + await valkey_cache.set("true_key", True) + await valkey_cache.set("false_key", False) + + result_true = await valkey_cache.get("true_key") + result_false = await valkey_cache.get("false_key") + + assert result_true is True + assert result_false is False + + @pytest.mark.asyncio + async def test_serialize_numeric_values( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test serializing numeric values (int, float).""" + mock_glide_client.get.side_effect = [ + json.dumps(42), + json.dumps(3.14159), + json.dumps(0), + json.dumps(-100), + ] + + await valkey_cache.set("int_key", 42) + await valkey_cache.set("float_key", 3.14159) + await valkey_cache.set("zero_key", 0) + await valkey_cache.set("negative_key", -100) + + assert await valkey_cache.get("int_key") == 42 + assert await valkey_cache.get("float_key") == 3.14159 + assert await valkey_cache.get("zero_key") == 0 + assert await valkey_cache.get("negative_key") == -100 + + @pytest.mark.asyncio + async def test_serialize_empty_collections( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test serializing empty collections.""" + mock_glide_client.get.side_effect = [ + json.dumps([]), + json.dumps({}), + json.dumps(""), + ] + + await valkey_cache.set("empty_list", []) + await valkey_cache.set("empty_dict", {}) + await valkey_cache.set("empty_string", "") + + assert await valkey_cache.get("empty_list") == [] + assert await valkey_cache.get("empty_dict") == {} + assert await valkey_cache.get("empty_string") == "" + + @pytest.mark.asyncio + async def test_serialize_nested_structures( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test serializing deeply nested structures.""" + nested_data = { + "level1": { + "level2": { + "level3": [1, 2, {"level4": "deep"}] + } + }, + "list": [{"a": 1}, {"b": 2}] + } + mock_glide_client.get.return_value = json.dumps(nested_data) + + await valkey_cache.set("nested_key", nested_data) + result = await valkey_cache.get("nested_key") + + assert result == nested_data + + @pytest.mark.asyncio + async def test_deserialize_invalid_json_returns_none( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test deserializing invalid JSON returns None and logs warning.""" + mock_glide_client.get.return_value = "invalid json {{" + + with patch("crewai.memory.storage.valkey_cache._logger") as mock_logger: + result = await valkey_cache.get("invalid_key") + + assert result is None + mock_logger.warning.assert_called_once() + + @pytest.mark.asyncio + async def test_serialize_unicode_strings( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test serializing unicode strings.""" + unicode_data = "Hello 世界 🌍 Привет" + mock_glide_client.get.return_value = json.dumps(unicode_data) + + await valkey_cache.set("unicode_key", unicode_data) + result = await valkey_cache.get("unicode_key") + + assert result == unicode_data + + +class TestValkeyCacheConnectionManagement: + """Tests for ValkeyCache connection management.""" + + @pytest.mark.asyncio + async def test_lazy_client_initialization(self) -> None: + """Test client is initialized lazily on first use.""" + cache = ValkeyCache(host="localhost", port=6379) + + # Client should be None initially + assert cache._client is None + + # Mock GlideClient.create + with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide: + mock_client = AsyncMock() + mock_glide.create = AsyncMock(return_value=mock_client) + mock_client.get = AsyncMock(return_value=None) + + # First operation should initialize client + await cache.get("test_key") + + # Client should now be initialized + assert cache._client is not None + mock_glide.create.assert_called_once() + + @pytest.mark.asyncio + async def test_client_reuse_across_operations( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test client is reused across multiple operations.""" + mock_glide_client.get.return_value = json.dumps("value") + mock_glide_client.exists.return_value = 1 + + # Perform multiple operations + await valkey_cache.get("key1") + await valkey_cache.set("key2", "value2") + await valkey_cache.exists("key3") + await valkey_cache.delete("key4") + + # _get_client should return the same client instance + client1 = await valkey_cache._get_client() + client2 = await valkey_cache._get_client() + assert client1 is client2 + + @pytest.mark.asyncio + async def test_close_connection( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test closing the client connection.""" + # Initialize client + await valkey_cache._get_client() + assert valkey_cache._client is not None + + # Close connection + await valkey_cache.close() + + # Verify close was called and client is None + mock_glide_client.close.assert_called_once() + assert valkey_cache._client is None + + @pytest.mark.asyncio + async def test_connection_error_raises_runtime_error(self) -> None: + """Test connection error raises RuntimeError with descriptive message.""" + cache = ValkeyCache(host="invalid-host", port=9999) + + with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide: + mock_glide.create = AsyncMock(side_effect=Exception("Connection refused")) + + with pytest.raises(RuntimeError) as exc_info: + await cache._get_client() + + assert "Cannot connect to Valkey" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_authentication_with_password(self) -> None: + """Test client initialization with password authentication.""" + cache = ValkeyCache( + host="localhost", + port=6379, + password="secret_password" + ) + + with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide: + mock_client = AsyncMock() + mock_glide.create = AsyncMock(return_value=mock_client) + + await cache._get_client() + + # Verify GlideClient.create was called with credentials + mock_glide.create.assert_called_once() + config = mock_glide.create.call_args[0][0] + assert hasattr(config, "credentials") + + +class TestValkeyCacheEdgeCases: + """Tests for ValkeyCache edge cases and error conditions.""" + + @pytest.mark.asyncio + async def test_set_with_special_characters_in_key( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test setting values with special characters in key.""" + special_keys = [ + "key:with:colons", + "key/with/slashes", + "key-with-dashes", + "key_with_underscores", + "key.with.dots", + ] + + for key in special_keys: + await valkey_cache.set(key, "value") + mock_glide_client.set.assert_called() + + @pytest.mark.asyncio + async def test_large_value_serialization( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test serializing large values.""" + large_list = list(range(10000)) + mock_glide_client.get.return_value = json.dumps(large_list) + + await valkey_cache.set("large_key", large_list) + result = await valkey_cache.get("large_key") + + assert result == large_list + + @pytest.mark.asyncio + async def test_concurrent_operations( + self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock + ) -> None: + """Test concurrent cache operations.""" + import asyncio + + mock_glide_client.get.return_value = json.dumps("value") + + # Perform concurrent operations + tasks = [ + valkey_cache.set(f"key{i}", f"value{i}") + for i in range(10) + ] + await asyncio.gather(*tasks) + + # Verify all operations completed + assert mock_glide_client.set.call_count == 10 diff --git a/lib/crewai/tests/memory/storage/test_valkey_storage.py b/lib/crewai/tests/memory/storage/test_valkey_storage.py new file mode 100644 index 0000000000..ea1ff1bbe4 --- /dev/null +++ b/lib/crewai/tests/memory/storage/test_valkey_storage.py @@ -0,0 +1,3074 @@ +"""Tests for ValkeyStorage save operation.""" + +from __future__ import annotations + +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from crewai.memory.storage.valkey_storage import ValkeyStorage +from crewai.memory.types import MemoryRecord + + +@pytest.fixture +def mock_glide_client() -> AsyncMock: + """Create a mock GlideClient for testing.""" + client = AsyncMock() + client.hset = AsyncMock(return_value=1) + client.zrange = AsyncMock(return_value=[]) + client.zadd = AsyncMock() + client.sadd = AsyncMock() + client.hgetall = AsyncMock(return_value={}) + client.close = AsyncMock() + return client + + +@pytest.fixture +def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage: + """Create a ValkeyStorage instance with mocked client.""" + storage = ValkeyStorage(host="localhost", port=6379, db=0) + + # Mock the client creation to return our mock + async def mock_create_client() -> AsyncMock: + storage._client = mock_glide_client + return mock_glide_client + + storage._get_client = mock_create_client # type: ignore[method-assign] + return storage + + +class TestValkeyStorageSave: + """Tests for ValkeyStorage save operation.""" + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_save_single_record_with_all_fields( + self, mock_ft_list, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test saving a single record with all fields populated.""" + # Create a record with all fields + record = MemoryRecord( + id="test-id-123", + content="Test memory content", + scope="/agent/task", + categories=["planning", "execution"], + metadata={"agent_id": "agent-1", "priority": "high"}, + importance=0.8, + created_at=datetime(2024, 1, 1, 12, 0, 0), + last_accessed=datetime(2024, 1, 1, 12, 0, 0), + embedding=[0.1, 0.2, 0.3, 0.4], + source="test-source", + private=True, + ) + + # Mock ft.list to simulate index exists + mock_ft_list.return_value = [b"memory_index"] + + # Save the record + await valkey_storage.asave([record]) + + # Verify ft.list was called to check index + mock_ft_list.assert_called_once() + + # Verify HSET was called with correct record data + mock_glide_client.hset.assert_called() + hset_call = mock_glide_client.hset.call_args + assert hset_call[0][0] == "record:test-id-123" # key + hset_dict = hset_call[0][1] # field_value_map dict + + assert hset_dict["id"] == "test-id-123" + assert hset_dict["content"] == "Test memory content" + assert hset_dict["scope"] == "/agent/task" + assert hset_dict["source"] == "test-source" + assert hset_dict["private"] == "true" + assert hset_dict["importance"] == "0.8" + assert "embedding" in hset_dict + assert isinstance(hset_dict["embedding"], bytes) + + # Verify scope index was updated + mock_glide_client.zadd.assert_called_once() + zadd_call = mock_glide_client.zadd.call_args + assert zadd_call[0][0] == "scope:/agent/task" + assert "test-id-123" in zadd_call[0][1] + + # Verify category indexes were updated + assert mock_glide_client.sadd.call_count >= 2 + sadd_calls = [call[0] for call in mock_glide_client.sadd.call_args_list] + category_calls = [call for call in sadd_calls if call[0].startswith("category:")] + assert len(category_calls) == 2 + assert any("category:planning" in str(call) for call in category_calls) + assert any("category:execution" in str(call) for call in category_calls) + + # Verify metadata indexes were updated + metadata_calls = [call for call in sadd_calls if call[0].startswith("metadata:")] + assert len(metadata_calls) == 2 + assert any("metadata:agent_id:agent-1" in str(call) for call in metadata_calls) + assert any("metadata:priority:high" in str(call) for call in metadata_calls) + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_save_multiple_records_in_batch( + self, mock_ft_list, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test saving multiple records in a single batch.""" + records = [ + MemoryRecord( + id=f"record-{i}", + content=f"Content {i}", + scope="/test", + embedding=[0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i], + ) + for i in range(3) + ] + + # Mock ft.list to simulate index exists + mock_ft_list.return_value = [b"memory_index"] + + await valkey_storage.asave(records) + + # Verify HSET was called for each record + assert mock_glide_client.hset.call_count == 3 + + # Verify each record was stored + hset_calls = mock_glide_client.hset.call_args_list + for i in range(3): + record_key = f"record:record-{i}" + assert any(call[0][0] == record_key for call in hset_calls) + + # Verify scope index was updated for all records + assert mock_glide_client.zadd.call_count == 3 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_save_record_with_empty_categories_and_metadata( + self, mock_ft_list, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test saving a record with empty categories and metadata.""" + record = MemoryRecord( + id="empty-fields-record", + content="Content with no categories or metadata", + scope="/test", + categories=[], + metadata={}, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Mock ft.list to simulate index exists + mock_ft_list.return_value = [b"memory_index"] + + await valkey_storage.asave([record]) + + # Verify record was saved + mock_glide_client.hset.assert_called_once() + + # Verify no category or metadata index updates + sadd_calls = mock_glide_client.sadd.call_args_list + # Should have no calls since categories and metadata are empty + assert len(sadd_calls) == 0 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_save_record_without_embedding( + self, mock_ft_list, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test saving a record without an embedding.""" + record = MemoryRecord( + id="no-embedding-record", + content="Content without embedding", + scope="/test", + embedding=None, + ) + + # Mock ft.list to simulate index exists + mock_ft_list.return_value = [b"memory_index"] + + await valkey_storage.asave([record]) + + # Verify record was saved + mock_glide_client.hset.assert_called_once() + + # Verify embedding field is empty bytes + hset_call = mock_glide_client.hset.call_args + hset_dict = hset_call[0][1] # field_value_map dict + + assert "embedding" in hset_dict + assert hset_dict["embedding"] == b"" + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_save_record_with_none_source( + self, mock_ft_list, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test saving a record with None source.""" + record = MemoryRecord( + id="none-source-record", + content="Content with None source", + scope="/test", + source=None, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Mock ft.list to simulate index exists + mock_ft_list.return_value = [b"memory_index"] + + await valkey_storage.asave([record]) + + # Verify record was saved + mock_glide_client.hset.assert_called_once() + + # Verify source field is empty string + hset_call = mock_glide_client.hset.call_args + hset_dict = hset_call[0][1] # field_value_map dict + + assert hset_dict["source"] == "" + + @pytest.mark.asyncio + async def test_save_empty_list_does_nothing( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that saving an empty list does nothing.""" + await valkey_storage.asave([]) + + # Verify no operations were performed + mock_glide_client.hset.assert_not_called() + mock_glide_client.zadd.assert_not_called() + mock_glide_client.sadd.assert_not_called() + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.FtCreateOptions") + @patch("crewai.memory.storage.valkey_storage.VectorField") + @patch("crewai.memory.storage.valkey_storage.VectorFieldAttributesHnsw") + @patch("crewai.memory.storage.valkey_storage.ft.create") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_save_creates_vector_index_if_not_exists( + self, mock_ft_list, mock_ft_create, mock_vector_attrs, mock_vector_field, mock_ft_create_options, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that save creates vector index if it doesn't exist.""" + record = MemoryRecord( + id="test-record", + content="Test content", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Mock ft.info to fail (index doesn't exist), then ft.create succeeds + mock_ft_list.return_value = [] + mock_ft_create.return_value = "OK" + + await valkey_storage.asave([record]) + + # Verify ft.create was called + mock_ft_create.assert_called_once() + + # Verify ft.create was called with correct index name + create_args = mock_ft_create.call_args + assert create_args[0][1] == "memory_index" + + @pytest.mark.asyncio + async def test_save_error_handling_for_serialization_failure( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test error handling when serialization fails.""" + # Create a record with a field that will cause serialization to fail + record = MemoryRecord( + id="bad-record", + content="Test content", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Mock _record_to_dict to raise an error + with patch.object( + valkey_storage, + "_record_to_dict", + side_effect=ValueError("Serialization failed"), + ): + with pytest.raises(ValueError, match="Serialization failed"): + await valkey_storage.asave([record]) + + @patch("crewai.memory.storage.valkey_storage.ft.list") + def test_save_sync_wrapper( + self, mock_ft_list, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync save wrapper calls async implementation.""" + record = MemoryRecord( + id="sync-test-record", + content="Test content", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Mock ft.list to simulate index exists + mock_ft_list.return_value = [b"memory_index"] + + # Call sync save + valkey_storage.save([record]) + + # Verify async operations were called + mock_glide_client.hset.assert_called_once() + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_save_with_special_characters_in_metadata( + self, mock_ft_list, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test saving a record with special characters in metadata values.""" + record = MemoryRecord( + id="special-chars-record", + content="Test content", + scope="/test", + metadata={ + "key:with:colons": "value:with:colons", + "key with spaces": "value with spaces", + "key/with/slashes": "value/with/slashes", + }, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Mock ft.list to simulate index exists + mock_ft_list.return_value = [b"memory_index"] + + await valkey_storage.asave([record]) + + # Verify metadata indexes were created with special characters + sadd_calls = mock_glide_client.sadd.call_args_list + metadata_calls = [call[0][0] for call in sadd_calls if call[0][0].startswith("metadata:")] + + assert len(metadata_calls) == 3 + assert any("key:with:colons:value:with:colons" in call for call in metadata_calls) + assert any("key with spaces:value with spaces" in call for call in metadata_calls) + assert any("key/with/slashes:value/with/slashes" in call for call in metadata_calls) + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_save_with_numeric_metadata_values( + self, mock_ft_list, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test saving a record with numeric metadata values.""" + record = MemoryRecord( + id="numeric-metadata-record", + content="Test content", + scope="/test", + metadata={ + "count": 42, + "score": 3.14, + "is_active": True, + }, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Mock ft.list to simulate index exists + mock_ft_list.return_value = [b"memory_index"] + + await valkey_storage.asave([record]) + + # Verify metadata indexes were created with string-converted values + sadd_calls = mock_glide_client.sadd.call_args_list + metadata_calls = [call[0][0] for call in sadd_calls if call[0][0].startswith("metadata:")] + + assert len(metadata_calls) == 3 + assert any("metadata:count:42" in call for call in metadata_calls) + assert any("metadata:score:3.14" in call for call in metadata_calls) + assert any("metadata:is_active:True" in call for call in metadata_calls) + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_save_preserves_datetime_precision( + self, mock_ft_list, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that datetime fields are serialized with proper precision.""" + created_at = datetime(2024, 1, 15, 10, 30, 45, 123456) + last_accessed = datetime(2024, 1, 15, 11, 45, 30, 654321) + + record = MemoryRecord( + id="datetime-precision-record", + content="Test content", + scope="/test", + created_at=created_at, + last_accessed=last_accessed, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Mock ft.list to simulate index exists + mock_ft_list.return_value = [b"memory_index"] + + await valkey_storage.asave([record]) + + # Verify datetime fields are in ISO format + mock_glide_client.hset.assert_called_once() + hset_call = mock_glide_client.hset.call_args + hset_dict = hset_call[0][1] # field_value_map dict + + assert hset_dict["created_at"] == created_at.isoformat() + assert hset_dict["last_accessed"] == last_accessed.isoformat() + + + +class TestValkeyStorageGetRecord: + """Tests for ValkeyStorage get_record operation.""" + + @pytest.mark.asyncio + async def test_retrieve_existing_record_with_all_fields( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test retrieving an existing record with all fields populated.""" + # Mock HGETALL to return a complete record + mock_glide_client.hgetall.return_value = { + "id": "test-record-123", + "content": "Test memory content", + "scope": "/agent/task", + "categories": '["planning", "execution"]', + "metadata": '{"agent_id": "agent-1", "priority": "high"}', + "importance": "0.8", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T13:00:00", + "embedding": valkey_storage._embedding_to_bytes([0.1, 0.2, 0.3, 0.4]), + "source": "test-source", + "private": "true", + } + + # Retrieve the record + record = await valkey_storage._aget_record("test-record-123") + + # Verify HGETALL was called with correct key + mock_glide_client.hgetall.assert_called_once_with("record:test-record-123") + + # Verify all fields are correctly deserialized + assert record is not None + assert record.id == "test-record-123" + assert record.content == "Test memory content" + assert record.scope == "/agent/task" + assert record.categories == ["planning", "execution"] + assert record.metadata == {"agent_id": "agent-1", "priority": "high"} + assert record.importance == 0.8 + assert record.created_at == datetime(2024, 1, 1, 12, 0, 0) + assert record.last_accessed == datetime(2024, 1, 1, 13, 0, 0) + # Check embedding with approximate comparison (float32 precision) + assert record.embedding is not None + assert len(record.embedding) == 4 + for i, expected in enumerate([0.1, 0.2, 0.3, 0.4]): + assert abs(record.embedding[i] - expected) < 1e-6 + assert record.source == "test-source" + assert record.private is True + + @pytest.mark.asyncio + async def test_retrieve_non_existent_record_returns_none( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test retrieving a non-existent record returns None.""" + # Mock HGETALL to return empty dict (record doesn't exist) + mock_glide_client.hgetall.return_value = {} + + # Retrieve non-existent record + record = await valkey_storage._aget_record("non-existent-id") + + # Verify HGETALL was called + mock_glide_client.hgetall.assert_called_once_with("record:non-existent-id") + + # Verify None is returned + assert record is None + + @pytest.mark.asyncio + async def test_retrieve_record_with_empty_embedding( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test retrieving a record with empty embedding.""" + # Mock HGETALL to return record with empty embedding + mock_glide_client.hgetall.return_value = { + "id": "no-embedding-record", + "content": "Content without embedding", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", # Empty bytes + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("no-embedding-record") + + # Verify record is retrieved with None embedding + assert record is not None + assert record.id == "no-embedding-record" + assert record.embedding is None + + @pytest.mark.asyncio + async def test_retrieve_record_with_none_source( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test retrieving a record with None source.""" + # Mock HGETALL to return record with empty source + mock_glide_client.hgetall.return_value = { + "id": "no-source-record", + "content": "Content without source", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", # Empty string + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("no-source-record") + + # Verify record is retrieved with None source + assert record is not None + assert record.source is None + + @pytest.mark.asyncio + async def test_retrieve_record_with_false_private_flag( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test retrieving a record with private=false.""" + # Mock HGETALL to return record with private=false + mock_glide_client.hgetall.return_value = { + "id": "public-record", + "content": "Public content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("public-record") + + # Verify private flag is False + assert record is not None + assert record.private is False + + @pytest.mark.asyncio + async def test_retrieve_record_with_empty_categories_and_metadata( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test retrieving a record with empty categories and metadata.""" + # Mock HGETALL to return record with empty lists/dicts + mock_glide_client.hgetall.return_value = { + "id": "minimal-record", + "content": "Minimal content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("minimal-record") + + # Verify empty collections are preserved + assert record is not None + assert record.categories == [] + assert record.metadata == {} + + @pytest.mark.asyncio + async def test_deserialization_of_datetime_fields( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deserialization of datetime fields with microseconds.""" + # Mock HGETALL with datetime including microseconds + mock_glide_client.hgetall.return_value = { + "id": "datetime-record", + "content": "Test content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-15T10:30:45.123456", + "last_accessed": "2024-01-15T11:45:30.654321", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("datetime-record") + + # Verify datetime fields are correctly parsed + assert record is not None + assert record.created_at == datetime(2024, 1, 15, 10, 30, 45, 123456) + assert record.last_accessed == datetime(2024, 1, 15, 11, 45, 30, 654321) + + @pytest.mark.asyncio + async def test_deserialization_of_float_importance( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deserialization of float importance value.""" + # Mock HGETALL with various float formats + mock_glide_client.hgetall.return_value = { + "id": "float-record", + "content": "Test content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.123456789", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("float-record") + + # Verify float is correctly parsed + assert record is not None + assert abs(record.importance - 0.123456789) < 1e-9 + + @pytest.mark.asyncio + async def test_deserialization_of_json_categories( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deserialization of JSON categories array.""" + # Mock HGETALL with multiple categories + mock_glide_client.hgetall.return_value = { + "id": "categories-record", + "content": "Test content", + "scope": "/test", + "categories": '["planning", "execution", "review", "analysis"]', + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("categories-record") + + # Verify categories are correctly parsed + assert record is not None + assert record.categories == ["planning", "execution", "review", "analysis"] + + @pytest.mark.asyncio + async def test_deserialization_of_json_metadata( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deserialization of JSON metadata object.""" + # Mock HGETALL with complex metadata + mock_glide_client.hgetall.return_value = { + "id": "metadata-record", + "content": "Test content", + "scope": "/test", + "categories": "[]", + "metadata": '{"agent_id": "agent-1", "count": 42, "score": 3.14, "active": true}', + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("metadata-record") + + # Verify metadata is correctly parsed + assert record is not None + assert record.metadata == { + "agent_id": "agent-1", + "count": 42, + "score": 3.14, + "active": True, + } + + @pytest.mark.asyncio + async def test_deserialization_of_binary_embedding( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deserialization of binary embedding vector.""" + # Create a test embedding + test_embedding = [0.1, 0.2, 0.3, 0.4, 0.5] + embedding_bytes = valkey_storage._embedding_to_bytes(test_embedding) + + # Mock HGETALL with binary embedding + mock_glide_client.hgetall.return_value = { + "id": "embedding-record", + "content": "Test content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": embedding_bytes, + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("embedding-record") + + # Verify embedding is correctly deserialized + assert record is not None + assert record.embedding is not None + assert len(record.embedding) == 5 + for i, val in enumerate(test_embedding): + assert abs(record.embedding[i] - val) < 1e-6 + + @pytest.mark.asyncio + async def test_handling_of_malformed_json_categories( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test handling of non-JSON categories uses TAG fallback.""" + # Mock HGETALL with non-JSON categories (treated as TAG format) + mock_glide_client.hgetall.return_value = { + "id": "malformed-categories", + "content": "Test content", + "scope": "/test", + "categories": "not valid json [", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("malformed-categories") + + # TAG fallback: comma-split produces the raw string as a single category + assert record is not None + assert record.id == "malformed-categories" + assert record.categories == ["not valid json ["] + mock_glide_client.hgetall.assert_called_once() + + @pytest.mark.asyncio + async def test_handling_of_malformed_json_metadata( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test handling of malformed JSON in metadata field.""" + # Mock HGETALL with invalid JSON + mock_glide_client.hgetall.return_value = { + "id": "malformed-metadata", + "content": "Test content", + "scope": "/test", + "categories": "[]", + "metadata": "{invalid json}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("malformed-metadata") + + # Verify None is returned and error is logged + assert record is None + + @pytest.mark.asyncio + async def test_handling_of_invalid_datetime_format( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test handling of invalid datetime format.""" + # Mock HGETALL with invalid datetime + mock_glide_client.hgetall.return_value = { + "id": "invalid-datetime", + "content": "Test content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "not a valid datetime", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("invalid-datetime") + + # Verify None is returned and error is logged + assert record is None + + @pytest.mark.asyncio + async def test_handling_of_invalid_importance_value( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test handling of invalid importance value.""" + # Mock HGETALL with non-numeric importance + mock_glide_client.hgetall.return_value = { + "id": "invalid-importance", + "content": "Test content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "not a number", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("invalid-importance") + + # Verify None is returned and error is logged + assert record is None + + @pytest.mark.asyncio + async def test_handling_of_missing_required_fields( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test handling of missing required fields.""" + # Mock HGETALL with missing fields + mock_glide_client.hgetall.return_value = { + "id": "incomplete-record", + "content": "Test content", + # Missing scope, categories, metadata, etc. + } + + # Retrieve the record + record = await valkey_storage._aget_record("incomplete-record") + + # Verify None is returned and error is logged + assert record is None + + @pytest.mark.asyncio + async def test_handling_of_connection_error( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test handling of connection error during retrieval.""" + # Mock HGETALL to raise connection error + mock_glide_client.hgetall.side_effect = Exception("Connection failed") + + # Retrieve the record + record = await valkey_storage._aget_record("test-record") + + # Verify None is returned and error is logged + assert record is None + + def test_get_record_sync_wrapper( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync get_record wrapper calls async implementation.""" + # Mock HGETALL to return a record + mock_glide_client.hgetall.return_value = { + "id": "sync-test-record", + "content": "Test content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Call sync get_record + record = valkey_storage.get_record("sync-test-record") + + # Verify async operation was called + mock_glide_client.hgetall.assert_called_once_with("record:sync-test-record") + assert record is not None + assert record.id == "sync-test-record" + + @pytest.mark.asyncio + async def test_retrieve_record_with_special_characters_in_content( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test retrieving a record with special characters in content.""" + # Mock HGETALL with special characters + mock_glide_client.hgetall.return_value = { + "id": "special-chars-record", + "content": "Content with special chars: \n\t\"quotes\" 'apostrophes' & symbols", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("special-chars-record") + + # Verify special characters are preserved + assert record is not None + assert "\n" in record.content + assert "\t" in record.content + assert '"quotes"' in record.content + assert "'apostrophes'" in record.content + + @pytest.mark.asyncio + async def test_retrieve_record_with_unicode_content( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test retrieving a record with unicode content.""" + # Mock HGETALL with unicode characters + mock_glide_client.hgetall.return_value = { + "id": "unicode-record", + "content": "Unicode content: 你好 مرحبا שלום 🚀 ñ é ü", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Retrieve the record + record = await valkey_storage._aget_record("unicode-record") + + # Verify unicode is preserved + assert record is not None + assert "你好" in record.content + assert "🚀" in record.content + + + +class TestValkeyStorageUpdate: + """Tests for ValkeyStorage update operation.""" + + @pytest.mark.asyncio + async def test_update_existing_record_preserves_created_at( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test updating an existing record preserves created_at timestamp.""" + original_created_at = datetime(2024, 1, 1, 10, 0, 0) + original_last_accessed = datetime(2024, 1, 1, 11, 0, 0) + + # Mock HGETALL to return existing record + mock_glide_client.hgetall.return_value = { + "id": "existing-record", + "content": "Original content", + "scope": "/original/scope", + "categories": '["old-category"]', + "metadata": '{"old_key": "old_value"}', + "importance": "0.5", + "created_at": original_created_at.isoformat(), + "last_accessed": original_last_accessed.isoformat(), + "embedding": b"", + "source": "", + "private": "false", + } + + # Create updated record with different created_at + updated_record = MemoryRecord( + id="existing-record", + content="Updated content", + scope="/updated/scope", + categories=["new-category"], + metadata={"new_key": "new_value"}, + importance=0.8, + created_at=datetime(2024, 2, 1, 10, 0, 0), # Different created_at + last_accessed=datetime(2024, 2, 1, 11, 0, 0), + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update the record + await valkey_storage._aupdate(updated_record) + + # Verify HGETALL was called to fetch existing record + mock_glide_client.hgetall.assert_called_once_with("record:existing-record") + + # Verify HSET was called with updated data + mock_glide_client.hset.assert_called() + hset_call = mock_glide_client.hset.call_args + assert hset_call[0][0] == "record:existing-record" # key + hset_dict = hset_call[0][1] # field_value_map dict + + # Verify created_at was preserved from original + assert hset_dict["created_at"] == original_created_at.isoformat() + + # Verify other fields were updated + assert hset_dict["content"] == "Updated content" + assert hset_dict["scope"] == "/updated/scope" + assert hset_dict["importance"] == "0.8" + + # Verify last_accessed was updated to current time (not the one in updated_record) + last_accessed_dt = datetime.fromisoformat(hset_dict["last_accessed"]) + assert last_accessed_dt > original_last_accessed + + @pytest.mark.asyncio + async def test_update_non_existent_record_creates_new_one( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test updating a non-existent record creates a new one.""" + # Mock HGETALL to return empty dict (record doesn't exist) + mock_glide_client.hgetall.return_value = {} + + # Create new record + new_record = MemoryRecord( + id="new-record", + content="New content", + scope="/new/scope", + categories=["new-category"], + metadata={"key": "value"}, + importance=0.7, + created_at=datetime(2024, 1, 1, 10, 0, 0), + last_accessed=datetime(2024, 1, 1, 11, 0, 0), + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update (create) the record + await valkey_storage._aupdate(new_record) + + # Verify HGETALL was called + mock_glide_client.hgetall.assert_called_once_with("record:new-record") + + # Verify HSET was called to create the record + mock_glide_client.hset.assert_called_once() + + # Verify new indexes were created + mock_glide_client.zadd.assert_called_once() + assert mock_glide_client.sadd.call_count == 2 # 1 category + 1 metadata + + @pytest.mark.asyncio + async def test_update_maintains_index_consistency( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that update maintains index consistency.""" + # Mock HGETALL to return existing record + mock_glide_client.hgetall.return_value = { + "id": "indexed-record", + "content": "Original content", + "scope": "/original", + "categories": '["cat1", "cat2"]', + "metadata": '{"key1": "value1", "key2": "value2"}', + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Create updated record with same categories and metadata + updated_record = MemoryRecord( + id="indexed-record", + content="Updated content", + scope="/original", + categories=["cat1", "cat2"], + metadata={"key1": "value1", "key2": "value2"}, + importance=0.8, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update the record + await valkey_storage._aupdate(updated_record) + + # Verify old indexes were removed + mock_glide_client.zrem.assert_called_once_with("scope:/original", ["indexed-record"]) + + # Verify old category indexes were removed + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + category_srem_calls = [call for call in srem_calls if "category:" in str(call)] + assert len(category_srem_calls) == 2 + + # Verify old metadata indexes were removed + metadata_srem_calls = [call for call in srem_calls if "metadata:" in str(call)] + assert len(metadata_srem_calls) == 2 + + # Verify new indexes were added + mock_glide_client.zadd.assert_called_once() + + # Verify new category indexes were added + sadd_calls = [call for call in mock_glide_client.sadd.call_args_list] + category_sadd_calls = [call for call in sadd_calls if "category:" in str(call[0])] + assert len(category_sadd_calls) == 2 + + # Verify new metadata indexes were added + metadata_sadd_calls = [call for call in sadd_calls if "metadata:" in str(call[0])] + assert len(metadata_sadd_calls) == 2 + + @pytest.mark.asyncio + async def test_update_removes_from_old_scope_index( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test updating scope removes record from old scope index.""" + # Mock HGETALL to return existing record with old scope + mock_glide_client.hgetall.return_value = { + "id": "scope-change-record", + "content": "Content", + "scope": "/old/scope", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Create updated record with new scope + updated_record = MemoryRecord( + id="scope-change-record", + content="Content", + scope="/new/scope", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update the record + await valkey_storage._aupdate(updated_record) + + # Verify removed from old scope index + mock_glide_client.zrem.assert_called_once_with( + "scope:/old/scope", ["scope-change-record"] + ) + + # Verify added to new scope index + zadd_call = mock_glide_client.zadd.call_args + assert zadd_call[0][0] == "scope:/new/scope" + assert "scope-change-record" in zadd_call[0][1] + + @pytest.mark.asyncio + async def test_update_removes_from_old_category_indexes( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test updating categories removes record from old category indexes.""" + # Mock HGETALL to return existing record with old categories + mock_glide_client.hgetall.return_value = { + "id": "category-change-record", + "content": "Content", + "scope": "/test", + "categories": '["old-cat1", "old-cat2", "shared-cat"]', + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Create updated record with new categories (one shared, two new) + updated_record = MemoryRecord( + id="category-change-record", + content="Content", + scope="/test", + categories=["new-cat1", "new-cat2", "shared-cat"], + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update the record + await valkey_storage._aupdate(updated_record) + + # Verify removed from all old category indexes + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + category_srem_calls = [call for call in srem_calls if "category:" in str(call)] + assert len(category_srem_calls) == 3 + + # Verify removed from old-cat1, old-cat2, and shared-cat + srem_keys = [call[0][0] for call in category_srem_calls] + assert "category:old-cat1" in srem_keys + assert "category:old-cat2" in srem_keys + assert "category:shared-cat" in srem_keys + + # Verify added to all new category indexes + sadd_calls = [call for call in mock_glide_client.sadd.call_args_list] + category_sadd_calls = [call for call in sadd_calls if "category:" in str(call[0])] + assert len(category_sadd_calls) == 3 + + # Verify added to new-cat1, new-cat2, and shared-cat + sadd_keys = [call[0][0] for call in category_sadd_calls] + assert "category:new-cat1" in sadd_keys + assert "category:new-cat2" in sadd_keys + assert "category:shared-cat" in sadd_keys + + @pytest.mark.asyncio + async def test_update_removes_from_old_metadata_indexes( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test updating metadata removes record from old metadata indexes.""" + # Mock HGETALL to return existing record with old metadata + mock_glide_client.hgetall.return_value = { + "id": "metadata-change-record", + "content": "Content", + "scope": "/test", + "categories": "[]", + "metadata": '{"old_key1": "old_value1", "old_key2": "old_value2", "shared_key": "old_value"}', + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Create updated record with new metadata + updated_record = MemoryRecord( + id="metadata-change-record", + content="Content", + scope="/test", + metadata={"new_key1": "new_value1", "new_key2": "new_value2", "shared_key": "new_value"}, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update the record + await valkey_storage._aupdate(updated_record) + + # Verify removed from all old metadata indexes + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + metadata_srem_calls = [call for call in srem_calls if "metadata:" in str(call)] + assert len(metadata_srem_calls) == 3 + + # Verify removed from old metadata keys + srem_keys = [call[0][0] for call in metadata_srem_calls] + assert "metadata:old_key1:old_value1" in srem_keys + assert "metadata:old_key2:old_value2" in srem_keys + assert "metadata:shared_key:old_value" in srem_keys + + # Verify added to all new metadata indexes + sadd_calls = [call for call in mock_glide_client.sadd.call_args_list] + metadata_sadd_calls = [call for call in sadd_calls if "metadata:" in str(call[0])] + assert len(metadata_sadd_calls) == 3 + + # Verify added to new metadata keys + sadd_keys = [call[0][0] for call in metadata_sadd_calls] + assert "metadata:new_key1:new_value1" in sadd_keys + assert "metadata:new_key2:new_value2" in sadd_keys + assert "metadata:shared_key:new_value" in sadd_keys + + @pytest.mark.asyncio + async def test_update_with_empty_categories_removes_all_old_categories( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test updating to empty categories removes all old category indexes.""" + # Mock HGETALL to return existing record with categories + mock_glide_client.hgetall.return_value = { + "id": "remove-categories-record", + "content": "Content", + "scope": "/test", + "categories": '["cat1", "cat2", "cat3"]', + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Create updated record with empty categories + updated_record = MemoryRecord( + id="remove-categories-record", + content="Content", + scope="/test", + categories=[], + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update the record + await valkey_storage._aupdate(updated_record) + + # Verify removed from all old category indexes + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + category_srem_calls = [call for call in srem_calls if "category:" in str(call)] + assert len(category_srem_calls) == 3 + + # Verify no new category indexes were added + sadd_calls = [call for call in mock_glide_client.sadd.call_args_list] + category_sadd_calls = [call for call in sadd_calls if "category:" in str(call[0])] + assert len(category_sadd_calls) == 0 + + @pytest.mark.asyncio + async def test_update_with_empty_metadata_removes_all_old_metadata( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test updating to empty metadata removes all old metadata indexes.""" + # Mock HGETALL to return existing record with metadata + mock_glide_client.hgetall.return_value = { + "id": "remove-metadata-record", + "content": "Content", + "scope": "/test", + "categories": "[]", + "metadata": '{"key1": "value1", "key2": "value2"}', + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Create updated record with empty metadata + updated_record = MemoryRecord( + id="remove-metadata-record", + content="Content", + scope="/test", + metadata={}, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update the record + await valkey_storage._aupdate(updated_record) + + # Verify removed from all old metadata indexes + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + metadata_srem_calls = [call for call in srem_calls if "metadata:" in str(call)] + assert len(metadata_srem_calls) == 2 + + # Verify no new metadata indexes were added + sadd_calls = [call for call in mock_glide_client.sadd.call_args_list] + metadata_sadd_calls = [call for call in sadd_calls if "metadata:" in str(call[0])] + assert len(metadata_sadd_calls) == 0 + + @pytest.mark.asyncio + async def test_update_handles_malformed_old_data_gracefully( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test update handles malformed old data gracefully.""" + # Mock HGETALL to return record with malformed JSON + mock_glide_client.hgetall.return_value = { + "id": "malformed-record", + "content": "Content", + "scope": "/test", + "categories": "not valid json", + "metadata": "{invalid json}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Create updated record + updated_record = MemoryRecord( + id="malformed-record", + content="Updated content", + scope="/test", + categories=["new-cat"], + metadata={"new_key": "new_value"}, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update should not raise an error + await valkey_storage._aupdate(updated_record) + + # Verify HSET was called (update proceeded despite malformed old data) + mock_glide_client.hset.assert_called_once() + + # Verify new indexes were added + mock_glide_client.zadd.assert_called_once() + assert mock_glide_client.sadd.call_count >= 2 + + @pytest.mark.asyncio + async def test_update_handles_missing_created_at_gracefully( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test update handles missing created_at in old record gracefully.""" + # Mock HGETALL to return record without created_at + mock_glide_client.hgetall.return_value = { + "id": "no-created-at-record", + "content": "Content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + # Missing created_at + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Create updated record with created_at + updated_record = MemoryRecord( + id="no-created-at-record", + content="Updated content", + scope="/test", + created_at=datetime(2024, 2, 1, 10, 0, 0), + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update should not raise an error + await valkey_storage._aupdate(updated_record) + + # Verify HSET was called + mock_glide_client.hset.assert_called_once() + + # Verify created_at from updated_record was used (since old one was missing) + hset_call = mock_glide_client.hset.call_args + hset_dict = hset_call[0][1] # field_value_map dict + + # Should use the created_at from updated_record since old one was missing + assert "created_at" in hset_dict + + @pytest.mark.asyncio + async def test_update_with_numeric_metadata_values( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test update with numeric metadata values converts to strings.""" + # Mock HGETALL to return existing record + mock_glide_client.hgetall.return_value = { + "id": "numeric-metadata-record", + "content": "Content", + "scope": "/test", + "categories": "[]", + "metadata": '{"count": 10, "score": 5.5}', + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Create updated record with different numeric metadata + updated_record = MemoryRecord( + id="numeric-metadata-record", + content="Content", + scope="/test", + metadata={"count": 20, "score": 7.5, "active": True}, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Update the record + await valkey_storage._aupdate(updated_record) + + # Verify removed from old metadata indexes with string-converted values + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + metadata_srem_calls = [call for call in srem_calls if "metadata:" in str(call)] + srem_keys = [call[0][0] for call in metadata_srem_calls] + assert "metadata:count:10" in srem_keys + assert "metadata:score:5.5" in srem_keys + + # Verify added to new metadata indexes with string-converted values + sadd_calls = [call for call in mock_glide_client.sadd.call_args_list] + metadata_sadd_calls = [call for call in sadd_calls if "metadata:" in str(call[0])] + sadd_keys = [call[0][0] for call in metadata_sadd_calls] + assert "metadata:count:20" in sadd_keys + assert "metadata:score:7.5" in sadd_keys + assert "metadata:active:True" in sadd_keys + + def test_update_sync_wrapper( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync update wrapper calls async implementation.""" + # Mock HGETALL to return empty dict (new record) + mock_glide_client.hgetall.return_value = {} + + # Create record + record = MemoryRecord( + id="sync-update-record", + content="Content", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Call sync update + valkey_storage.update(record) + + # Verify async operations were called + mock_glide_client.hgetall.assert_called_once_with("record:sync-update-record") + + @pytest.mark.asyncio + async def test_update_preserves_embedding( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that update preserves embedding correctly.""" + # Mock HGETALL to return existing record + mock_glide_client.hgetall.return_value = { + "id": "embedding-update-record", + "content": "Original content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": valkey_storage._embedding_to_bytes([0.1, 0.2, 0.3, 0.4]), + "source": "", + "private": "false", + } + + # Create updated record with new embedding + new_embedding = [0.5, 0.6, 0.7, 0.8] + updated_record = MemoryRecord( + id="embedding-update-record", + content="Updated content", + scope="/test", + embedding=new_embedding, + ) + + # Update the record + await valkey_storage._aupdate(updated_record) + + # Verify HSET was called with new embedding + mock_glide_client.hset.assert_called_once() + hset_call = mock_glide_client.hset.call_args + hset_dict = hset_call[0][1] # field_value_map dict + + # Verify embedding was updated + assert "embedding" in hset_dict + # Deserialize and check values + deserialized_embedding = valkey_storage._bytes_to_embedding(hset_dict["embedding"]) + assert len(deserialized_embedding) == 4 + for i, val in enumerate(new_embedding): + assert abs(deserialized_embedding[i] - val) < 1e-6 + + +class TestValkeyStorageDelete: + """Tests for ValkeyStorage delete operation.""" + + @pytest.mark.asyncio + async def test_delete_by_record_ids( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deleting records by specific record IDs.""" + # Mock record data for deletion + mock_glide_client.hgetall.side_effect = [ + { + "id": "record-1", + "content": "Content 1", + "scope": "/test", + "categories": '["cat1", "cat2"]', + "metadata": '{"key1": "value1"}', + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + { + "id": "record-2", + "content": "Content 2", + "scope": "/test", + "categories": '["cat1"]', + "metadata": '{"key1": "value2"}', + "importance": "0.6", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + ] + + # Delete by record IDs + count = await valkey_storage.adelete(record_ids=["record-1", "record-2"]) + + # Verify correct count returned + assert count == 2 + + # Verify records were deleted + delete_calls = [call for call in mock_glide_client.delete.call_args_list] + assert len(delete_calls) == 2 + + # Verify records were removed from scope indexes + zrem_calls = [call for call in mock_glide_client.zrem.call_args_list] + assert len(zrem_calls) == 2 + assert any("scope:/test" in str(call) for call in zrem_calls) + + # Verify records were removed from category indexes + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + category_srem_calls = [call for call in srem_calls if "category:" in str(call)] + assert len(category_srem_calls) >= 2 # At least cat1 and cat2 + + # Verify records were removed from metadata indexes + metadata_srem_calls = [call for call in srem_calls if "metadata:" in str(call)] + assert len(metadata_srem_calls) >= 2 # At least key1:value1 and key1:value2 + + @pytest.mark.asyncio + async def test_delete_by_scope_prefix( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deleting records by scope prefix.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", # cursor (as bytes) + [b"scope:/agent/task1", b"scope:/agent/task2", b"scope:/other"], + ) + + # Mock zrange calls (used by _find_records_by_scope) + mock_glide_client.zrange.side_effect = [ + ["record-1", "record-2"], # zrange scope:/agent/task1 + ["record-3"], # zrange scope:/agent/task2 + [], # zrange scope:/other (not matched by prefix) + ] + + # Mock record data (for _fetch_records_for_deletion) + mock_glide_client.hgetall.side_effect = [ + { + b"id": b"record-1", + b"content": b"Content 1", + b"scope": b"/agent/task1", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + }, + + { + b"id": b"record-2", + b"content": b"Content 2", + b"scope": b"/agent/task1", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + }, + { + b"id": b"record-3", + b"content": b"Content 3", + b"scope": b"/agent/task2", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + ] + + # Delete by scope prefix + count = await valkey_storage.adelete(scope_prefix="/agent") + + # Verify correct count returned (3 records in /agent scopes) + assert count == 3 + + # Verify scan was called to find scope keys + mock_glide_client.scan.assert_called() + + # Verify zrange was called to get record IDs + assert mock_glide_client.zrange.call_count >= 2 + + # Verify records were deleted + assert mock_glide_client.delete.call_count == 3 + + + @pytest.mark.asyncio + async def test_delete_by_categories( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deleting records by categories.""" + # Mock smembers to return record IDs for categories + mock_glide_client.smembers.side_effect = [ + {"record-1", "record-2", "record-3"}, # category:planning + {"record-2", "record-3", "record-4"}, # category:execution + ] + + # Mock sinter to return intersection (records with ANY category) + mock_glide_client.sunion.return_value = {"record-1", "record-2", "record-3", "record-4"} + + # Mock record data + mock_glide_client.hgetall.side_effect = [ + { + "id": "record-1", + "content": "Content 1", + "scope": "/test", + "categories": '["planning"]', + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + { + "id": "record-2", + "content": "Content 2", + "scope": "/test", + "categories": '["planning", "execution"]', + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + + { + "id": "record-3", + "content": "Content 3", + "scope": "/test", + "categories": '["planning", "execution"]', + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + { + "id": "record-4", + "content": "Content 4", + "scope": "/test", + "categories": '["execution"]', + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + ] + + # Delete by categories (OR logic - any record with planning OR execution) + count = await valkey_storage.adelete(categories=["planning", "execution"]) + + # Verify correct count returned + assert count == 4 + + # Verify records were deleted + assert mock_glide_client.delete.call_count == 4 + + + @pytest.mark.asyncio + async def test_delete_by_older_than( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deleting records older than a specific datetime.""" + cutoff_date = datetime(2024, 1, 15, 0, 0, 0) + cutoff_timestamp = cutoff_date.timestamp() + + # Mock scan to return all scope keys + mock_glide_client.scan.return_value = ( + b"0", # cursor + [b"scope:/test"], + ) + + # Mock zrange for ZRANGEBYSCORE to return old records + mock_glide_client.zrange.return_value = ["record-1", "record-2"] + + # Mock record data + mock_glide_client.hgetall.side_effect = [ + { + b"id": b"record-1", + b"content": b"Old content 1", + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + }, + { + b"id": b"record-2", + b"content": b"Old content 2", + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-10T10:00:00", + b"last_accessed": b"2024-01-10T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + }, + ] + + # Delete records older than cutoff + count = await valkey_storage.adelete(older_than=cutoff_date) + + # Verify correct count returned + assert count == 2 + + # Verify scan was called + mock_glide_client.scan.assert_called() + + # Verify zrange was called for score-based range query + mock_glide_client.zrange.assert_called_once() + + # Verify records were deleted + assert mock_glide_client.delete.call_count == 2 + + @pytest.mark.asyncio + async def test_delete_by_metadata_filter( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deleting records by metadata filter.""" + # Mock smembers to return records matching each metadata criterion + mock_glide_client.smembers.side_effect = [ + {"record-1", "record-2", "record-3"}, # metadata:agent_id:agent-1 + {"record-1", "record-2"}, # metadata:priority:high + ] + + # Mock record data (only record-1 and record-2 match both criteria) + mock_glide_client.hgetall.side_effect = [ + { + "id": "record-1", + "content": "Content 1", + "scope": "/test", + "categories": "[]", + "metadata": '{"agent_id": "agent-1", "priority": "high"}', + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + { + "id": "record-2", + "content": "Content 2", + "scope": "/test", + "categories": "[]", + "metadata": '{"agent_id": "agent-1", "priority": "high"}', + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + ] + + # Delete by metadata filter (AND logic - both criteria must match) + count = await valkey_storage.adelete( + metadata_filter={"agent_id": "agent-1", "priority": "high"} + ) + + # Verify correct count returned (only records matching both criteria) + assert count == 2 + + # Verify smembers was called for each metadata criterion + assert mock_glide_client.smembers.call_count == 2 + + # Verify records were deleted + assert mock_glide_client.delete.call_count == 2 + + @pytest.mark.asyncio + async def test_delete_with_combined_criteria( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deleting records with combined criteria (AND logic).""" + # Mock scan for scope filtering + mock_glide_client.scan.return_value = ( + b"0", # cursor (as bytes) + [b"scope:/agent/task1", b"scope:/agent/task2"], + ) + + # Mock zrange calls (used by _find_records_by_scope) + mock_glide_client.zrange.side_effect = [ + ["record-1", "record-2", "record-3"], # zrange scope:/agent/task1 + ["record-4"], # zrange scope:/agent/task2 + ] + + # Mock smembers for category filtering (returns records with planning category) + # Only record-1 and record-2 have planning category (not record-4) + mock_glide_client.smembers.return_value = {"record-1", "record-2"} + + # The AND logic will intersect scope records (1,2,3,4) with category records (1,2) + # Result: record-1 and record-2 (both in /agent scope AND have planning category) + # Mock record data for the 2 matching records (for _fetch_records_for_deletion) + mock_glide_client.hgetall.side_effect = [ + { + b"id": b"record-1", + b"content": b"Content 1", + b"scope": b"/agent/task1", + b"categories": b'["planning"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + }, + { + b"id": b"record-2", + b"content": b"Content 2", + b"scope": b"/agent/task1", + b"categories": b'["planning"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + }, + ] + + # Mock delete, zrem, srem operations + mock_glide_client.delete.return_value = 1 + mock_glide_client.zrem.return_value = 1 + mock_glide_client.srem.return_value = 1 + + # Delete with combined criteria: scope_prefix AND categories + count = await valkey_storage.adelete( + scope_prefix="/agent", categories=["planning"] + ) + + # Verify correct count (only records in /agent scope AND with planning category) + assert count == 2 + + # Verify both scope and category filtering were used + mock_glide_client.scan.assert_called() + mock_glide_client.smembers.assert_called() + + # Verify only matching records were deleted + assert mock_glide_client.delete.call_count == 2 + + @pytest.mark.asyncio + async def test_delete_returns_correct_count( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that delete returns the correct count of deleted records.""" + # Mock record data + mock_glide_client.hgetall.side_effect = [ + { + "id": "record-1", + "content": "Content 1", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + + { + "id": "record-2", + "content": "Content 2", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + { + "id": "record-3", + "content": "Content 3", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + ] + + # Delete 3 records + count = await valkey_storage.adelete( + record_ids=["record-1", "record-2", "record-3"] + ) + + # Verify count is exactly 3 + assert count == 3 + + @pytest.mark.asyncio + async def test_delete_with_no_matching_records_returns_zero( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that delete returns 0 when no records match criteria.""" + # Mock scan to return no matching scopes + mock_glide_client.scan.return_value = (b"0", []) + + # Delete with scope that doesn't exist + count = await valkey_storage.adelete(scope_prefix="/nonexistent") + + # Verify count is 0 + assert count == 0 + + # Verify no delete operations were performed + mock_glide_client.delete.assert_not_called() + + + @pytest.mark.asyncio + async def test_delete_removes_from_all_indexes( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that delete removes records from all index structures.""" + # Mock record with multiple categories and metadata + mock_glide_client.hgetall.return_value = { + "id": "indexed-record", + "content": "Content", + "scope": "/agent/task", + "categories": '["cat1", "cat2", "cat3"]', + "metadata": '{"key1": "value1", "key2": "value2", "key3": "value3"}', + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Delete the record + count = await valkey_storage.adelete(record_ids=["indexed-record"]) + + # Verify record was deleted + assert count == 1 + mock_glide_client.delete.assert_called_once_with(["record:indexed-record"]) + + # Verify removed from scope index + mock_glide_client.zrem.assert_called_once_with( + "scope:/agent/task", ["indexed-record"] + ) + + # Verify removed from all category indexes + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + category_srem_calls = [call for call in srem_calls if "category:" in str(call)] + assert len(category_srem_calls) == 3 + + category_keys = [call[0][0] for call in category_srem_calls] + assert "category:cat1" in category_keys + assert "category:cat2" in category_keys + assert "category:cat3" in category_keys + + # Verify removed from all metadata indexes + metadata_srem_calls = [call for call in srem_calls if "metadata:" in str(call)] + assert len(metadata_srem_calls) == 3 + + metadata_keys = [call[0][0] for call in metadata_srem_calls] + assert "metadata:key1:value1" in metadata_keys + assert "metadata:key2:value2" in metadata_keys + assert "metadata:key3:value3" in metadata_keys + + + @pytest.mark.asyncio + async def test_delete_with_empty_categories_list( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test delete with empty categories list removes no category indexes.""" + # Mock record with no categories + mock_glide_client.hgetall.return_value = { + "id": "no-categories-record", + "content": "Content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Delete the record + count = await valkey_storage.adelete(record_ids=["no-categories-record"]) + + # Verify record was deleted + assert count == 1 + + # Verify no category index removals + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + category_srem_calls = [call for call in srem_calls if "category:" in str(call)] + assert len(category_srem_calls) == 0 + + @pytest.mark.asyncio + async def test_delete_with_empty_metadata_dict( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test delete with empty metadata dict removes no metadata indexes.""" + # Mock record with no metadata + mock_glide_client.hgetall.return_value = { + "id": "no-metadata-record", + "content": "Content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Delete the record + count = await valkey_storage.adelete(record_ids=["no-metadata-record"]) + + + # Verify record was deleted + assert count == 1 + + # Verify no metadata index removals + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + metadata_srem_calls = [call for call in srem_calls if "metadata:" in str(call)] + assert len(metadata_srem_calls) == 0 + + @pytest.mark.asyncio + async def test_delete_with_numeric_metadata_values( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test delete with numeric metadata values converts to strings.""" + # Mock record with numeric metadata + mock_glide_client.hgetall.return_value = { + "id": "numeric-metadata-record", + "content": "Content", + "scope": "/test", + "categories": "[]", + "metadata": '{"count": 42, "score": 3.14, "active": true}', + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Delete the record + count = await valkey_storage.adelete(record_ids=["numeric-metadata-record"]) + + # Verify record was deleted + assert count == 1 + + # Verify metadata indexes were removed with string-converted values + srem_calls = [call for call in mock_glide_client.srem.call_args_list] + metadata_srem_calls = [call for call in srem_calls if "metadata:" in str(call)] + metadata_keys = [call[0][0] for call in metadata_srem_calls] + + assert "metadata:count:42" in metadata_keys + assert "metadata:score:3.14" in metadata_keys + assert "metadata:active:True" in metadata_keys + + @pytest.mark.asyncio + async def test_delete_handles_missing_record_data_gracefully( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test delete handles missing record data gracefully.""" + # Mock hgetall to return empty dict (record doesn't exist) + mock_glide_client.hgetall.return_value = {} + + # Delete non-existent record + count = await valkey_storage.adelete(record_ids=["non-existent-record"]) + + + # Verify count is 0 (record not found) + assert count == 0 + + # Verify no delete operations were performed + mock_glide_client.delete.assert_not_called() + + @pytest.mark.asyncio + async def test_delete_with_no_criteria_returns_zero( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test delete with no criteria specified returns 0.""" + # Delete with no criteria + count = await valkey_storage.adelete() + + # Verify count is 0 + assert count == 0 + + # Verify no operations were performed + mock_glide_client.delete.assert_not_called() + mock_glide_client.scan.assert_not_called() + + @pytest.mark.asyncio + async def test_delete_with_malformed_record_data( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test delete handles malformed record data gracefully.""" + # Mock record with malformed JSON + mock_glide_client.hgetall.return_value = { + "id": "malformed-record", + "content": "Content", + "scope": "/test", + "categories": "not valid json", + "metadata": "{invalid json}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + # Delete should not raise an error + count = await valkey_storage.adelete(record_ids=["malformed-record"]) + + # Verify record was still deleted (best effort) + assert count == 1 + mock_glide_client.delete.assert_called_once() + + def test_delete_sync_wrapper( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync delete wrapper calls async implementation.""" + # Mock record data + mock_glide_client.hgetall.return_value = { + "id": "sync-delete-record", + "content": "Content", + "scope": "/test", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + } + + + # Call sync delete + count = valkey_storage.delete(record_ids=["sync-delete-record"]) + + # Verify async operation was called + assert count == 1 + mock_glide_client.delete.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_with_special_characters_in_scope( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test delete with special characters in scope path.""" + # Mock scan to return scope with special characters + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/agent:task/sub-task"], + ) + + # Mock zrange to return record IDs + mock_glide_client.zrange.return_value = ["record-1"] + + # Mock record data + mock_glide_client.hgetall.return_value = { + b"id": b"record-1", + b"content": b"Content", + b"scope": b"/agent:task/sub-task", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + } + + # Delete by scope with special characters + count = await valkey_storage.adelete(scope_prefix="/agent:task") + + # Verify record was deleted + assert count == 1 + mock_glide_client.delete.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_multiple_records_in_single_scope( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deleting multiple records in a single scope.""" + # Mock scan to return one scope + mock_glide_client.scan.return_value = (b"0", [b"scope:/test"]) + + # Mock zrange to return multiple record IDs + mock_glide_client.zrange.return_value = [ + "record-1", + "record-2", + "record-3", + "record-4", + "record-5", + ] + + + # Mock record data for all records + mock_glide_client.hgetall.side_effect = [ + { + b"id": f"record-{i}".encode(), + b"content": f"Content {i}".encode(), + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + } + for i in range(1, 6) + ] + + # Delete all records in scope + count = await valkey_storage.adelete(scope_prefix="/test") + + # Verify all 5 records were deleted + assert count == 5 + assert mock_glide_client.delete.call_count == 5 + + @pytest.mark.asyncio + async def test_delete_with_root_scope( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test deleting records with root scope '/'.""" + # Mock scan to return all scopes + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/", b"scope:/agent", b"scope:/task"], + ) + + # Mock zrange calls + mock_glide_client.zrange.side_effect = [ + ["record-1"], # zrange scope:/ + ["record-2"], # zrange scope:/agent + ["record-3"], # zrange scope:/task + ] + + # Mock record data + mock_glide_client.hgetall.side_effect = [ + { + b"id": b"record-1", + b"content": b"Content 1", + b"scope": b"/", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + }, + + { + b"id": b"record-2", + b"content": b"Content 2", + b"scope": b"/agent", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + { + "id": "record-3", + "content": "Content 3", + "scope": "/task", + "categories": "[]", + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T10:00:00", + "last_accessed": "2024-01-01T11:00:00", + "embedding": b"", + "source": "", + "private": "false", + }, + ] + + # Delete all records (root scope matches all) + count = await valkey_storage.adelete(scope_prefix="/") + + # Verify all records were deleted + assert count == 3 + assert mock_glide_client.delete.call_count == 3 + + @pytest.mark.asyncio + async def test_delete_preserves_unmatched_records( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that delete only removes matching records, not all records.""" + # Mock scan to return multiple scopes + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/agent", b"scope:/task"], + ) + + # Mock zrange - only /agent scope matches prefix + mock_glide_client.zrange.side_effect = [ + ["record-1", "record-2"], # zrange scope:/agent (matches) + [], # zrange scope:/task (doesn't match prefix, but still scanned) + ] + + # Mock record data only for matching records + mock_glide_client.hgetall.side_effect = [ + { + b"id": b"record-1", + b"content": b"Content 1", + b"scope": b"/agent", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + }, + + { + b"id": b"record-2", + b"content": b"Content 2", + b"scope": b"/agent", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T11:00:00", + b"embedding": b"", + b"source": b"", + b"private": b"false", + }, + ] + + # Delete only records in /agent scope + count = await valkey_storage.adelete(scope_prefix="/agent") + + # Verify only 2 records were deleted (not records in /task) + assert count == 2 + assert mock_glide_client.delete.call_count == 2 + + # Verify only /agent records were deleted + delete_calls = [call[0][0][0] for call in mock_glide_client.delete.call_args_list] + assert "record:record-1" in delete_calls + assert "record:record-2" in delete_calls + + + +class TestValkeyStorageIndexing: + """Tests for ValkeyStorage indexing system (_update_indexes and _remove_from_indexes).""" + + @pytest.mark.asyncio + async def test_update_indexes_with_simple_scope_path( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test scope index updates with a simple scope path.""" + record_id = "test-record-123" + scope = "/agent/task" + categories = ["planning"] + metadata = {"agent_id": "agent-1"} + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify scope index was updated + mock_glide_client.zadd.assert_called_once_with( + "scope:/agent/task", {record_id: timestamp} + ) + + # Verify category index was updated + mock_glide_client.sadd.assert_any_call("category:planning", [record_id]) + + # Verify metadata index was updated + mock_glide_client.sadd.assert_any_call("metadata:agent_id:agent-1", [record_id]) + + @pytest.mark.asyncio + async def test_update_indexes_with_root_scope( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test scope index updates with root scope '/'.""" + record_id = "root-record" + scope = "/" + categories: list[str] = [] + metadata: dict[str, str] = {} + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify root scope index was created correctly + mock_glide_client.zadd.assert_called_once_with( + "scope:/", {record_id: timestamp} + ) + + # Verify no category or metadata indexes were created + assert mock_glide_client.sadd.call_count == 0 + + @pytest.mark.asyncio + async def test_update_indexes_with_nested_scope_path( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test scope index updates with deeply nested scope path.""" + record_id = "nested-record" + scope = "/agent/task/subtask/step" + categories: list[str] = [] + metadata: dict[str, str] = {} + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify nested scope index was created correctly + mock_glide_client.zadd.assert_called_once_with( + "scope:/agent/task/subtask/step", {record_id: timestamp} + ) + + @pytest.mark.asyncio + async def test_update_indexes_with_scope_containing_special_characters( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test scope index updates with special characters in scope path.""" + record_id = "special-scope-record" + scope = "/agent:123/task-456/step_789" + categories: list[str] = [] + metadata: dict[str, str] = {} + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify scope with special characters is handled correctly + mock_glide_client.zadd.assert_called_once_with( + "scope:/agent:123/task-456/step_789", {record_id: timestamp} + ) + + @pytest.mark.asyncio + async def test_update_indexes_with_multiple_categories( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test category index updates with multiple categories.""" + record_id = "multi-category-record" + scope = "/test" + categories = ["planning", "execution", "review", "analysis"] + metadata: dict[str, str] = {} + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify all category indexes were updated + assert mock_glide_client.sadd.call_count == 4 + mock_glide_client.sadd.assert_any_call("category:planning", [record_id]) + mock_glide_client.sadd.assert_any_call("category:execution", [record_id]) + mock_glide_client.sadd.assert_any_call("category:review", [record_id]) + mock_glide_client.sadd.assert_any_call("category:analysis", [record_id]) + + @pytest.mark.asyncio + async def test_update_indexes_with_empty_categories( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test category index updates with empty categories list.""" + record_id = "no-categories-record" + scope = "/test" + categories: list[str] = [] + metadata: dict[str, str] = {} + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify scope index was updated + mock_glide_client.zadd.assert_called_once() + + # Verify no category indexes were created + assert mock_glide_client.sadd.call_count == 0 + + @pytest.mark.asyncio + async def test_update_indexes_with_categories_containing_special_characters( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test category index updates with special characters in category names.""" + record_id = "special-category-record" + scope = "/test" + categories = ["category:with:colons", "category-with-dashes", "category_with_underscores"] + metadata: dict[str, str] = {} + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify all category indexes were created with special characters preserved + assert mock_glide_client.sadd.call_count == 3 + mock_glide_client.sadd.assert_any_call("category:category:with:colons", [record_id]) + mock_glide_client.sadd.assert_any_call("category:category-with-dashes", [record_id]) + mock_glide_client.sadd.assert_any_call("category:category_with_underscores", [record_id]) + + @pytest.mark.asyncio + async def test_update_indexes_with_string_metadata_values( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test metadata index updates with string values.""" + record_id = "string-metadata-record" + scope = "/test" + categories: list[str] = [] + metadata = { + "agent_id": "agent-1", + "task_type": "planning", + "status": "active", + } + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify all metadata indexes were created + assert mock_glide_client.sadd.call_count == 3 + mock_glide_client.sadd.assert_any_call("metadata:agent_id:agent-1", [record_id]) + mock_glide_client.sadd.assert_any_call("metadata:task_type:planning", [record_id]) + mock_glide_client.sadd.assert_any_call("metadata:status:active", [record_id]) + + @pytest.mark.asyncio + async def test_update_indexes_with_numeric_metadata_values( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test metadata index updates with numeric values (converted to strings).""" + record_id = "numeric-metadata-record" + scope = "/test" + categories: list[str] = [] + metadata = { + "count": 42, + "score": 3.14159, + "priority": 1, + } + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify metadata values are converted to strings + assert mock_glide_client.sadd.call_count == 3 + mock_glide_client.sadd.assert_any_call("metadata:count:42", [record_id]) + mock_glide_client.sadd.assert_any_call("metadata:score:3.14159", [record_id]) + mock_glide_client.sadd.assert_any_call("metadata:priority:1", [record_id]) + + @pytest.mark.asyncio + async def test_update_indexes_with_boolean_metadata_values( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test metadata index updates with boolean values (converted to strings).""" + record_id = "boolean-metadata-record" + scope = "/test" + categories: list[str] = [] + metadata = { + "is_active": True, + "is_complete": False, + } + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify boolean values are converted to strings + assert mock_glide_client.sadd.call_count == 2 + mock_glide_client.sadd.assert_any_call("metadata:is_active:True", [record_id]) + mock_glide_client.sadd.assert_any_call("metadata:is_complete:False", [record_id]) + + @pytest.mark.asyncio + async def test_update_indexes_with_empty_metadata( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test metadata index updates with empty metadata dict.""" + record_id = "no-metadata-record" + scope = "/test" + categories: list[str] = [] + metadata: dict[str, str] = {} + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify scope index was updated + mock_glide_client.zadd.assert_called_once() + + # Verify no metadata indexes were created + assert mock_glide_client.sadd.call_count == 0 + + @pytest.mark.asyncio + async def test_update_indexes_with_metadata_containing_special_characters( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test metadata index updates with special characters in keys and values.""" + record_id = "special-metadata-record" + scope = "/test" + categories: list[str] = [] + metadata = { + "key:with:colons": "value:with:colons", + "key-with-dashes": "value-with-dashes", + "key_with_underscores": "value_with_underscores", + "key with spaces": "value with spaces", + } + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify all metadata indexes were created with special characters preserved + assert mock_glide_client.sadd.call_count == 4 + mock_glide_client.sadd.assert_any_call( + "metadata:key:with:colons:value:with:colons", [record_id] + ) + mock_glide_client.sadd.assert_any_call( + "metadata:key-with-dashes:value-with-dashes", [record_id] + ) + mock_glide_client.sadd.assert_any_call( + "metadata:key_with_underscores:value_with_underscores", [record_id] + ) + mock_glide_client.sadd.assert_any_call( + "metadata:key with spaces:value with spaces", [record_id] + ) + + @pytest.mark.asyncio + async def test_update_indexes_with_mixed_data_types_in_metadata( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test metadata index updates with mixed data types.""" + record_id = "mixed-metadata-record" + scope = "/test" + categories: list[str] = [] + metadata = { + "string_key": "string_value", + "int_key": 123, + "float_key": 45.67, + "bool_key": True, + } + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify all metadata indexes were created with proper type conversion + assert mock_glide_client.sadd.call_count == 4 + mock_glide_client.sadd.assert_any_call("metadata:string_key:string_value", [record_id]) + mock_glide_client.sadd.assert_any_call("metadata:int_key:123", [record_id]) + mock_glide_client.sadd.assert_any_call("metadata:float_key:45.67", [record_id]) + mock_glide_client.sadd.assert_any_call("metadata:bool_key:True", [record_id]) + + @pytest.mark.asyncio + async def test_update_indexes_with_all_fields_populated( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test index updates with scope, categories, and metadata all populated.""" + record_id = "full-record" + scope = "/agent/task" + categories = ["planning", "execution"] + metadata = {"agent_id": "agent-1", "priority": "high"} + timestamp = 1704067200.0 + + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Verify scope index was updated + mock_glide_client.zadd.assert_called_once_with( + "scope:/agent/task", {record_id: timestamp} + ) + + # Verify all indexes were updated (2 categories + 2 metadata = 4 sadd calls) + assert mock_glide_client.sadd.call_count == 4 + + @pytest.mark.asyncio + async def test_remove_from_indexes_with_simple_scope( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test removing record from indexes with simple scope.""" + record_id = "test-record-123" + scope = "/agent/task" + categories = ["planning"] + metadata = {"agent_id": "agent-1"} + + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Verify record was removed from scope index + mock_glide_client.zrem.assert_called_once_with("scope:/agent/task", [record_id]) + + # Verify record was removed from category index + mock_glide_client.srem.assert_any_call("category:planning", [record_id]) + + # Verify record was removed from metadata index + mock_glide_client.srem.assert_any_call("metadata:agent_id:agent-1", [record_id]) + + @pytest.mark.asyncio + async def test_remove_from_indexes_with_root_scope( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test removing record from indexes with root scope '/'.""" + record_id = "root-record" + scope = "/" + categories: list[str] = [] + metadata: dict[str, str] = {} + + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Verify record was removed from root scope index + mock_glide_client.zrem.assert_called_once_with("scope:/", [record_id]) + + # Verify no category or metadata removals + assert mock_glide_client.srem.call_count == 0 + + @pytest.mark.asyncio + async def test_remove_from_indexes_with_multiple_categories( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test removing record from multiple category indexes.""" + record_id = "multi-category-record" + scope = "/test" + categories = ["planning", "execution", "review"] + metadata: dict[str, str] = {} + + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Verify record was removed from all category indexes + assert mock_glide_client.srem.call_count == 3 + mock_glide_client.srem.assert_any_call("category:planning", [record_id]) + mock_glide_client.srem.assert_any_call("category:execution", [record_id]) + mock_glide_client.srem.assert_any_call("category:review", [record_id]) + + @pytest.mark.asyncio + async def test_remove_from_indexes_with_empty_categories( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test removing record with empty categories list.""" + record_id = "no-categories-record" + scope = "/test" + categories: list[str] = [] + metadata: dict[str, str] = {} + + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Verify scope removal + mock_glide_client.zrem.assert_called_once() + + # Verify no category removals + assert mock_glide_client.srem.call_count == 0 + + @pytest.mark.asyncio + async def test_remove_from_indexes_with_multiple_metadata_entries( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test removing record from multiple metadata indexes.""" + record_id = "multi-metadata-record" + scope = "/test" + categories: list[str] = [] + metadata = { + "agent_id": "agent-1", + "task_type": "planning", + "priority": "high", + } + + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Verify record was removed from all metadata indexes + assert mock_glide_client.srem.call_count == 3 + mock_glide_client.srem.assert_any_call("metadata:agent_id:agent-1", [record_id]) + mock_glide_client.srem.assert_any_call("metadata:task_type:planning", [record_id]) + mock_glide_client.srem.assert_any_call("metadata:priority:high", [record_id]) + + @pytest.mark.asyncio + async def test_remove_from_indexes_with_empty_metadata( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test removing record with empty metadata dict.""" + record_id = "no-metadata-record" + scope = "/test" + categories: list[str] = [] + metadata: dict[str, str] = {} + + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Verify scope removal + mock_glide_client.zrem.assert_called_once() + + # Verify no metadata removals + assert mock_glide_client.srem.call_count == 0 + + @pytest.mark.asyncio + async def test_remove_from_indexes_with_numeric_metadata_values( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test removing record with numeric metadata values (converted to strings).""" + record_id = "numeric-metadata-record" + scope = "/test" + categories: list[str] = [] + metadata = { + "count": 42, + "score": 3.14, + } + + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Verify metadata values are converted to strings for removal + assert mock_glide_client.srem.call_count == 2 + mock_glide_client.srem.assert_any_call("metadata:count:42", [record_id]) + mock_glide_client.srem.assert_any_call("metadata:score:3.14", [record_id]) + + @pytest.mark.asyncio + async def test_remove_from_indexes_with_boolean_metadata_values( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test removing record with boolean metadata values (converted to strings).""" + record_id = "boolean-metadata-record" + scope = "/test" + categories: list[str] = [] + metadata = { + "is_active": True, + "is_complete": False, + } + + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Verify boolean values are converted to strings for removal + assert mock_glide_client.srem.call_count == 2 + mock_glide_client.srem.assert_any_call("metadata:is_active:True", [record_id]) + mock_glide_client.srem.assert_any_call("metadata:is_complete:False", [record_id]) + + @pytest.mark.asyncio + async def test_remove_from_indexes_with_all_fields_populated( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test removing record from all index structures.""" + record_id = "full-record" + scope = "/agent/task" + categories = ["planning", "execution"] + metadata = {"agent_id": "agent-1", "priority": "high"} + + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Verify scope removal + mock_glide_client.zrem.assert_called_once_with("scope:/agent/task", [record_id]) + + # Verify all removals (2 categories + 2 metadata = 4 srem calls) + assert mock_glide_client.srem.call_count == 4 + + @pytest.mark.asyncio + async def test_remove_from_indexes_cleans_all_structures( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that remove_from_indexes cleans all index structures completely.""" + record_id = "cleanup-record" + scope = "/agent/task/subtask" + categories = ["planning", "execution", "review"] + metadata = { + "agent_id": "agent-1", + "task_type": "analysis", + "priority": 5, + "active": True, + } + + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Verify scope index cleanup + mock_glide_client.zrem.assert_called_once_with( + "scope:/agent/task/subtask", [record_id] + ) + + # Verify category index cleanup (3 categories) + mock_glide_client.srem.assert_any_call("category:planning", [record_id]) + mock_glide_client.srem.assert_any_call("category:execution", [record_id]) + mock_glide_client.srem.assert_any_call("category:review", [record_id]) + + # Verify metadata index cleanup (4 metadata entries) + mock_glide_client.srem.assert_any_call("metadata:agent_id:agent-1", [record_id]) + mock_glide_client.srem.assert_any_call("metadata:task_type:analysis", [record_id]) + mock_glide_client.srem.assert_any_call("metadata:priority:5", [record_id]) + mock_glide_client.srem.assert_any_call("metadata:active:True", [record_id]) + + # Verify total number of removals (3 categories + 4 metadata = 7 srem calls) + assert mock_glide_client.srem.call_count == 7 + + @pytest.mark.asyncio + async def test_update_then_remove_indexes_consistency( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that update and remove operations use consistent key naming.""" + record_id = "consistency-record" + scope = "/test/scope" + categories = ["cat1", "cat2"] + metadata = {"key1": "value1", "key2": 123} + timestamp = 1704067200.0 + + # Update indexes + await valkey_storage._update_indexes( + record_id, scope, categories, metadata, timestamp + ) + + # Capture the keys used in update + zadd_key = mock_glide_client.zadd.call_args[0][0] + sadd_keys = [call[0][0] for call in mock_glide_client.sadd.call_args_list] + + # Reset mocks + mock_glide_client.reset_mock() + + # Remove from indexes + await valkey_storage._remove_from_indexes( + record_id, scope, categories, metadata + ) + + # Capture the keys used in remove + zrem_key = mock_glide_client.zrem.call_args[0][0] + srem_keys = [call[0][0] for call in mock_glide_client.srem.call_args_list] + + # Verify scope keys match + assert zadd_key == zrem_key + + # Verify category and metadata keys match + assert set(sadd_keys) == set(srem_keys) diff --git a/lib/crewai/tests/memory/storage/test_valkey_storage_errors.py b/lib/crewai/tests/memory/storage/test_valkey_storage_errors.py new file mode 100644 index 0000000000..244e09b854 --- /dev/null +++ b/lib/crewai/tests/memory/storage/test_valkey_storage_errors.py @@ -0,0 +1,267 @@ +"""Tests for ValkeyStorage error handling.""" + +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from crewai.memory.storage.valkey_storage import ValkeyStorage +from crewai.memory.types import MemoryRecord + + +@pytest.fixture +def mock_glide_client() -> AsyncMock: + """Create a mock GlideClient for testing.""" + client = AsyncMock() + client.hset = AsyncMock(return_value=1) + client.zrange = AsyncMock(return_value=[]) + client.zadd = AsyncMock() + client.sadd = AsyncMock() + client.hgetall = AsyncMock(return_value={}) + client.close = AsyncMock() + return client + + +@pytest.fixture +def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage: + """Create a ValkeyStorage instance with mocked client.""" + storage = ValkeyStorage(host="localhost", port=6379, db=0) + + # Mock the client creation to return our mock + async def mock_create_client() -> AsyncMock: + storage._client = mock_glide_client + return mock_glide_client + + storage._get_client = mock_create_client # type: ignore[method-assign] + return storage + + +class TestSerializationErrors: + """Tests for serialization error handling.""" + + def test_serialization_error_raises_descriptive_exception( + self, valkey_storage: ValkeyStorage + ) -> None: + """Test that serialization errors raise descriptive ValueError.""" + # Create a record with non-serializable metadata + record = MemoryRecord( + id="test-id", + content="test content", + scope="/test", + categories=["test"], + metadata={"bad_key": object()}, # Non-serializable object + importance=0.5, + created_at=datetime.now(), + last_accessed=datetime.now(), + embedding=[0.1, 0.2, 0.3], + ) + + # Should raise ValueError with descriptive message + with pytest.raises(ValueError, match="Failed to serialize record test-id"): + valkey_storage._record_to_dict(record) + + def test_serialization_error_includes_cause( + self, valkey_storage: ValkeyStorage + ) -> None: + """Test that serialization error includes the original exception as cause.""" + # Create a mock record that will fail during JSON serialization + # We need to bypass Pydantic validation, so we'll patch json.dumps + record = MemoryRecord( + id="test-id-2", + content="test content", + scope="/test", + categories=["valid"], + metadata={"key": "value"}, + importance=0.5, + created_at=datetime.now(), + last_accessed=datetime.now(), + embedding=[0.1, 0.2, 0.3], + ) + + # Patch json.dumps to raise an error + with patch("json.dumps", side_effect=TypeError("Cannot serialize")): + with pytest.raises(ValueError) as exc_info: + valkey_storage._record_to_dict(record) + + # Verify the exception has a cause + assert exc_info.value.__cause__ is not None + assert isinstance(exc_info.value.__cause__, TypeError) + + +class TestDeserializationErrors: + """Tests for deserialization error handling.""" + + def test_deserialization_error_logs_and_returns_none( + self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture + ) -> None: + """Test that deserialization errors log error and return None.""" + # Create malformed data (missing required fields) + malformed_data = { + "id": "test-id", + "content": "test content", + # Missing scope, categories, metadata, etc. + } + + # Should return None and log error + result = valkey_storage._dict_to_record(malformed_data) + + assert result is None + assert "Failed to deserialize record test-id" in caplog.text + + def test_deserialization_with_invalid_json_categories_uses_tag_fallback( + self, valkey_storage: ValkeyStorage + ) -> None: + """Test that non-JSON categories fall back to TAG (comma-separated) parsing.""" + # Create data with non-JSON categories string + data = { + "id": "test-id-json", + "content": "test content", + "scope": "/test", + "categories": "not valid json [", # Not JSON, treated as TAG format + "metadata": "{}", + "importance": "0.5", + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "source": "", + "private": "false", + } + + result = valkey_storage._dict_to_record(data) + + # TAG fallback: comma-split produces the raw string as a single category + assert result is not None + assert result.id == "test-id-json" + assert result.categories == ["not valid json ["] + + def test_deserialization_with_invalid_datetime_returns_none( + self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture + ) -> None: + """Test that invalid datetime format returns None.""" + # Create data with invalid datetime + invalid_data = { + "id": "test-id-datetime", + "content": "test content", + "scope": "/test", + "categories": '["test"]', + "metadata": "{}", + "importance": "0.5", + "created_at": "not a datetime", # Invalid datetime + "last_accessed": "2024-01-01T12:00:00", + "source": "", + "private": "false", + } + + result = valkey_storage._dict_to_record(invalid_data) + + assert result is None + assert "Failed to deserialize record test-id-datetime" in caplog.text + + def test_deserialization_with_invalid_float_returns_none( + self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture + ) -> None: + """Test that invalid float importance returns None.""" + # Create data with invalid float + invalid_data = { + "id": "test-id-float", + "content": "test content", + "scope": "/test", + "categories": '["test"]', + "metadata": "{}", + "importance": "not a float", # Invalid float + "created_at": "2024-01-01T12:00:00", + "last_accessed": "2024-01-01T12:00:00", + "source": "", + "private": "false", + } + + result = valkey_storage._dict_to_record(invalid_data) + + assert result is None + assert "Failed to deserialize record test-id-float" in caplog.text + + def test_deserialization_with_bytes_keys_uses_tag_fallback( + self, valkey_storage: ValkeyStorage + ) -> None: + """Test that deserialization handles bytes keys with non-JSON categories via TAG fallback.""" + # Create data with bytes keys (as returned by Valkey) + bytes_data = { + b"id": b"test-id-bytes", + b"content": b"test content", + b"scope": b"/test", + b"categories": b"invalid json [", # Not JSON, treated as TAG format + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T12:00:00", + b"last_accessed": b"2024-01-01T12:00:00", + } + + result = valkey_storage._dict_to_record(bytes_data) + + # TAG fallback: comma-split produces the raw string as a single category + assert result is not None + assert result.id == "test-id-bytes" + assert result.categories == ["invalid json ["] + + +class TestRetryBehaviorIntegration: + """Integration tests demonstrating retry behavior patterns.""" + + @pytest.mark.asyncio + async def test_mock_client_operation_with_retry_pattern( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test demonstrating how retry would work with client operations.""" + from glide import ClosingError + + # Mock a client operation that fails once + mock_glide_client.hgetall.side_effect = [ + ClosingError("Connection lost"), + { + b"id": b"test-id", + b"content": b"test content", + b"scope": b"/test", + b"categories": b'["test"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T12:00:00", + b"last_accessed": b"2024-01-01T12:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + }, + ] + + # First call fails, second succeeds + with pytest.raises(ClosingError): + await mock_glide_client.hgetall("record:test-id") + + # Second call succeeds + result = await mock_glide_client.hgetall("record:test-id") + assert result is not None + + @pytest.mark.asyncio + async def test_serialization_error_not_retried( + self, valkey_storage: ValkeyStorage + ) -> None: + """Test that serialization errors are not retried (they're not connection errors).""" + # Create a record with non-serializable data + record = MemoryRecord( + id="test-id", + content="test content", + scope="/test", + categories=["test"], + metadata={"bad": object()}, + importance=0.5, + created_at=datetime.now(), + last_accessed=datetime.now(), + embedding=[0.1, 0.2, 0.3], + ) + + # Serialization error should not be retried + with pytest.raises(ValueError, match="Failed to serialize"): + valkey_storage._record_to_dict(record) diff --git a/lib/crewai/tests/memory/storage/test_valkey_storage_scope.py b/lib/crewai/tests/memory/storage/test_valkey_storage_scope.py new file mode 100644 index 0000000000..4afd2c4ef8 --- /dev/null +++ b/lib/crewai/tests/memory/storage/test_valkey_storage_scope.py @@ -0,0 +1,1110 @@ +"""Tests for ValkeyStorage scope operations.""" + +from __future__ import annotations + +from datetime import datetime +from unittest.mock import AsyncMock +from uuid import uuid4 + +import pytest + +from crewai.memory.storage.valkey_storage import ValkeyStorage +from crewai.memory.types import MemoryRecord, ScopeInfo + + +@pytest.fixture +def mock_glide_client() -> AsyncMock: + """Create a mock GlideClient for testing.""" + client = AsyncMock() + client.hset = AsyncMock(return_value=1) + client.zrange = AsyncMock(return_value=[]) + client.zadd = AsyncMock() + client.sadd = AsyncMock() + client.zrem = AsyncMock() + client.srem = AsyncMock() + client.hgetall = AsyncMock(return_value={}) + client.scan = AsyncMock() + client.smembers = AsyncMock(return_value=[]) + client.scard = AsyncMock(return_value=0) + client.delete = AsyncMock() + client.close = AsyncMock() + return client + + +@pytest.fixture +def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage: + """Create a ValkeyStorage instance with mocked client.""" + storage = ValkeyStorage(host="localhost", port=6379, db=0) + + # Mock the client creation to return our mock + async def mock_create_client() -> AsyncMock: + storage._client = mock_glide_client + return mock_glide_client + + storage._get_client = mock_create_client # type: ignore[method-assign] + return storage + + +class TestValkeyStorageListRecords: + """Tests for list_records operation.""" + + @pytest.mark.asyncio + async def test_list_records_returns_newest_first( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that list_records returns records ordered by created_at descending.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", # cursor + [b"scope:/test"], # keys + ) + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2", b"record-3"], # ZRANGE response + ] + + # Mock hgetall to return record data + def mock_hgetall(key: str) -> dict[bytes, bytes]: + if key == "record:record-1": + return { + b"id": b"record-1", + b"content": b"Content 1", + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + elif key == "record:record-2": + return { + b"id": b"record-2", + b"content": b"Content 2", + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-02T10:00:00", + b"last_accessed": b"2024-01-02T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + elif key == "record:record-3": + return { + b"id": b"record-3", + b"content": b"Content 3", + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-03T10:00:00", + b"last_accessed": b"2024-01-03T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + return {} + + mock_glide_client.hgetall.side_effect = mock_hgetall + + # List records + records = await valkey_storage._alist_records(scope_prefix="/test") + + # Verify records are ordered newest first + assert len(records) == 3 + assert records[0].id == "record-3" # Newest + assert records[1].id == "record-2" + assert records[2].id == "record-1" # Oldest + + @pytest.mark.asyncio + async def test_list_records_with_pagination_limit_only( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test pagination with limit only (no offset).""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/test"], + ) + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2", b"record-3", b"record-4", b"record-5"], + ] + + # Mock hgetall to return record data + def mock_hgetall(key: str) -> dict[bytes, bytes]: + record_id = key.split(":")[-1] + day = int(record_id.split("-")[-1]) + return { + b"id": record_id.encode(), + b"content": f"Content {day}".encode(), + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": f"2024-01-0{day}T10:00:00".encode(), + b"last_accessed": f"2024-01-0{day}T10:00:00".encode(), + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + + mock_glide_client.hgetall.side_effect = mock_hgetall + + # List records with limit only + records = await valkey_storage._alist_records(scope_prefix="/test", limit=3) + + # Verify limit works (take first 3) + assert len(records) == 3 + assert records[0].id == "record-5" # Newest + assert records[1].id == "record-4" + assert records[2].id == "record-3" + + @pytest.mark.asyncio + async def test_list_records_with_pagination_offset_only( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test pagination with offset only (default limit).""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/test"], + ) + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2", b"record-3"], + ] + + # Mock hgetall to return record data + def mock_hgetall(key: str) -> dict[bytes, bytes]: + record_id = key.split(":")[-1] + day = int(record_id.split("-")[-1]) + return { + b"id": record_id.encode(), + b"content": f"Content {day}".encode(), + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": f"2024-01-0{day}T10:00:00".encode(), + b"last_accessed": f"2024-01-0{day}T10:00:00".encode(), + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + + mock_glide_client.hgetall.side_effect = mock_hgetall + + # List records with offset only + records = await valkey_storage._alist_records(scope_prefix="/test", offset=1) + + # Verify offset works (skip first 1) + assert len(records) == 2 + assert records[0].id == "record-2" + assert records[1].id == "record-1" + + @pytest.mark.asyncio + async def test_list_records_with_pagination_limit_and_offset( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test pagination with both limit and offset.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/test"], + ) + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2", b"record-3", b"record-4", b"record-5"], + ] + + # Mock hgetall to return record data + def mock_hgetall(key: str) -> dict[bytes, bytes]: + record_id = key.split(":")[-1] + day = int(record_id.split("-")[-1]) + return { + b"id": record_id.encode(), + b"content": f"Content {day}".encode(), + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": f"2024-01-0{day}T10:00:00".encode(), + b"last_accessed": f"2024-01-0{day}T10:00:00".encode(), + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + + mock_glide_client.hgetall.side_effect = mock_hgetall + + # List records with pagination + records = await valkey_storage._alist_records( + scope_prefix="/test", limit=2, offset=1 + ) + + # Verify pagination works (skip 1, take 2) + assert len(records) == 2 + assert records[0].id == "record-4" # Second newest + assert records[1].id == "record-3" # Third newest + + @pytest.mark.asyncio + async def test_list_records_with_large_offset( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test pagination with offset beyond available records.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/test"], + ) + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2"], + ] + + # Mock hgetall to return record data + def mock_hgetall(key: str) -> dict[bytes, bytes]: + record_id = key.split(":")[-1] + day = int(record_id.split("-")[-1]) + return { + b"id": record_id.encode(), + b"content": f"Content {day}".encode(), + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": f"2024-01-0{day}T10:00:00".encode(), + b"last_accessed": f"2024-01-0{day}T10:00:00".encode(), + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + + mock_glide_client.hgetall.side_effect = mock_hgetall + + # List records with large offset + records = await valkey_storage._alist_records(scope_prefix="/test", offset=10) + + # Verify empty list when offset exceeds available records + assert len(records) == 0 + + @pytest.mark.asyncio + async def test_list_records_empty_scope( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test list_records returns empty list for empty scope.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/empty"], + ) + + # Mock ZRANGE to return no record IDs + mock_glide_client.zrange.side_effect = [ + [], # No records + ] + + # List records + records = await valkey_storage._alist_records(scope_prefix="/empty") + + # Verify empty list + assert len(records) == 0 + + def test_list_records_sync_wrapper( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync list_records wrapper calls async implementation.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/test"], + ) + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1"], + ] + + # Mock hgetall to return record data + mock_glide_client.hgetall.return_value = { + b"id": b"record-1", + b"content": b"Content 1", + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + + # Call sync wrapper + records = valkey_storage.list_records(scope_prefix="/test") + + # Verify it works + assert len(records) == 1 + assert records[0].id == "record-1" + + +class TestValkeyStorageGetScopeInfo: + """Tests for get_scope_info operation.""" + + @pytest.mark.asyncio + async def test_get_scope_info_returns_accurate_counts( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that get_scope_info returns accurate record counts and metadata.""" + # Mock scan to return scope keys + mock_glide_client.scan.side_effect = [ + (b"0", [b"scope:/test", b"scope:/test/sub"]), # First scan + (b"0", [b"scope:/test", b"scope:/test/sub"]), # Second scan for child scopes + ] + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2"], # Records in /test + [b"record-3"], # Records in /test/sub + ] + + # Mock hgetall to return record data + def mock_hgetall(key: str) -> dict[bytes, bytes]: + if key == "record:record-1": + return { + b"id": b"record-1", + b"content": b"Content 1", + b"scope": b"/test", + b"categories": b'["planning"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + elif key == "record:record-2": + return { + b"id": b"record-2", + b"content": b"Content 2", + b"scope": b"/test", + b"categories": b'["execution"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-03T10:00:00", + b"last_accessed": b"2024-01-03T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + elif key == "record:record-3": + return { + b"id": b"record-3", + b"content": b"Content 3", + b"scope": b"/test/sub", + b"categories": b'["planning"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-02T10:00:00", + b"last_accessed": b"2024-01-02T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + return {} + + mock_glide_client.hgetall.side_effect = mock_hgetall + + # Get scope info + info = await valkey_storage._aget_scope_info("/test") + + # Verify scope info + assert info.path == "/test" + assert info.record_count == 3 # All records in /test and subscopes + assert set(info.categories) == {"execution", "planning"} + assert info.oldest_record == datetime(2024, 1, 1, 10, 0, 0) + assert info.newest_record == datetime(2024, 1, 3, 10, 0, 0) + assert "/test/sub" in info.child_scopes + + @pytest.mark.asyncio + async def test_get_scope_info_returns_accurate_timestamps( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that get_scope_info returns accurate oldest and newest timestamps.""" + # Mock scan to return scope keys + mock_glide_client.scan.side_effect = [ + (b"0", [b"scope:/test"]), # First scan + (b"0", [b"scope:/test"]), # Second scan for child scopes + ] + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2", b"record-3"], + ] + + # Mock hgetall to return record data with different timestamps + def mock_hgetall(key: str) -> dict[bytes, bytes]: + if key == "record:record-1": + return { + b"id": b"record-1", + b"content": b"Content 1", + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-15T10:00:00", + b"last_accessed": b"2024-01-15T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + elif key == "record:record-2": + return { + b"id": b"record-2", + b"content": b"Content 2", + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", # Oldest + b"last_accessed": b"2024-01-01T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + elif key == "record:record-3": + return { + b"id": b"record-3", + b"content": b"Content 3", + b"scope": b"/test", + b"categories": b"[]", + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-20T10:00:00", # Newest + b"last_accessed": b"2024-01-20T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + return {} + + mock_glide_client.hgetall.side_effect = mock_hgetall + + # Get scope info + info = await valkey_storage._aget_scope_info("/test") + + # Verify timestamps + assert info.oldest_record == datetime(2024, 1, 1, 10, 0, 0) + assert info.newest_record == datetime(2024, 1, 20, 10, 0, 0) + + @pytest.mark.asyncio + async def test_get_scope_info_empty_scope( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test get_scope_info returns empty info for empty scope.""" + # Mock scan to return no matching scopes + mock_glide_client.scan.return_value = (b"0", []) + + # Get scope info for empty scope + info = await valkey_storage._aget_scope_info("/empty") + + # Verify empty scope info + assert info.path == "/empty" + assert info.record_count == 0 + assert info.categories == [] + assert info.oldest_record is None + assert info.newest_record is None + assert info.child_scopes == [] + + @pytest.mark.asyncio + async def test_get_scope_info_with_multiple_categories( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test get_scope_info aggregates categories from all records.""" + # Mock scan to return scope keys + mock_glide_client.scan.side_effect = [ + (b"0", [b"scope:/test"]), # First scan + (b"0", [b"scope:/test"]), # Second scan for child scopes + ] + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2", b"record-3"], + ] + + # Mock hgetall to return record data with various categories + def mock_hgetall(key: str) -> dict[bytes, bytes]: + if key == "record:record-1": + return { + b"id": b"record-1", + b"content": b"Content 1", + b"scope": b"/test", + b"categories": b'["planning", "execution"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + elif key == "record:record-2": + return { + b"id": b"record-2", + b"content": b"Content 2", + b"scope": b"/test", + b"categories": b'["review", "planning"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-02T10:00:00", + b"last_accessed": b"2024-01-02T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + elif key == "record:record-3": + return { + b"id": b"record-3", + b"content": b"Content 3", + b"scope": b"/test", + b"categories": b'["analysis"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-03T10:00:00", + b"last_accessed": b"2024-01-03T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + return {} + + mock_glide_client.hgetall.side_effect = mock_hgetall + + # Get scope info + info = await valkey_storage._aget_scope_info("/test") + + # Verify all unique categories are collected and sorted + assert set(info.categories) == {"analysis", "execution", "planning", "review"} + assert info.categories == ["analysis", "execution", "planning", "review"] # Sorted + + def test_get_scope_info_sync_wrapper( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync get_scope_info wrapper calls async implementation.""" + # Mock scan to return no matching scopes + mock_glide_client.scan.return_value = (b"0", []) + + # Call sync wrapper + info = valkey_storage.get_scope_info("/test") + + # Verify it works + assert info.path == "/test" + assert info.record_count == 0 + + +class TestValkeyStorageListScopes: + """Tests for list_scopes operation.""" + + @pytest.mark.asyncio + async def test_list_scopes_returns_immediate_children_only( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that list_scopes returns only immediate children, not grandchildren.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [ + b"scope:/agent", + b"scope:/agent/task", + b"scope:/agent/task/subtask", + b"scope:/crew", + ], + ) + + # List scopes under root + scopes = await valkey_storage._alist_scopes("/") + + # Verify only immediate children are returned + assert len(scopes) == 2 + assert "/agent" in scopes + assert "/crew" in scopes + assert "/agent/task" not in scopes # Grandchild not included + + @pytest.mark.asyncio + async def test_list_scopes_with_parent( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test list_scopes with specific parent path.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [ + b"scope:/agent", + b"scope:/agent/task", + b"scope:/agent/task/subtask", + b"scope:/agent/memory", + ], + ) + + # List scopes under /agent + scopes = await valkey_storage._alist_scopes("/agent") + + # Verify only immediate children of /agent are returned + assert len(scopes) == 2 + assert "/agent/task" in scopes + assert "/agent/memory" in scopes + assert "/agent/task/subtask" not in scopes # Grandchild not included + + @pytest.mark.asyncio + async def test_list_scopes_returns_sorted_order( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that list_scopes returns scopes in sorted order.""" + # Mock scan to return scope keys in random order + mock_glide_client.scan.return_value = ( + b"0", + [ + b"scope:/zebra", + b"scope:/alpha", + b"scope:/beta", + b"scope:/gamma", + ], + ) + + # List scopes under root + scopes = await valkey_storage._alist_scopes("/") + + # Verify scopes are sorted + assert scopes == ["/alpha", "/beta", "/gamma", "/zebra"] + + @pytest.mark.asyncio + async def test_list_scopes_empty_parent( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test list_scopes returns empty list when parent has no children.""" + # Mock scan to return scope keys that don't match parent + mock_glide_client.scan.return_value = ( + b"0", + [ + b"scope:/agent", + b"scope:/crew", + ], + ) + + # List scopes under /other (no children) + scopes = await valkey_storage._alist_scopes("/other") + + # Verify empty list + assert len(scopes) == 0 + + @pytest.mark.asyncio + async def test_list_scopes_with_deep_hierarchy( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test list_scopes with deep scope hierarchy.""" + # Mock scan to return scope keys with deep nesting + mock_glide_client.scan.return_value = ( + b"0", + [ + b"scope:/a", + b"scope:/a/b", + b"scope:/a/b/c", + b"scope:/a/b/c/d", + b"scope:/a/x", + ], + ) + + # List scopes under /a/b + scopes = await valkey_storage._alist_scopes("/a/b") + + # Verify only immediate children are returned + assert len(scopes) == 1 + assert "/a/b/c" in scopes + assert "/a/b/c/d" not in scopes # Grandchild not included + + def test_list_scopes_sync_wrapper( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync list_scopes wrapper calls async implementation.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/agent", b"scope:/crew"], + ) + + # Call sync wrapper + scopes = valkey_storage.list_scopes("/") + + # Verify it works + assert len(scopes) == 2 + assert "/agent" in scopes + assert "/crew" in scopes + + +class TestValkeyStorageListCategories: + """Tests for list_categories operation.""" + + @pytest.mark.asyncio + async def test_list_categories_global_returns_accurate_counts( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test list_categories returns accurate global category counts.""" + # Mock scan to return category keys + mock_glide_client.scan.return_value = ( + b"0", + [b"category:planning", b"category:execution", b"category:review"], + ) + + # Mock scard to return category counts + def mock_scard(key: str) -> int: + if key == "category:planning": + return 5 + elif key == "category:execution": + return 3 + elif key == "category:review": + return 2 + return 0 + + mock_glide_client.scard.side_effect = mock_scard + + # List categories globally + categories = await valkey_storage._alist_categories(scope_prefix=None) + + # Verify category counts + assert categories == {"planning": 5, "execution": 3, "review": 2} + + @pytest.mark.asyncio + async def test_list_categories_with_scope_returns_accurate_counts( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test list_categories with scope filtering returns accurate counts.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/test"], + ) + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2", b"record-3"], + ] + + # Mock hgetall to return record data with categories + def mock_hgetall(key: str) -> dict[bytes, bytes]: + if key == "record:record-1": + return { + b"id": b"record-1", + b"content": b"Content 1", + b"scope": b"/test", + b"categories": b'["planning", "execution"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-01T10:00:00", + b"last_accessed": b"2024-01-01T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + elif key == "record:record-2": + return { + b"id": b"record-2", + b"content": b"Content 2", + b"scope": b"/test", + b"categories": b'["planning"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-02T10:00:00", + b"last_accessed": b"2024-01-02T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + elif key == "record:record-3": + return { + b"id": b"record-3", + b"content": b"Content 3", + b"scope": b"/test", + b"categories": b'["execution"]', + b"metadata": b"{}", + b"importance": b"0.5", + b"created_at": b"2024-01-03T10:00:00", + b"last_accessed": b"2024-01-03T10:00:00", + b"source": b"", + b"private": b"false", + b"embedding": b"", + } + return {} + + mock_glide_client.hgetall.side_effect = mock_hgetall + + # List categories in scope + categories = await valkey_storage._alist_categories(scope_prefix="/test") + + # Verify category counts + assert categories == {"planning": 2, "execution": 2} + + @pytest.mark.asyncio + async def test_list_categories_empty_scope( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test list_categories returns empty dict for empty scope.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/empty"], + ) + + # Mock ZRANGE to return no record IDs + mock_glide_client.zrange.side_effect = [ + [], # No records + ] + + # List categories in empty scope + categories = await valkey_storage._alist_categories(scope_prefix="/empty") + + # Verify empty dict + assert categories == {} + + @pytest.mark.asyncio + async def test_list_categories_global_empty( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test list_categories returns empty dict when no categories exist.""" + # Mock scan to return no category keys + mock_glide_client.scan.return_value = (b"0", []) + + # List categories globally + categories = await valkey_storage._alist_categories(scope_prefix=None) + + # Verify empty dict + assert categories == {} + + def test_list_categories_sync_wrapper( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync list_categories wrapper calls async implementation.""" + # Mock scan to return category keys + mock_glide_client.scan.return_value = ( + b"0", + [b"category:planning"], + ) + + # Mock scard to return category count + mock_glide_client.scard.return_value = 5 + + # Call sync wrapper + categories = valkey_storage.list_categories(scope_prefix=None) + + # Verify it works + assert categories == {"planning": 5} + + +class TestValkeyStorageCount: + """Tests for count operation.""" + + @pytest.mark.asyncio + async def test_count_all_records_returns_correct_total( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test count returns correct total count across all scopes.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/test1", b"scope:/test2"], + ) + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2"], # /test1 + [b"record-3", b"record-4", b"record-5"], # /test2 + ] + + # Count all records + count = await valkey_storage._acount(scope_prefix=None) + + # Verify total count + assert count == 5 + + @pytest.mark.asyncio + async def test_count_with_scope_returns_correct_total( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test count with scope filtering returns correct total.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/test", b"scope:/test/sub"], + ) + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2"], # /test + [b"record-3"], # /test/sub + ] + + # Count records in scope + count = await valkey_storage._acount(scope_prefix="/test") + + # Verify count includes subscopes + assert count == 3 + + @pytest.mark.asyncio + async def test_count_empty_scope( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test count returns 0 for empty scope.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/empty"], + ) + + # Mock ZRANGE to return no record IDs + mock_glide_client.zrange.side_effect = [ + [], # No records + ] + + # Count records in empty scope + count = await valkey_storage._acount(scope_prefix="/empty") + + # Verify count is 0 + assert count == 0 + + @pytest.mark.asyncio + async def test_count_deduplicates_records( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test count deduplicates records that appear in multiple scopes.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/test1", b"scope:/test2"], + ) + + # Mock ZRANGE to return overlapping record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2"], # /test1 + [b"record-2", b"record-3"], # /test2 (record-2 appears in both) + ] + + # Count all records + count = await valkey_storage._acount(scope_prefix=None) + + # Verify count deduplicates (3 unique records, not 4) + assert count == 3 + + def test_count_sync_wrapper( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync count wrapper calls async implementation.""" + # Mock scan to return scope keys + mock_glide_client.scan.return_value = ( + b"0", + [b"scope:/test"], + ) + + # Mock ZRANGE to return record IDs + mock_glide_client.zrange.side_effect = [ + [b"record-1", b"record-2"], + ] + + # Call sync wrapper + count = valkey_storage.count(scope_prefix="/test") + + # Verify it works + assert count == 2 + + +class TestValkeyStorageReset: + """Tests for reset operation.""" + + @pytest.mark.asyncio + async def test_reset_clears_all_records( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test reset delegates to adelete to clear all records.""" + # Mock adelete to track if it was called + original_adelete = valkey_storage.adelete + adelete_called = False + adelete_args = None + + async def mock_adelete(*args: object, **kwargs: object) -> int: + nonlocal adelete_called, adelete_args + adelete_called = True + adelete_args = kwargs + return 0 + + valkey_storage.adelete = mock_adelete # type: ignore[method-assign] + + # Reset all records + await valkey_storage._areset(scope_prefix=None) + + # Verify adelete was called with correct arguments + assert adelete_called + assert adelete_args == {"scope_prefix": None} + + # Restore original method + valkey_storage.adelete = original_adelete # type: ignore[method-assign] + + @pytest.mark.asyncio + async def test_reset_with_scope_clears_scope_records( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test reset with scope delegates to adelete with scope_prefix.""" + # Mock adelete to track if it was called + original_adelete = valkey_storage.adelete + adelete_called = False + adelete_args = None + + async def mock_adelete(*args: object, **kwargs: object) -> int: + nonlocal adelete_called, adelete_args + adelete_called = True + adelete_args = kwargs + return 0 + + valkey_storage.adelete = mock_adelete # type: ignore[method-assign] + + # Reset records in scope + await valkey_storage._areset(scope_prefix="/test") + + # Verify adelete was called with correct arguments + assert adelete_called + assert adelete_args == {"scope_prefix": "/test"} + + # Restore original method + valkey_storage.adelete = original_adelete # type: ignore[method-assign] + + def test_reset_sync_wrapper( + self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync reset wrapper calls async implementation.""" + # Mock adelete to track if it was called + original_adelete = valkey_storage.adelete + adelete_called = False + + async def mock_adelete(*args: object, **kwargs: object) -> int: + nonlocal adelete_called + adelete_called = True + return 0 + + valkey_storage.adelete = mock_adelete # type: ignore[method-assign] + + # Call sync wrapper + valkey_storage.reset(scope_prefix="/test") + + # Verify adelete was called + assert adelete_called + + # Restore original method + valkey_storage.adelete = original_adelete # type: ignore[method-assign] diff --git a/lib/crewai/tests/memory/storage/test_valkey_storage_search.py b/lib/crewai/tests/memory/storage/test_valkey_storage_search.py new file mode 100644 index 0000000000..73ee989206 --- /dev/null +++ b/lib/crewai/tests/memory/storage/test_valkey_storage_search.py @@ -0,0 +1,998 @@ +"""Tests for ValkeyStorage vector search operation.""" + +from __future__ import annotations + +import json +from datetime import datetime +from unittest.mock import AsyncMock, patch +from uuid import uuid4 + +import pytest + +from crewai.memory.storage.valkey_storage import ValkeyStorage +from crewai.memory.types import MemoryRecord + + +@pytest.fixture +def mock_glide_client() -> AsyncMock: + """Create a mock GlideClient for testing.""" + client = AsyncMock() + client.hset = AsyncMock(return_value=1) + client.zrange = AsyncMock(return_value=[]) + client.zadd = AsyncMock() + client.sadd = AsyncMock() + client.hgetall = AsyncMock(return_value={}) + client.close = AsyncMock() + return client + + +@pytest.fixture +def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage: + """Create a ValkeyStorage instance with mocked client.""" + storage = ValkeyStorage(host="localhost", port=6379, db=0) + + # Mock the client creation to return our mock + async def mock_create_client() -> AsyncMock: + storage._client = mock_glide_client + return mock_glide_client + + storage._get_client = mock_create_client # type: ignore[method-assign] + return storage + + +def create_mock_ft_search_response( + records: list[tuple[MemoryRecord, float]] +) -> list[int | dict[str, dict[str, str]]]: + """Create a mock FT.SEARCH response in native dict format. + + Args: + records: List of (MemoryRecord, score) tuples to include in response. + + Returns: + Mock FT.SEARCH response in the native format: + [total_count, {doc_key: {field: value, ...}, ...}] + """ + if not records: + return [0] + + docs: dict[str, dict[str, str]] = {} + + for record, score in records: + doc_key = f"record:{record.id}" + + # Build field dict + fields: dict[str, str] = {} + fields["id"] = record.id + fields["content"] = record.content + fields["scope"] = record.scope + fields["categories"] = json.dumps(record.categories) + fields["metadata"] = json.dumps(record.metadata) + fields["importance"] = str(record.importance) + fields["created_at"] = record.created_at.isoformat() + fields["last_accessed"] = record.last_accessed.isoformat() + fields["source"] = record.source or "" + fields["private"] = "true" if record.private else "false" + + # Add score (Valkey Search returns cosine distance, not similarity) + # Convert similarity to distance: distance = 2 * (1 - similarity) + distance = 2.0 * (1.0 - score) + fields["score"] = str(distance) + + # Add embedding if present + if record.embedding: + fields["embedding"] = json.dumps(record.embedding) + + docs[doc_key] = fields + + return [len(records), docs] + + +class TestValkeyStorageVectorSearch: + """Tests for ValkeyStorage vector search operation.""" + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_no_filters_returns_all_records( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with no filters returns all records.""" + # Create test records + record1 = MemoryRecord( + id="record-1", + content="First test record", + scope="/test", + categories=["cat1"], + metadata={"key": "value1"}, + importance=0.8, + created_at=datetime(2024, 1, 1, 10, 0, 0), + last_accessed=datetime(2024, 1, 1, 11, 0, 0), + embedding=[0.1, 0.2, 0.3, 0.4], + ) + record2 = MemoryRecord( + id="record-2", + content="Second test record", + scope="/test", + categories=["cat2"], + metadata={"key": "value2"}, + importance=0.6, + created_at=datetime(2024, 1, 2, 10, 0, 0), + last_accessed=datetime(2024, 1, 2, 11, 0, 0), + embedding=[0.2, 0.3, 0.4, 0.5], + ) + + # Mock FT.INFO to simulate index exists + mock_ft_list.return_value = [b"memory_index"] + # Mock FT.SEARCH to return both records + mock_ft_search.return_value = create_mock_ft_search_response([ + (record1, 0.95), + (record2, 0.85), + ]) + + # Perform search with no filters + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=10) + + # Verify ft.search was called + mock_ft_search.assert_called_once() + + # Verify query contains only KNN part (no filters) + call_args = mock_ft_search.call_args + query = call_args[0][2] # 3rd positional arg: query string + assert "*=>[KNN 10 @embedding $BLOB AS score]" in query + assert "@scope" not in query + assert "@categories" not in query + + # Verify results + assert len(results) == 2 + assert results[0][0].id == "record-1" + assert results[0][1] == 0.95 + assert results[1][0].id == "record-2" + assert results[1][1] == 0.85 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_scope_filter_only( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with scope filter only.""" + record1 = MemoryRecord( + id="record-1", + content="Record in scope", + scope="/agent/task", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + scope_prefix="/agent", + limit=10 + ) + + # Verify query contains scope filter + call_args = mock_ft_search.call_args + query = call_args[0][2] + assert "(@scope:{/agent*})=>[KNN 10 @embedding $BLOB AS score]" in query + + # Verify results + assert len(results) == 1 + assert results[0][0].id == "record-1" + assert results[0][0].scope == "/agent/task" + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_category_filter_only( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with category filter only.""" + record1 = MemoryRecord( + id="record-1", + content="Record with planning category", + scope="/test", + categories=["planning"], + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.88)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + categories=["planning", "execution"], + limit=10 + ) + + # Verify query contains category filter with OR logic + call_args = mock_ft_search.call_args + query = call_args[0][2] + assert "(@categories:{planning|execution})=>[KNN 10 @embedding $BLOB AS score]" in query + + # Verify results + assert len(results) == 1 + assert results[0][0].id == "record-1" + assert "planning" in results[0][0].categories + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_metadata_filter_only( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with metadata filter only.""" + record1 = MemoryRecord( + id="record-1", + content="Record with metadata", + scope="/test", + metadata={"agent_id": "agent-1", "priority": "high"}, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.92)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + metadata_filter={"agent_id": "agent-1", "priority": "high"}, + limit=10 + ) + + # Verify query contains metadata filters (AND logic) + call_args = mock_ft_search.call_args + query = call_args[0][2] + assert "@agent_id:{agent\\-1}" in query or "@agent_id:{agent-1}" in query + assert "@priority:{high}" in query + assert "=>[KNN 10 @embedding $BLOB AS score]" in query + + # Verify results + assert len(results) == 1 + assert results[0][0].id == "record-1" + assert results[0][0].metadata["agent_id"] == "agent-1" + assert results[0][0].metadata["priority"] == "high" + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_combined_filters( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with combined filters (scope + categories + metadata).""" + record1 = MemoryRecord( + id="record-1", + content="Record matching all filters", + scope="/agent/task", + categories=["planning"], + metadata={"agent_id": "agent-1"}, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.93)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + scope_prefix="/agent", + categories=["planning"], + metadata_filter={"agent_id": "agent-1"}, + limit=10 + ) + + # Verify query contains all filters + call_args = mock_ft_search.call_args + query = call_args[0][2] + assert "@scope:{/agent*}" in query + assert "@categories:{planning}" in query + assert "@agent_id:{agent\\-1}" in query or "@agent_id:{agent-1}" in query + assert "=>[KNN 10 @embedding $BLOB AS score]" in query + + # Verify results + assert len(results) == 1 + assert results[0][0].id == "record-1" + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_respects_limit_parameter( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search respects limit parameter.""" + records = [ + ( + MemoryRecord( + id=f"record-{i}", + content=f"Record {i}", + scope="/test", + embedding=[0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i], + ), + 0.9 - (i * 0.1) + ) + for i in range(1, 6) + ] + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response(records[:3]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=3) + + # Verify KNN limit in query + call_args = mock_ft_search.call_args + query = call_args[0][2] + assert "=>[KNN 3 @embedding $BLOB AS score]" in query + + # Verify results respect limit + assert len(results) == 3 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_respects_min_score_parameter( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search respects min_score parameter.""" + record1 = MemoryRecord( + id="record-1", + content="High score record", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + record2 = MemoryRecord( + id="record-2", + content="Medium score record", + scope="/test", + embedding=[0.2, 0.3, 0.4, 0.5], + ) + record3 = MemoryRecord( + id="record-3", + content="Low score record", + scope="/test", + embedding=[0.3, 0.4, 0.5, 0.6], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([ + (record1, 0.95), + (record2, 0.75), + (record3, 0.55), + ]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + limit=10, + min_score=0.7 + ) + + # Verify only records with score >= 0.7 are returned + assert len(results) == 2 + assert results[0][0].id == "record-1" + assert results[0][1] == 0.95 + assert results[1][0].id == "record-2" + assert results[1][1] == 0.75 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_returns_results_ordered_by_descending_score( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search returns results ordered by descending score.""" + record1 = MemoryRecord( + id="record-1", + content="Medium score", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + record2 = MemoryRecord( + id="record-2", + content="Highest score", + scope="/test", + embedding=[0.2, 0.3, 0.4, 0.5], + ) + record3 = MemoryRecord( + id="record-3", + content="Lowest score", + scope="/test", + embedding=[0.3, 0.4, 0.5, 0.6], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([ + (record1, 0.75), + (record2, 0.95), + (record3, 0.55), + ]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=10) + + # Verify results are ordered by descending score + assert len(results) == 3 + assert results[0][0].id == "record-2" + assert results[0][1] == 0.95 + assert results[1][0].id == "record-1" + assert results[1][1] == 0.75 + assert results[2][0].id == "record-3" + assert results[2][1] == 0.55 + + # Verify scores are in descending order + for i in range(len(results) - 1): + assert results[i][1] >= results[i + 1][1] + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_empty_results( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with no matching results.""" + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = [0] # Total count = 0 + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=10) + + # Verify empty results + assert len(results) == 0 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_special_characters_in_scope( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with special characters in scope prefix.""" + record1 = MemoryRecord( + id="record-1", + content="Record with special scope", + scope="/agent:task-1", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + scope_prefix="/agent:task", + limit=10 + ) + + # Verify query contains escaped scope + call_args = mock_ft_search.call_args + query = call_args[0][2] + assert "@scope:{/agent\\:task*}" in query or "@scope:{/agent:task*}" in query + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_special_characters_in_categories( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with special characters in categories.""" + record1 = MemoryRecord( + id="record-1", + content="Record with special category", + scope="/test", + categories=["plan:execute"], + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + categories=["plan:execute"], + limit=10 + ) + + # Verify query contains escaped category + call_args = mock_ft_search.call_args + query = call_args[0][2] + assert "@categories:{plan\\:execute}" in query or "@categories:{plan:execute}" in query + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_numeric_metadata_values( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with numeric metadata values.""" + record1 = MemoryRecord( + id="record-1", + content="Record with numeric metadata", + scope="/test", + metadata={"count": 42, "score": 3.14}, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + metadata_filter={"count": 42, "score": 3.14}, + limit=10 + ) + + # Verify query contains string-converted metadata values + call_args = mock_ft_search.call_args + query = call_args[0][2] + assert "@count:{42}" in query + assert "@score:{3" in query and "14}" in query + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_embedding_blob_parameter( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search passes embedding as BLOB parameter.""" + record1 = MemoryRecord( + id="record-1", + content="Test record", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=10) + + # Verify ft.search was called with search options containing BLOB param + call_args = mock_ft_search.call_args + # The 4th positional arg is the FtSearchOptions + search_options = call_args[0][3] + # The options object should have params with BLOB + assert search_options is not None + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_results_sorted_by_score( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search results are sorted by score (descending) automatically.""" + record1 = MemoryRecord( + id="record-1", + content="Test record", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=10) + + # Verify ft.search was called (results are auto-sorted by vector search) + mock_ft_search.assert_called_once() + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_return_fields( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search includes RETURN clause with all record fields.""" + record1 = MemoryRecord( + id="record-1", + content="Test record", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=10) + + # Verify ft.search was called with search options containing return fields + call_args = mock_ft_search.call_args + search_options = call_args[0][3] + assert search_options is not None + # The FtSearchOptions should have return_fields set + assert search_options.return_fields is not None + assert len(search_options.return_fields) == 11 # All fields including score + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.VectorFieldAttributesHnsw") + @patch("crewai.memory.storage.valkey_storage.ft.create") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_handles_valkey_search_not_available( + self, mock_ft_list: AsyncMock, mock_ft_create: AsyncMock, + mock_vector_attrs: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search raises error when Valkey Search module is not available.""" + # Mock FT.INFO to fail (index doesn't exist) + mock_ft_list.return_value = [] + # Mock FT.CREATE to fail (Search module not available) + mock_ft_create.side_effect = Exception("ERR unknown command 'ft.create'") + + query_embedding = [0.1, 0.2, 0.3, 0.4] + + with pytest.raises(RuntimeError, match="Valkey Search module is not available"): + await valkey_storage.asearch(query_embedding, limit=10) + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_handles_ft_search_error( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search handles FT.SEARCH errors gracefully.""" + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.side_effect = Exception("ERR unknown command 'FT.SEARCH'") + + query_embedding = [0.1, 0.2, 0.3, 0.4] + + with pytest.raises(RuntimeError, match="Valkey Search module is not available"): + await valkey_storage.asearch(query_embedding, limit=10) + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_handles_malformed_ft_search_response( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search handles malformed FT.SEARCH response gracefully.""" + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = None # Malformed response + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=10) + + # Verify empty results are returned (graceful handling) + assert len(results) == 0 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_handles_missing_score_field( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search handles missing score field in results.""" + record1 = MemoryRecord( + id="record-1", + content="Test record", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Create mock response without score field (dict format) + docs = { + f"record:{record1.id}": { + "id": record1.id, + "content": record1.content, + "scope": record1.scope, + "categories": str(record1.categories), + "metadata": str(record1.metadata), + "importance": str(record1.importance), + "created_at": record1.created_at.isoformat(), + "last_accessed": record1.last_accessed.isoformat(), + "source": record1.source or "", + "private": "false", + # No score field + } + } + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = [1, docs] + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=10) + + # Verify record is returned with default score of 0.0 + assert len(results) == 1 + assert results[0][0].id == "record-1" + assert results[0][1] == 0.0 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_filters_out_records_with_deserialization_errors( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search filters out records that fail deserialization.""" + valid_record = MemoryRecord( + id="valid-record", + content="Valid record", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Create mock response with one valid and one invalid record (dict format) + docs = { + f"record:{valid_record.id}": { + "id": valid_record.id, + "content": valid_record.content, + "scope": valid_record.scope, + "categories": str(valid_record.categories), + "metadata": str(valid_record.metadata), + "importance": str(valid_record.importance), + "created_at": valid_record.created_at.isoformat(), + "last_accessed": valid_record.last_accessed.isoformat(), + "source": valid_record.source or "", + "private": "false", + "score": "0.1", + }, + "record:invalid-record": { + "id": "invalid-record", + # Missing content, scope, and other required fields + "score": "0.2", + }, + } + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = [2, docs] + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=10) + + # Verify only valid record is returned + assert len(results) == 1 + assert results[0][0].id == "valid-record" + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_converts_cosine_distance_to_similarity( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search converts Valkey Search cosine distance to similarity score.""" + record1 = MemoryRecord( + id="record-1", + content="Test record", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Create mock response with distance score (dict format) + docs = { + f"record:{record1.id}": { + "id": record1.id, + "content": record1.content, + "scope": record1.scope, + "categories": str(record1.categories), + "metadata": str(record1.metadata), + "importance": str(record1.importance), + "created_at": record1.created_at.isoformat(), + "last_accessed": record1.last_accessed.isoformat(), + "source": record1.source or "", + "private": "false", + "score": "0.1", # Distance = 0.1 + } + } + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = [1, docs] + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=10) + + # Verify similarity score is correctly converted + assert len(results) == 1 + assert results[0][0].id == "record-1" + # Distance 0.1 -> Similarity = 1 - (0.1 / 2) = 0.95 + assert abs(results[0][1] - 0.95) < 0.01 + + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + def test_search_sync_wrapper( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test that sync search wrapper calls async implementation.""" + record1 = MemoryRecord( + id="record-1", + content="Test record", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = valkey_storage.search(query_embedding, limit=10) + + # Verify ft.search was called + assert mock_ft_search.call_count >= 1 + + # Verify results + assert len(results) == 1 + assert results[0][0].id == "record-1" + assert results[0][1] == 0.9 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_multiple_categories_uses_or_logic( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with multiple categories uses OR logic.""" + record1 = MemoryRecord( + id="record-1", + content="Record with one matching category", + scope="/test", + categories=["planning"], + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + categories=["planning", "execution", "review"], + limit=10 + ) + + # Verify query contains OR logic for categories + call_args = mock_ft_search.call_args + query = call_args[0][2] + assert "@categories:{planning|execution|review}" in query + + # Verify record with only one matching category is returned + assert len(results) == 1 + assert results[0][0].id == "record-1" + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_multiple_metadata_filters_uses_and_logic( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with multiple metadata filters uses AND logic.""" + record1 = MemoryRecord( + id="record-1", + content="Record matching all metadata", + scope="/test", + metadata={"agent_id": "agent-1", "priority": "high", "status": "active"}, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + metadata_filter={"agent_id": "agent-1", "priority": "high", "status": "active"}, + limit=10 + ) + + # Verify query contains AND logic for metadata + call_args = mock_ft_search.call_args + query = call_args[0][2] + assert "@agent_id:" in query + assert "@priority:" in query + assert "@status:" in query + + # Verify record matching all metadata is returned + assert len(results) == 1 + assert results[0][0].id == "record-1" + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_zero_limit_returns_empty( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with limit=0 returns empty results.""" + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = [0] + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch(query_embedding, limit=0) + + # Verify empty results + assert len(results) == 0 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_min_score_one_filters_all( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with min_score=1.0 filters out all non-perfect matches.""" + record1 = MemoryRecord( + id="record-1", + content="High score but not perfect", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.99)]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + limit=10, + min_score=1.0 + ) + + # Verify all results are filtered out + assert len(results) == 0 + + @pytest.mark.asyncio + @patch("crewai.memory.storage.valkey_storage.ft.search") + @patch("crewai.memory.storage.valkey_storage.ft.list") + async def test_search_with_min_score_zero_returns_all( + self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock, + valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock + ) -> None: + """Test search with min_score=0.0 returns all results.""" + record1 = MemoryRecord( + id="record-1", + content="High score", + scope="/test", + embedding=[0.1, 0.2, 0.3, 0.4], + ) + record2 = MemoryRecord( + id="record-2", + content="Low score", + scope="/test", + embedding=[0.2, 0.3, 0.4, 0.5], + ) + + mock_ft_list.return_value = [b"memory_index"] + mock_ft_search.return_value = create_mock_ft_search_response([ + (record1, 0.95), + (record2, 0.05), + ]) + + query_embedding = [0.1, 0.2, 0.3, 0.4] + results = await valkey_storage.asearch( + query_embedding, + limit=10, + min_score=0.0 + ) + + # Verify all results are returned + assert len(results) == 2 + assert results[0][0].id == "record-1" + assert results[1][0].id == "record-2" \ No newline at end of file diff --git a/lib/crewai/tests/memory/test_embedding_safety.py b/lib/crewai/tests/memory/test_embedding_safety.py new file mode 100644 index 0000000000..ac5288e40c --- /dev/null +++ b/lib/crewai/tests/memory/test_embedding_safety.py @@ -0,0 +1,115 @@ +"""Tests for embedding safety: bytes→float validators and async-safe embed_texts.""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from crewai.memory.types import MemoryRecord, embed_text, embed_texts + + +class TestMemoryRecordEmbeddingValidator: + """Tests for MemoryRecord.validate_embedding (bytes→list[float]).""" + + def test_none_embedding_stays_none(self) -> None: + r = MemoryRecord(content="test", embedding=None) + assert r.embedding is None + + def test_list_of_floats_passes_through(self) -> None: + r = MemoryRecord(content="test", embedding=[0.1, 0.2, 0.3]) + assert r.embedding == [0.1, 0.2, 0.3] + + def test_bytes_converted_to_list_float(self) -> None: + arr = np.array([0.1, 0.2, 0.3], dtype=np.float32) + raw_bytes = arr.tobytes() + r = MemoryRecord(content="test", embedding=raw_bytes) + assert r.embedding is not None + assert len(r.embedding) == 3 + assert all(isinstance(x, float) for x in r.embedding) + np.testing.assert_allclose(r.embedding, [0.1, 0.2, 0.3], atol=1e-6) + + def test_empty_bytes_becomes_none(self) -> None: + r = MemoryRecord(content="test", embedding=b"") + assert r.embedding is None + + def test_list_of_ints_converted_to_floats(self) -> None: + r = MemoryRecord(content="test", embedding=[1, 2, 3]) + assert r.embedding == [1.0, 2.0, 3.0] + assert all(isinstance(x, float) for x in r.embedding) + + def test_numpy_array_converted_to_list(self) -> None: + arr = np.array([0.5, 0.6], dtype=np.float32) + r = MemoryRecord(content="test", embedding=arr) + assert r.embedding is not None + assert isinstance(r.embedding, list) + assert len(r.embedding) == 2 + + +class TestEmbedTextsAsyncSafety: + """Tests for embed_texts running safely in async context.""" + + def test_embed_texts_sync_context(self) -> None: + """embed_texts works in a normal sync context.""" + embedder = MagicMock(return_value=[[0.1, 0.2], [0.3, 0.4]]) + result = embed_texts(embedder, ["hello", "world"]) + assert len(result) == 2 + assert result[0] == [0.1, 0.2] + embedder.assert_called_once() + + def test_embed_texts_empty_input(self) -> None: + embedder = MagicMock() + assert embed_texts(embedder, []) == [] + embedder.assert_not_called() + + def test_embed_texts_all_empty_strings(self) -> None: + embedder = MagicMock() + result = embed_texts(embedder, ["", " ", ""]) + assert result == [[], [], []] + embedder.assert_not_called() + + def test_embed_texts_skips_empty_preserves_positions(self) -> None: + embedder = MagicMock(return_value=[[0.1, 0.2]]) + result = embed_texts(embedder, ["", "hello", ""]) + assert result == [[], [0.1, 0.2], []] + embedder.assert_called_once_with(["hello"]) + + def test_embed_texts_in_async_context(self) -> None: + """embed_texts uses thread pool when called from async context.""" + embedder = MagicMock(return_value=[[0.1, 0.2]]) + + async def run() -> list[list[float]]: + return embed_texts(embedder, ["hello"]) + + result = asyncio.run(run()) + assert result == [[0.1, 0.2]] + embedder.assert_called_once() + + +class TestEmbedText: + """Tests for embed_text (single text).""" + + def test_empty_string_returns_empty(self) -> None: + embedder = MagicMock() + assert embed_text(embedder, "") == [] + embedder.assert_not_called() + + def test_whitespace_only_returns_empty(self) -> None: + embedder = MagicMock() + assert embed_text(embedder, " ") == [] + embedder.assert_not_called() + + def test_normal_text_returns_embedding(self) -> None: + embedder = MagicMock(return_value=[[0.1, 0.2, 0.3]]) + result = embed_text(embedder, "hello") + assert result == [0.1, 0.2, 0.3] + + def test_numpy_array_result_converted(self) -> None: + arr = np.array([0.1, 0.2], dtype=np.float32) + embedder = MagicMock(return_value=[arr]) + result = embed_text(embedder, "hello") + assert isinstance(result, list) + assert len(result) == 2 diff --git a/lib/crewai/tests/utilities/test_cache_config.py b/lib/crewai/tests/utilities/test_cache_config.py new file mode 100644 index 0000000000..de28b52f32 --- /dev/null +++ b/lib/crewai/tests/utilities/test_cache_config.py @@ -0,0 +1,117 @@ +"""Tests for shared cache configuration helpers.""" + +from __future__ import annotations + +import os +from unittest.mock import patch + +import pytest + +from crewai.utilities.cache_config import ( + get_aiocache_config, + parse_cache_url, + use_valkey_cache, +) + + +class TestParseCacheUrl: + """Tests for parse_cache_url().""" + + def test_returns_none_when_no_env_vars(self) -> None: + with patch.dict(os.environ, {}, clear=True): + assert parse_cache_url() is None + + def test_parses_valkey_url(self) -> None: + with patch.dict( + os.environ, {"VALKEY_URL": "redis://myhost:6380/2"}, clear=True + ): + result = parse_cache_url() + assert result is not None + assert result["host"] == "myhost" + assert result["port"] == 6380 + assert result["db"] == 2 + assert result["password"] is None + + def test_parses_redis_url(self) -> None: + with patch.dict( + os.environ, {"REDIS_URL": "redis://localhost:6379/0"}, clear=True + ): + result = parse_cache_url() + assert result is not None + assert result["host"] == "localhost" + assert result["port"] == 6379 + assert result["db"] == 0 + + def test_valkey_url_takes_priority_over_redis_url(self) -> None: + with patch.dict( + os.environ, + { + "VALKEY_URL": "redis://valkey-host:6380/1", + "REDIS_URL": "redis://redis-host:6379/0", + }, + clear=True, + ): + result = parse_cache_url() + assert result is not None + assert result["host"] == "valkey-host" + assert result["port"] == 6380 + + def test_parses_password(self) -> None: + with patch.dict( + os.environ, + {"VALKEY_URL": "redis://:s3cret@myhost:6379/0"}, + clear=True, + ): + result = parse_cache_url() + assert result is not None + assert result["password"] == "s3cret" + + def test_defaults_for_minimal_url(self) -> None: + with patch.dict( + os.environ, {"VALKEY_URL": "redis://myhost"}, clear=True + ): + result = parse_cache_url() + assert result is not None + assert result["host"] == "myhost" + assert result["port"] == 6379 + assert result["db"] == 0 + assert result["password"] is None + + +class TestGetAiocacheConfig: + """Tests for get_aiocache_config().""" + + def test_returns_memory_cache_when_no_url(self) -> None: + with patch.dict(os.environ, {}, clear=True): + config = get_aiocache_config() + assert config["default"]["cache"] == "aiocache.SimpleMemoryCache" + + def test_returns_redis_cache_when_url_set(self) -> None: + with patch.dict( + os.environ, {"VALKEY_URL": "redis://myhost:6380/2"}, clear=True + ): + config = get_aiocache_config() + assert config["default"]["cache"] == "aiocache.RedisCache" + assert config["default"]["endpoint"] == "myhost" + assert config["default"]["port"] == 6380 + assert config["default"]["db"] == 2 + + +class TestUseValkeyCache: + """Tests for use_valkey_cache().""" + + def test_returns_false_when_not_set(self) -> None: + with patch.dict(os.environ, {}, clear=True): + assert use_valkey_cache() is False + + def test_returns_true_when_set(self) -> None: + with patch.dict( + os.environ, {"VALKEY_URL": "redis://localhost:6379"}, clear=True + ): + assert use_valkey_cache() is True + + def test_returns_false_when_only_redis_url_set(self) -> None: + with patch.dict( + os.environ, {"REDIS_URL": "redis://localhost:6379"}, clear=True + ): + assert use_valkey_cache() is False diff --git a/pyproject.toml b/pyproject.toml index b15b4ac14c..382e085cdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -196,6 +196,8 @@ override-dependencies = [ "python-multipart>=0.0.26,<1", "langsmith>=0.7.31,<0.8", "authlib>=1.6.11", + # scrapegraph-py 2.x removed Client class; pin until upstream fixes type ignores + "scrapegraph-py>=1.46.0,<2", ] [tool.uv.workspace] diff --git a/uv.lock b/uv.lock index 5101cea490..924b75cafd 100644 --- a/uv.lock +++ b/uv.lock @@ -13,7 +13,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-27T16:00:00Z" +exclude-newer = "2026-04-28T04:00:00Z" [manifest] members = [ @@ -34,6 +34,7 @@ overrides = [ { name = "pypdf", specifier = ">=6.10.2,<7" }, { name = "python-multipart", specifier = ">=0.0.26,<1" }, { name = "rich", specifier = ">=13.7.1" }, + { name = "scrapegraph-py", specifier = ">=1.46.0,<2" }, { name = "transformers", marker = "python_full_version >= '3.10'", specifier = ">=5.4.0" }, { name = "urllib3", specifier = ">=2.6.3" }, { name = "uv", specifier = ">=0.11.6,<1" }, @@ -1360,6 +1361,9 @@ qdrant-edge = [ tools = [ { name = "crewai-tools" }, ] +valkey = [ + { name = "valkey-glide" }, +] voyageai = [ { name = "voyageai" }, ] @@ -1421,9 +1425,10 @@ requires-dist = [ { name = "tomli", specifier = "~=2.0.2" }, { name = "tomli-w", specifier = "~=1.1.0" }, { name = "uv", specifier = "~=0.11.6" }, + { name = "valkey-glide", marker = "extra == 'valkey'", specifier = ">=1.3.0" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = "~=0.3.5" }, ] -provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "voyageai", "watson"] +provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "valkey", "voyageai", "watson"] [[package]] name = "crewai-devtools" @@ -9392,6 +9397,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/6e/3e955517e22cbdd565f2f8b2e73d52528b14b8bcfdb04f62466b071de847/validators-0.35.0-py3-none-any.whl", hash = "sha256:e8c947097eae7892cb3d26868d637f79f47b4a0554bc6b80065dfe5aac3705dd", size = 44712, upload-time = "2025-05-01T05:42:04.203Z" }, ] +[[package]] +name = "valkey-glide" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "protobuf" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/32/35/fb0401c4bc7be748d937e95213786d21d9e56767b3ad816db5bad6f92c01/valkey_glide-2.0.1.tar.gz", hash = "sha256:4f9c62a88aedffd725cced7d28a9488b27e3f675d1a5294b4962624e97d346c4", size = 1026255, upload-time = "2025-06-20T01:08:15.861Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/a3/bf5ff3841538d0bb337371e073dc2c0e93f748f7f8b10a44806f36ab5fa1/valkey_glide-2.0.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:b3307934b76557b18ac559f327592cc09fc895fc653ba46010dd6d70fb6239dc", size = 5074638, upload-time = "2025-06-20T01:07:30.16Z" }, + { url = "https://files.pythonhosted.org/packages/0f/c4/20b66dced96bdca81aa294b39bc03018ed22628c52076752e8d1d3540a7d/valkey_glide-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6b83d34e2e723e97c41682479b0dce5882069066e808316292b363855992b449", size = 4750261, upload-time = "2025-06-20T01:07:32.452Z" }, + { url = "https://files.pythonhosted.org/packages/53/58/6440e66bde8963d86bc3c44d88f993059f2a9d7ebdb3256a695d035cff50/valkey_glide-2.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1baaf14d09d464ae645be5bdb5dc6b8a38b7eacf22f9dcb2907200c74fbdcdd3", size = 4767755, upload-time = "2025-06-20T01:07:33.86Z" }, + { url = "https://files.pythonhosted.org/packages/3b/69/dd5c350ce4d2cadde0d83beb601f05e1e62622895f268135e252e8bfc307/valkey_glide-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4427e7b4d54c9de289a35032c19d5956f94376f5d4335206c5ac4524cbd1c64a", size = 5094507, upload-time = "2025-06-20T01:07:35.349Z" }, + { url = "https://files.pythonhosted.org/packages/b5/dd/0dd6614e09123a5bd7273bf1159c958d1ea65e7decc2190b225d212e0cb9/valkey_glide-2.0.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:6379582d6fbd817697fb119274e37d397db450103cd15d4bd71e555e6d88fb6b", size = 5072939, upload-time = "2025-06-20T01:07:36.948Z" }, + { url = "https://files.pythonhosted.org/packages/c6/04/986188e407231a5f0bfaf31f31b68e3605ab66f4f4c656adfbb0345669d9/valkey_glide-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0f1c0fe003026d8ae172369e0eb2337cbff16f41d4c085332487d6ca2e5282e6", size = 4750491, upload-time = "2025-06-20T01:07:38.659Z" }, + { url = "https://files.pythonhosted.org/packages/ac/fb/2f5cec71ae51c464502a892b6825426cd74a2c325827981726e557926c94/valkey_glide-2.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82c5f33598e50bcfec6fc924864931f3c6e30cd327a9c9562e1c7ac4e17e79fd", size = 4767597, upload-time = "2025-06-20T01:07:40.091Z" }, + { url = "https://files.pythonhosted.org/packages/3a/31/851a1a734fe5da5d520106fcfd824e4da09c3be8a0a2123bb4b1980db1ea/valkey_glide-2.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79039a9dc23bb074680f171c12b36b3322357a0af85125534993e81a619dce21", size = 5094383, upload-time = "2025-06-20T01:07:41.329Z" }, + { url = "https://files.pythonhosted.org/packages/fc/6d/1e7b432cbc02fe63e7496b984b7fc830fb7de388c877b237e0579a6300fc/valkey_glide-2.0.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:f55ec8968b0fde364a5b3399be34b89dcb9068994b5cd384e20db0773ad12723", size = 5075024, upload-time = "2025-06-20T01:07:42.917Z" }, + { url = "https://files.pythonhosted.org/packages/ca/39/6e9f83970590d17d19f596e1b3a366d39077624888e3dd709309efc67690/valkey_glide-2.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:21598f49313912ad27dc700d7b13a3b4bfed7ed9dffad207235cac7d218f4966", size = 4748418, upload-time = "2025-06-20T01:07:44.64Z" }, + { url = "https://files.pythonhosted.org/packages/98/0e/91335c13dc8e7ceb95063234c16010b46e2dd874a2edef62dea155081647/valkey_glide-2.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f662285146529328e2b5a0a7047f699339b4e0d250eb1f252b15c9befa0dea05", size = 4767264, upload-time = "2025-06-20T01:07:46.185Z" }, + { url = "https://files.pythonhosted.org/packages/5f/94/ee4d9d441f83fec1464d9f4e52f7940bdd2aeb917589e6abd57498880876/valkey_glide-2.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3939aaa8411fcbba00cb1ff7c7ba73f388bb1deca919972f65cba7eda1d5fa95", size = 5093543, upload-time = "2025-06-20T01:07:47.345Z" }, + { url = "https://files.pythonhosted.org/packages/ed/7e/257a2e4b61ac29d5923f89bad5fe62be7b4a19e7bec78d191af3ce77aa39/valkey_glide-2.0.1-cp313-cp313-macosx_10_7_x86_64.whl", hash = "sha256:c49b53011a05b5820d0c660ee5c76574183b413a54faa33cf5c01ce77164d9c8", size = 5073114, upload-time = "2025-06-20T01:07:48.885Z" }, + { url = "https://files.pythonhosted.org/packages/20/14/a8a470679953980af7eac3ccb09638f2a76d4547116d48cbc69ae6f25080/valkey_glide-2.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3a23572b83877537916ba36ad0a6b2fd96581534f0bc67ef8f8498bf4dbb2b40", size = 4747717, upload-time = "2025-06-20T01:07:50.092Z" }, + { url = "https://files.pythonhosted.org/packages/9f/49/f168dd0c778d9f6ff1be70d5d3bad7a86928fee563de7de5f4f575eddfd8/valkey_glide-2.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:943a2c4a5c38b8a6b53281201d5a4997ec454a6fdda72d27050eeb6aaef12afb", size = 4767128, upload-time = "2025-06-20T01:07:51.306Z" }, + { url = "https://files.pythonhosted.org/packages/43/be/68961b14ea133d1792ce50f6df1753848b5377c3e06a8dbe4e39188a549a/valkey_glide-2.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d770ec581acc59d5597e7ccaac37aee7e3b5e716a77a7fa44e2967db3a715f53", size = 5093522, upload-time = "2025-06-20T01:07:52.546Z" }, + { url = "https://files.pythonhosted.org/packages/51/2e/ad8595ffe84317385d52ceab8de1e9ef06a4da6b81ca8cd61b7961923de4/valkey_glide-2.0.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d4a9ccfe2b190c90622849dab62f9468acf76a282719a1245d272b649e7c12d1", size = 5074539, upload-time = "2025-06-20T01:07:59.87Z" }, + { url = "https://files.pythonhosted.org/packages/db/e5/2122541c7a64706f3631655209bb0b13723fb99db3c190d9a792b4e7d494/valkey_glide-2.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9aa004077b82f64b23ea0d38d948b5116c23f7228dae3a5b4fcfa1799f8ff7de", size = 4753222, upload-time = "2025-06-20T01:08:01.376Z" }, + { url = "https://files.pythonhosted.org/packages/6c/13/cd9a20988a820ff61b127d3f850887b28bb734daf2c26d512d8e4c2e8e9e/valkey_glide-2.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:631a7a0e2045f7e5e3706e1903beeddf381a6529e318c27230798f4382579e4f", size = 4771530, upload-time = "2025-06-20T01:08:02.6Z" }, + { url = "https://files.pythonhosted.org/packages/c7/fc/047e89cc01b4cc71db1b6b8160d3b5d050097b408028022c002351238641/valkey_glide-2.0.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ed905fb62368c9bc6aef9df8d66269ef51f968dc527da4d7c956927382c1d", size = 5091242, upload-time = "2025-06-20T01:08:04.111Z" }, + { url = "https://files.pythonhosted.org/packages/1c/9e/68790c1a263f3a0094d67d0109be34631f6f79c2fbce5ced7e33a65ad363/valkey_glide-2.0.1-pp311-pypy311_pp73-macosx_10_7_x86_64.whl", hash = "sha256:53da3cc47c8d946ac76ecc4b468a469d3486778833a59162ea69aa7ce70cbb27", size = 5072793, upload-time = "2025-06-20T01:08:05.562Z" }, + { url = "https://files.pythonhosted.org/packages/1f/ae/a935af65ae4069d76c69f28f6bfb4533da8b89f7fc418beb7a1482cdd9ee/valkey_glide-2.0.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:e526a7d718cdd299d6b03091c12dcc15cd02ff22fe420f253341a4891c50824d", size = 4753435, upload-time = "2025-06-20T01:08:07.149Z" }, + { url = "https://files.pythonhosted.org/packages/3b/c2/c91d753a89dd87dce2fc8932cfbe174c7a1226c657b3cd64c063f21d4fe6/valkey_glide-2.0.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d3345ea2adf6f745733fa5157d8709bcf5ffbb2674391aeebd8f166a37cbc96", size = 4771401, upload-time = "2025-06-20T01:08:08.359Z" }, + { url = "https://files.pythonhosted.org/packages/00/fe/ad83cfc2ac87bf6bad2b75fa64fca5a6dd54568c1de551d36d369e07f948/valkey_glide-2.0.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1c5fff0f12d2aa4277ddc335035b2c8e12bb11243c1a0f3c35071f4a8b11064", size = 5091360, upload-time = "2025-06-20T01:08:09.622Z" }, +] + [[package]] name = "vcrpy" version = "7.0.0"