diff --git a/.github/workflows/vdb-tests-full.yml b/.github/workflows/vdb-tests-full.yml index fbd073b6725b1a..e647bba9b59741 100644 --- a/.github/workflows/vdb-tests-full.yml +++ b/.github/workflows/vdb-tests-full.yml @@ -56,6 +56,7 @@ jobs: # tidb # tiflash + # - name: Check VDB Ready (TiDB) # run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py @@ -63,6 +64,6 @@ jobs: run: | uv run --project api pytest \ --start-vdb \ - --vdb-services "weaviate,qdrant,couchbase-server,etcd,minio,milvus-standalone,pgvecto-rs,pgvector,chroma,elasticsearch,oceanbase" \ + --vdb-services "weaviate,qdrant,couchbase-server,etcd,minio,milvus-standalone,pgvecto-rs,pgvector,chroma,elasticsearch,oceanbase,valkey" \ --timeout "${PYTEST_TIMEOUT:-180}" \ api/providers/vdb/*/tests/integration_tests diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 972ab881727f4d..76da88ef27428b 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -53,6 +53,7 @@ jobs: # tidb # tiflash + # - name: Check VDB Ready (TiDB) # run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py @@ -64,4 +65,5 @@ jobs: api/providers/vdb/vdb-chroma/tests/integration_tests \ api/providers/vdb/vdb-pgvector/tests/integration_tests \ api/providers/vdb/vdb-qdrant/tests/integration_tests \ + api/providers/vdb/vdb-valkey/tests/integration_tests \ api/providers/vdb/vdb-weaviate/tests/integration_tests diff --git a/api/.env.example b/api/.env.example index f645ba7bf02d36..09b66786d00630 100644 --- a/api/.env.example +++ b/api/.env.example @@ -201,7 +201,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,* COOKIE_DOMAIN= # Vector database configuration -# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`. +# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`, `valkey`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -231,6 +231,14 @@ QDRANT_GRPC_ENABLED=false QDRANT_GRPC_PORT=6334 QDRANT_REPLICATION_FACTOR=1 +# Valkey configuration, only available when VECTOR_STORE is `valkey` +VALKEY_HOST=localhost +VALKEY_PORT=6379 +VALKEY_PASSWORD= +VALKEY_DB=0 +VALKEY_USE_SSL=false +VALKEY_DISTANCE_METRIC=COSINE + #Couchbase configuration COUCHBASE_CONNECTION_STRING=127.0.0.1 COUCHBASE_USER=Administrator diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 865bb48c676452..8d30ce74949443 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -46,6 +46,7 @@ from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig from .vdb.tidb_vector_config import TiDBVectorConfig from .vdb.upstash_config import UpstashConfig +from .vdb.valkey_config import ValkeyConfig from .vdb.vastbase_vector_config import VastbaseVectorConfig from .vdb.vikingdb_config import VikingDBConfig from .vdb.weaviate_config import WeaviateConfig @@ -407,5 +408,6 @@ class MiddlewareConfig( TableStoreConfig, DatasetQueueMonitorConfig, MatrixoneConfig, + ValkeyConfig, ): pass diff --git a/api/configs/middleware/vdb/valkey_config.py b/api/configs/middleware/vdb/valkey_config.py new file mode 100644 index 00000000000000..47dc6c0d5ed900 --- /dev/null +++ b/api/configs/middleware/vdb/valkey_config.py @@ -0,0 +1,38 @@ +from typing import Literal + +from pydantic import Field, NonNegativeInt, PositiveInt +from pydantic_settings import BaseSettings + + +class ValkeyConfig(BaseSettings): + """Configuration settings for Valkey vector database (valkey-search module).""" + + VALKEY_HOST: str = Field( + description="Hostname or IP address of the Valkey server.", + default="localhost", + ) + + VALKEY_PORT: PositiveInt = Field( + description="Port number for the Valkey server (default is 6379).", + default=6379, + ) + + VALKEY_PASSWORD: str = Field( + description="Password for authenticating with the Valkey server.", + default="", + ) + + VALKEY_DB: NonNegativeInt = Field( + description="Valkey database number to use (default is 0).", + default=0, + ) + + VALKEY_USE_SSL: bool = Field( + description="Whether to use SSL/TLS for the Valkey connection.", + default=False, + ) + + VALKEY_DISTANCE_METRIC: Literal["COSINE", "L2", "IP"] = Field( + description="Distance metric for vector similarity search. Options: COSINE, L2, IP.", + default="COSINE", + ) diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 9cce8e4c32208d..3288f7e7abb0ba 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -33,5 +33,6 @@ class VectorType(StrEnum): HUAWEI_CLOUD = "huawei_cloud" MATRIXONE = "matrixone" CLICKZETTA = "clickzetta" + VALKEY = "valkey" IRIS = "iris" HOLOGRES = "hologres" diff --git a/api/providers/vdb/vdb-valkey/pyproject.toml b/api/providers/vdb/vdb-valkey/pyproject.toml new file mode 100644 index 00000000000000..0157d6db0efd26 --- /dev/null +++ b/api/providers/vdb/vdb-valkey/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-valkey" +version = "0.0.1" + +dependencies = [ + "valkey-glide>=1.3.0", +] +description = "Dify vector store backend (dify-vdb-valkey)." + +[project.entry-points."dify.vector_backends"] +valkey = "dify_vdb_valkey.valkey_vector:ValkeyVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/vdb/vdb-valkey/src/dify_vdb_valkey/__init__.py b/api/providers/vdb/vdb-valkey/src/dify_vdb_valkey/__init__.py new file mode 100644 index 00000000000000..8b137891791fe9 --- /dev/null +++ b/api/providers/vdb/vdb-valkey/src/dify_vdb_valkey/__init__.py @@ -0,0 +1 @@ + diff --git a/api/providers/vdb/vdb-valkey/src/dify_vdb_valkey/valkey_vector.py b/api/providers/vdb/vdb-valkey/src/dify_vdb_valkey/valkey_vector.py new file mode 100644 index 00000000000000..f544153953360e --- /dev/null +++ b/api/providers/vdb/vdb-valkey/src/dify_vdb_valkey/valkey_vector.py @@ -0,0 +1,621 @@ +"""Valkey vector store backend using valkey-glide and the valkey-search module. + +This module implements the Dify vector store interface for Valkey, using the +``valkey-search`` module's ``FT.CREATE`` / ``FT.SEARCH`` / ``FT.DROPINDEX`` commands +for vector similarity search and full-text search. + +Data is stored as Valkey HASH keys. Each document gets a hash key containing: +- ``vector``: the embedding as raw FLOAT32 bytes +- ``page_content``: the document text +- ``metadata``: JSON-serialised metadata dict +- ``group_id``: the dataset/group identifier + +An FT index is created per collection with HNSW vector indexing, TAG fields +for ``group_id`` / ``doc_id`` / ``document_id``, and a TEXT field for +``page_content`` to support full-text search. + +Dimensions are auto-detected from the first embedding on index creation. +The distance metric (COSINE, L2, IP) is configurable via ``VALKEY_DISTANCE_METRIC``. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import struct +import uuid +from typing import Any + +from pydantic import BaseModel + +from configs import dify_config +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document + +# ``redis_client`` is Dify's internal Redis instance (used for caching and +# Celery). It is **not** the Valkey vector store — it is only used here for +# distributed locking during index creation so that concurrent workers don't +# race to create the same FT index. +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + +# Distance metrics supported by valkey-search. +_VALID_DISTANCE_METRICS = frozenset({"COSINE", "L2", "IP"}) + + +# --------------------------------------------------------------------------- +# Pure helpers — no external dependencies, safe to unit-test directly. +# --------------------------------------------------------------------------- + + +def _float_vector_to_bytes(vector: list[float]) -> bytes: + """Pack a list of floats into little-endian FLOAT32 bytes for Valkey.""" + return struct.pack(f"<{len(vector)}f", *vector) + + +def _bytes_to_float_vector(data: bytes) -> list[float]: + """Unpack little-endian FLOAT32 bytes back into a list of floats.""" + count = len(data) // 4 + return list(struct.unpack(f"<{count}f", data)) + + +def _to_str(value: Any) -> str: + """Convert bytes or str to str.""" + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + return str(value) if value is not None else "" + + +def _escape_tag(value: str) -> str: + """Escape special characters in a TAG value for FT.SEARCH queries.""" + special = r"\.+*?[{()|^$!<>~@&\"-]" + result: list[str] = [] + for ch in value: + if ch in special: + result.append(f"\\{ch}") + else: + result.append(ch) + return "".join(result) + + +def _escape_text(value: str) -> str: + """Escape special characters in a TEXT query for FT.SEARCH.""" + special = r"@!{}()|-=>~*\'\"" + result: list[str] = [] + for ch in value: + if ch in special: + result.append(f"\\{ch}") + else: + result.append(ch) + return "".join(result) + + +def _parse_dict_keys(result: Any) -> list[str]: + """Parse key names from a glide FT.SEARCH ``RETURN 0`` response. + + Glide returns ``[total_count, {key: {}, ...}]``. + """ + if not result or not isinstance(result, (list, tuple)) or len(result) < 2: + return [] + entries = result[1] + if not isinstance(entries, dict): + return [] + return [_to_str(k) for k in entries] + + +def _distance_to_similarity(distance: float, metric: str) -> float: + """Convert a valkey-search distance value to a [0, 1] similarity score. + + Valkey-search distance definitions (per the FT.CREATE docs): + - COSINE: ``1 - cos(θ)``, range [0, 2]. Similarity = ``1 - d/2``. + - L2: Euclidean distance, range [0, ∞). Similarity = ``1 / (1 + d)``. + - IP: ``1 - dot(X, Y)``. Similarity = ``1 - d`` (already normalised + when vectors are unit-length). + """ + metric_upper = metric.upper() + if metric_upper == "COSINE": + return 1.0 - distance / 2.0 + if metric_upper == "L2": + return 1.0 / (1.0 + distance) + if metric_upper == "IP": + return 1.0 - distance + # Fallback — treat as raw distance inversion. + return 1.0 - distance + + +def _parse_vector_search_results( + result: Any, + score_threshold: float, + distance_metric: str, +) -> list[Document]: + """Parse FT.SEARCH results from a vector KNN query. + + The glide client returns ``[total_count, {key: {field: value, ...}, ...}]``. + The ``__vector_score`` field contains the raw distance. + """ + if not result or not isinstance(result, (list, tuple)) or len(result) < 2: + return [] + + entries = result[1] + if not isinstance(entries, dict): + return [] + + docs: list[Document] = [] + for _key, fields in entries.items(): + if not isinstance(fields, dict): + continue + + score_raw = fields.get(b"__vector_score") or fields.get("__vector_score") + if score_raw is None: + continue + distance = float(_to_str(score_raw)) + score = _distance_to_similarity(distance, distance_metric) + + if score < score_threshold: + continue + + metadata_raw = fields.get(b"metadata") or fields.get(Field.METADATA_KEY, b"{}") + metadata = json.loads(_to_str(metadata_raw)) if metadata_raw else {} + metadata["score"] = score + + page_content = _to_str(fields.get(b"page_content") or fields.get(Field.CONTENT_KEY, b"")) + docs.append(Document(page_content=page_content, metadata=metadata)) + + docs.sort(key=lambda d: d.metadata.get("score", 0) if d.metadata else 0, reverse=True) + return docs + + +def _parse_full_text_results(result: Any) -> list[tuple[str, Document]]: + """Parse FT.SEARCH results from a full-text query. + + Returns a list of ``(key_name, Document)`` tuples. + """ + if not result or not isinstance(result, (list, tuple)) or len(result) < 2: + return [] + + entries = result[1] + if not isinstance(entries, dict): + return [] + + pairs: list[tuple[str, Document]] = [] + for key_raw, fields in entries.items(): + if not isinstance(fields, dict): + continue + + key = _to_str(key_raw) + metadata_raw = fields.get(b"metadata") or fields.get(Field.METADATA_KEY, b"{}") + metadata = json.loads(_to_str(metadata_raw)) if metadata_raw else {} + page_content = _to_str(fields.get(b"page_content") or fields.get(Field.CONTENT_KEY, b"")) + pairs.append((key, Document(page_content=page_content, metadata=metadata))) + + return pairs + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +class ValkeyVectorConfig(BaseModel): + """Connection parameters for the Valkey server.""" + + host: str = "localhost" + port: int = 6379 + password: str = "" + db: int = 0 + use_ssl: bool = False + distance_metric: str = "COSINE" + + +# --------------------------------------------------------------------------- +# Async client helper +# --------------------------------------------------------------------------- + +# The glide client is bound to the event loop it was created on. We keep a +# dedicated loop per ValkeyVector instance so every ``_run`` call dispatches +# on the correct loop. + + +def _create_glide_client(config: ValkeyVectorConfig) -> tuple[Any, asyncio.AbstractEventLoop]: + """Create a valkey-glide ``GlideClient`` and the event loop it lives on. + + Returns ``(client, loop)``. Callers must use ``loop.run_until_complete`` + for all subsequent async operations on *client*. + """ + from glide import GlideClient, GlideClientConfiguration, NodeAddress + + addresses = [NodeAddress(config.host, config.port)] + kwargs: dict[str, Any] = { + "addresses": addresses, + "use_tls": config.use_ssl, + "request_timeout": 30_000, + "client_name": "dify_vector_store", + } + if config.password: + from glide import ServerCredentials + + kwargs["credentials"] = ServerCredentials(password=config.password) + + glide_config = GlideClientConfiguration(**kwargs) + if config.db: + glide_config.database_id = config.db + + loop = asyncio.new_event_loop() + client = loop.run_until_complete(GlideClient.create(glide_config)) + return client, loop + + +# --------------------------------------------------------------------------- +# ValkeyVector +# --------------------------------------------------------------------------- + + +class ValkeyVector(BaseVector): + """Valkey vector store implementation using the valkey-search module.""" + + _config: ValkeyVectorConfig + _client: Any + _loop: asyncio.AbstractEventLoop + _group_id: str + _prefix: str + + def __init__( + self, + collection_name: str, + group_id: str, + config: ValkeyVectorConfig, + *, + client: Any | None = None, + loop: asyncio.AbstractEventLoop | None = None, + ) -> None: + super().__init__(collection_name) + self._config = config + self._group_id = group_id + self._prefix = f"doc:{collection_name}:" + if client is not None: + self._client = client + self._loop = loop or asyncio.new_event_loop() + else: + self._client, self._loop = _create_glide_client(config) + + def _run(self, coro: Any) -> Any: + """Run an async coroutine on this instance's event loop.""" + return self._loop.run_until_complete(coro) + + def close(self) -> None: + """Shut down the glide client and close the event loop.""" + try: + self._run(self._client.close()) + except Exception: + logger.debug("Error closing glide client", exc_info=True) + finally: + if not self._loop.is_closed(): + self._loop.close() + + def __enter__(self) -> ValkeyVector: + return self + + def __exit__(self, *exc: Any) -> None: + self.close() + + def __del__(self) -> None: + # Best-effort cleanup if close() was never called explicitly. + try: + if not self._loop.is_closed(): + self.close() + except Exception: + logger.debug("Error during __del__ cleanup", exc_info=True) + finally: + # Ensure the loop is closed even if close() failed. + try: + if not self._loop.is_closed(): + self._loop.close() + except Exception: + logger.debug("Error closing event loop in __del__", exc_info=True) + + def get_type(self) -> str: + return VectorType.VALKEY + + def to_index_struct(self) -> VectorIndexStructDict: + return { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + + # ------------------------------------------------------------------ + # Index management + # ------------------------------------------------------------------ + + def _index_name(self) -> str: + return f"idx:{self._collection_name}" + + def _index_exists(self) -> bool: + """Check whether the FT index already exists.""" + from glide.async_commands.server_modules import ft + + try: + self._run(ft.info(self._client, self._index_name())) + return True + except Exception: + return False + + def _create_index(self, vector_size: int) -> None: + """Create the FT index with HNSW vector, TAG, and TEXT fields. + + The distance metric is read from ``self._config.distance_metric``. + Dimensions are determined by *vector_size* (auto-detected from the + first embedding passed to ``create``). + """ + from glide.async_commands.server_modules import ft + from glide.async_commands.server_modules.ft_options.ft_create_options import ( + DistanceMetricType, + FtCreateOptions, + TagField, + TextField, + VectorAlgorithm, + VectorField, + VectorFieldAttributesHnsw, + ) + from glide.async_commands.server_modules.ft_options.ft_create_options import ( + VectorType as GlideVectorType, + ) + + metric_map = { + "COSINE": DistanceMetricType.COSINE, + "L2": DistanceMetricType.L2, + "IP": DistanceMetricType.IP, + } + metric = metric_map.get( + self._config.distance_metric.upper(), + DistanceMetricType.COSINE, + ) + + schema = [ + VectorField( + name="vector", + algorithm=VectorAlgorithm.HNSW, + attributes=VectorFieldAttributesHnsw( + dimensions=vector_size, + distance_metric=metric, + type=GlideVectorType.FLOAT32, + ), + ), + TagField(name="group_id"), + TagField(name="doc_id"), + TagField(name="document_id"), + TextField(name=Field.CONTENT_KEY), + ] + options = FtCreateOptions(prefixes=[self._prefix]) + self._run(ft.create(self._client, self._index_name(), schema, options)) + + # ------------------------------------------------------------------ + # CRUD operations + # ------------------------------------------------------------------ + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs: Any) -> None: + if not texts or not embeddings: + return + vector_size = len(embeddings[0]) + if vector_size == 0: + raise ValueError("First embedding is empty — cannot determine vector dimensions") + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + cache_key = f"vector_indexing_{self._collection_name}" + if not redis_client.get(cache_key): + if not self._index_exists(): + self._create_index(vector_size) + redis_client.set(cache_key, 1, ex=3600) + self.add_texts(texts, embeddings, **kwargs) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs: Any) -> list[str]: + if not documents: + return [] + + # Validate all embeddings have consistent dimensions. + expected_dim = len(embeddings[0]) + for i, emb in enumerate(embeddings): + if len(emb) != expected_dim: + raise ValueError(f"Embedding dimension mismatch at index {i}: expected {expected_dim}, got {len(emb)}") + + added_ids: list[str] = [] + for doc, embedding in zip(documents, embeddings): + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) if doc.metadata else str(uuid.uuid4()) + key = f"{self._prefix}{doc_id}" + metadata = doc.metadata or {} + fields: dict[str, str | bytes] = { + "vector": _float_vector_to_bytes(embedding), + Field.CONTENT_KEY: doc.page_content, + Field.METADATA_KEY: json.dumps(metadata), + "group_id": self._group_id, + "doc_id": metadata.get("doc_id", ""), + "document_id": metadata.get("document_id", ""), + } + try: + self._run(self._client.hset(key, fields)) + except Exception: + logger.exception( + "Failed to add document %s to collection %s", + doc_id, + self._collection_name, + ) + raise + added_ids.append(doc_id) + + return added_ids + + def text_exists(self, id: str) -> bool: + key = f"{self._prefix}{id}" + result = self._run(self._client.exists([key])) + return result > 0 + + def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return + keys = [f"{self._prefix}{doc_id}" for doc_id in ids] + self._run(self._client.delete(keys)) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + """Delete documents matching a metadata TAG field value.""" + if key == "document_id": + query = f"@document_id:{{{_escape_tag(value)}}}" + else: + query = f"@doc_id:{{{_escape_tag(value)}}}" + self._delete_by_query(query) + + def delete(self) -> None: + """Delete all documents belonging to this group.""" + query = f"@group_id:{{{_escape_tag(self._group_id)}}}" + self._delete_by_query(query) + + def _delete_by_query(self, query: str) -> None: + """Search for keys matching *query* and delete them.""" + from glide.async_commands.server_modules import ft + from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSearchLimit, + FtSearchOptions, + ) + + batch_size = 100 + while True: + # Always search from offset 0 because deletions shift the result set. + options = FtSearchOptions( + return_fields=[], + limit=FtSearchLimit(offset=0, count=batch_size), + ) + result = self._run( + ft.search(self._client, self._index_name(), query, options), + ) + keys = _parse_dict_keys(result) + if not keys: + break + self._run(self._client.delete(keys)) + + # ------------------------------------------------------------------ + # Search + # ------------------------------------------------------------------ + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + from glide.async_commands.server_modules import ft + from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSearchLimit, + FtSearchOptions, + ) + + top_k = kwargs.get("top_k", 4) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + + query = f"(@group_id:{{{_escape_tag(self._group_id)}}})=>[KNN {top_k} @vector $query_vector]" + + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + tag_values = "|".join(_escape_tag(did) for did in document_ids_filter) + query = ( + f"(@group_id:{{{_escape_tag(self._group_id)}}} " + f"@document_id:{{{tag_values}}})" + f"=>[KNN {top_k} @vector $query_vector]" + ) + + vector_bytes = _float_vector_to_bytes(query_vector) + options = FtSearchOptions( + params={"query_vector": vector_bytes}, + limit=FtSearchLimit(offset=0, count=top_k), + ) + result = self._run( + ft.search(self._client, self._index_name(), query, options), + ) + return _parse_vector_search_results( + result, + score_threshold, + self._config.distance_metric, + ) + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Full-text search on ``page_content``. + + Splits the query into keywords and searches each separately (OR logic), + deduplicating results by key name. + """ + from glide.async_commands.server_modules import ft + from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSearchLimit, + FtSearchOptions, + ) + + top_k = kwargs.get("top_k", 2) + keywords = list( + dict.fromkeys(kw.strip() for kw in query.strip().split() if kw.strip()), + )[:10] + if not keywords: + return [] + + document_ids_filter = kwargs.get("document_ids_filter") + seen_keys: set[str] = set() + documents: list[Document] = [] + + for keyword in keywords: + escaped_kw = _escape_text(keyword) + filter_parts = [f"@group_id:{{{_escape_tag(self._group_id)}}}"] + if document_ids_filter: + tag_values = "|".join(_escape_tag(did) for did in document_ids_filter) + filter_parts.append(f"@document_id:{{{tag_values}}}") + filter_parts.append(f"@{Field.CONTENT_KEY}:{escaped_kw}") + ft_query = " ".join(filter_parts) + + options = FtSearchOptions(limit=FtSearchLimit(offset=0, count=top_k)) + result = self._run( + ft.search(self._client, self._index_name(), ft_query, options), + ) + for key, doc in _parse_full_text_results(result): + if key not in seen_keys: + seen_keys.add(key) + documents.append(doc) + if len(documents) >= top_k: + return documents + return documents + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +class ValkeyVectorFactory(AbstractVectorFactory): + """Factory for creating ValkeyVector instances from dataset configuration.""" + + def init_vector( + self, + dataset: Dataset, + attributes: list, + embeddings: Embeddings, + ) -> ValkeyVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + + if not dataset.index_struct_dict: + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.VALKEY, collection_name), + ) + + return ValkeyVector( + collection_name=collection_name, + group_id=dataset.id, + config=ValkeyVectorConfig( + host=dify_config.VALKEY_HOST, + port=dify_config.VALKEY_PORT, + password=dify_config.VALKEY_PASSWORD, + db=dify_config.VALKEY_DB, + use_ssl=dify_config.VALKEY_USE_SSL, + distance_metric=dify_config.VALKEY_DISTANCE_METRIC, + ), + ) diff --git a/api/providers/vdb/vdb-valkey/tests/integration_tests/test_valkey.py b/api/providers/vdb/vdb-valkey/tests/integration_tests/test_valkey.py new file mode 100644 index 00000000000000..286b3b1625628f --- /dev/null +++ b/api/providers/vdb/vdb-valkey/tests/integration_tests/test_valkey.py @@ -0,0 +1,420 @@ +"""Integration tests for the Valkey vector store backend. + +Requires a running Valkey instance with the valkey-search module loaded +on localhost:6379 (standard port). Start one with: + + docker run -d --name valkey-search -p 6379:6379 valkey/valkey-bundle:latest +""" + +from __future__ import annotations + +import os +import uuid + +import pytest +from dify_vdb_valkey.valkey_vector import ValkeyVector, ValkeyVectorConfig, VectorType + +from core.rag.models.document import Document + +EMBEDDING_DIM = 128 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _cfg() -> ValkeyVectorConfig: + return ValkeyVectorConfig( + host=os.environ.get("VALKEY_HOST", "localhost"), + port=int(os.environ.get("VALKEY_PORT", "6379")), + password=os.environ.get("VALKEY_PASSWORD", ""), + db=int(os.environ.get("VALKEY_DB", "0")), + use_ssl=False, + ) + + +def _embedding(seed: float = 1.001) -> list[float]: + """Deterministic 128-d embedding. Avoids zero vectors by adding 1 to the index.""" + return [seed * (i + 1) for i in range(EMBEDDING_DIM)] + + +def _doc(doc_id: str, content: str, dataset_id: str) -> Document: + return Document( + page_content=content, + metadata={ + "doc_id": doc_id, + "doc_hash": doc_id, + "document_id": doc_id, + "dataset_id": dataset_id, + }, + ) + + +@pytest.fixture +def vv(setup_mock_redis): + """Yield a fresh ValkeyVector with a unique collection, clean up after.""" + group_id = str(uuid.uuid4()) + collection = f"test_{uuid.uuid4().hex[:12]}" + v = ValkeyVector(collection_name=collection, group_id=group_id, config=_cfg()) + yield v + # Teardown: delete all docs and drop the index + try: + v.delete() + except Exception: # noqa: S110 + pass + try: + from glide.async_commands.server_modules import ft + + v._run(ft.dropindex(v._client, v._index_name())) + except Exception: # noqa: S110 + pass + + +# --------------------------------------------------------------------------- +# get_type / to_index_struct +# --------------------------------------------------------------------------- + + +class TestBasicProperties: + def test_get_type(self, vv: ValkeyVector): + assert vv.get_type() == VectorType.VALKEY + + def test_to_index_struct(self, vv: ValkeyVector): + s = vv.to_index_struct() + assert s["type"] == VectorType.VALKEY + assert s["vector_store"]["class_prefix"] == vv.collection_name + + +# --------------------------------------------------------------------------- +# create +# --------------------------------------------------------------------------- + + +class TestCreate: + def test_create_inserts_documents(self, vv: ValkeyVector): + doc_id = str(uuid.uuid4()) + docs = [_doc(doc_id, "hello world", vv._group_id)] + vv.create(docs, [_embedding()]) + + assert vv.text_exists(doc_id) + + def test_create_with_empty_list_is_noop(self, vv: ValkeyVector): + """create([]) must not create an index or raise.""" + vv.create([], []) + assert not vv._index_exists() + + def test_create_is_idempotent(self, vv: ValkeyVector): + """Calling create twice must not fail (index already exists).""" + doc1 = _doc(str(uuid.uuid4()), "first", vv._group_id) + doc2 = _doc(str(uuid.uuid4()), "second", vv._group_id) + vv.create([doc1], [_embedding(1.0)]) + vv.create([doc2], [_embedding(2.0)]) + + assert vv.text_exists(doc1.metadata["doc_id"]) + assert vv.text_exists(doc2.metadata["doc_id"]) + + +# --------------------------------------------------------------------------- +# add_texts +# --------------------------------------------------------------------------- + + +class TestAddTexts: + def test_returns_doc_ids(self, vv: ValkeyVector): + d1, d2 = str(uuid.uuid4()), str(uuid.uuid4()) + docs = [_doc(d1, "a", vv._group_id), _doc(d2, "b", vv._group_id)] + vv.create([docs[0]], [_embedding()]) # ensure index exists + ids = vv.add_texts(docs, [_embedding(), _embedding(2.0)]) + assert ids == [d1, d2] + + def test_batch_add_100_documents(self, vv: ValkeyVector): + """Verify bulk insert works.""" + doc_ids = [str(uuid.uuid4()) for _ in range(100)] + docs = [_doc(did, f"content-{i}", vv._group_id) for i, did in enumerate(doc_ids)] + vv.create([docs[0]], [_embedding()]) + vv.add_texts(docs, [_embedding(float(i)) for i in range(100)]) + # Spot-check a few + assert vv.text_exists(doc_ids[0]) + assert vv.text_exists(doc_ids[50]) + assert vv.text_exists(doc_ids[99]) + + +# --------------------------------------------------------------------------- +# text_exists +# --------------------------------------------------------------------------- + + +class TestTextExists: + def test_exists_true(self, vv: ValkeyVector): + doc_id = str(uuid.uuid4()) + vv.create([_doc(doc_id, "exists", vv._group_id)], [_embedding()]) + assert vv.text_exists(doc_id) is True + + def test_exists_false_for_missing(self, vv: ValkeyVector): + assert vv.text_exists("nonexistent-id") is False + + +# --------------------------------------------------------------------------- +# delete_by_ids +# --------------------------------------------------------------------------- + + +class TestDeleteByIds: + def test_deletes_specific_documents(self, vv: ValkeyVector): + d1, d2 = str(uuid.uuid4()), str(uuid.uuid4()) + vv.create( + [_doc(d1, "keep", vv._group_id), _doc(d2, "remove", vv._group_id)], + [_embedding(), _embedding(2.0)], + ) + vv.delete_by_ids([d2]) + assert vv.text_exists(d1) is True + assert vv.text_exists(d2) is False + + def test_delete_nonexistent_id_is_noop(self, vv: ValkeyVector): + """Deleting an ID that doesn't exist must not raise.""" + vv.delete_by_ids(["does-not-exist"]) + + def test_delete_empty_list_is_noop(self, vv: ValkeyVector): + vv.delete_by_ids([]) + + +# --------------------------------------------------------------------------- +# delete_by_metadata_field +# --------------------------------------------------------------------------- + + +class TestDeleteByMetadataField: + def test_delete_by_document_id(self, vv: ValkeyVector): + d1, d2 = str(uuid.uuid4()), str(uuid.uuid4()) + vv.create( + [_doc(d1, "doc one", vv._group_id), _doc(d2, "doc two", vv._group_id)], + [_embedding(), _embedding(2.0)], + ) + vv.delete_by_metadata_field("document_id", d1) + assert vv.text_exists(d1) is False + assert vv.text_exists(d2) is True + + def test_delete_by_doc_id(self, vv: ValkeyVector): + d1 = str(uuid.uuid4()) + vv.create([_doc(d1, "target", vv._group_id)], [_embedding()]) + vv.delete_by_metadata_field("doc_id", d1) + assert vv.text_exists(d1) is False + + +# --------------------------------------------------------------------------- +# delete (group-level) +# --------------------------------------------------------------------------- + + +class TestDelete: + def test_deletes_all_group_documents(self, vv: ValkeyVector): + ids = [str(uuid.uuid4()) for _ in range(5)] + docs = [_doc(did, f"text-{i}", vv._group_id) for i, did in enumerate(ids)] + vv.create(docs, [_embedding(float(i)) for i in range(5)]) + for did in ids: + assert vv.text_exists(did) + + vv.delete() + for did in ids: + assert vv.text_exists(did) is False + + def test_delete_does_not_affect_other_groups(self, vv: ValkeyVector): + """Documents from a different group_id must survive.""" + other_group = str(uuid.uuid4()) + d_own = str(uuid.uuid4()) + d_other = str(uuid.uuid4()) + + # Insert doc for vv's group + vv.create([_doc(d_own, "own group", vv._group_id)], [_embedding()]) + + # Insert doc for a different group into the same collection + other_doc = Document( + page_content="other group", + metadata={"doc_id": d_other, "doc_hash": d_other, "document_id": d_other, "dataset_id": other_group}, + ) + field_pairs: list[str | bytes] = [ + "vector", + b"", # placeholder, replaced below + "page_content", + "other group", + "metadata", + '{"doc_id":"' + d_other + '"}', + "group_id", + other_group, + "doc_id", + d_other, + "document_id", + d_other, + ] + import struct + + field_pairs[1] = struct.pack(f"<{EMBEDDING_DIM}f", *_embedding(3.0)) + vv._run(vv._client.custom_command(["HSET", f"{vv._prefix}{d_other}", *field_pairs])) + + vv.delete() # only deletes vv._group_id + + assert vv.text_exists(d_own) is False + assert vv.text_exists(d_other) is True # other group survives + + +# --------------------------------------------------------------------------- +# search_by_vector +# --------------------------------------------------------------------------- + + +class TestSearchByVector: + def test_returns_matching_document(self, vv: ValkeyVector): + doc_id = str(uuid.uuid4()) + emb = _embedding() + vv.create([_doc(doc_id, "vector search target", vv._group_id)], [emb]) + + hits = vv.search_by_vector(emb, top_k=1) + assert len(hits) == 1 + assert hits[0].metadata["doc_id"] == doc_id + assert "score" in hits[0].metadata + + def test_respects_top_k(self, vv: ValkeyVector): + ids = [str(uuid.uuid4()) for _ in range(5)] + docs = [_doc(did, f"doc-{i}", vv._group_id) for i, did in enumerate(ids)] + vv.create(docs, [_embedding(float(i + 1)) for i in range(5)]) + + hits = vv.search_by_vector(_embedding(1.0), top_k=2) + assert len(hits) <= 2 + + def test_score_threshold_filters_low_similarity(self, vv: ValkeyVector): + doc_id = str(uuid.uuid4()) + vv.create([_doc(doc_id, "threshold test", vv._group_id)], [_embedding()]) + + # Threshold of 1.0 means only exact match (distance=0) passes + hits = vv.search_by_vector(_embedding(), score_threshold=1.0, top_k=4) + # The same vector should have distance ~0, score ~1.0 + assert len(hits) >= 1 + + def test_score_threshold_excludes_all(self, vv: ValkeyVector): + doc_id = str(uuid.uuid4()) + # Create a vector pointing in a very different direction + opposite = [(-1.0) ** i * (i + 1) for i in range(EMBEDDING_DIM)] + vv.create([_doc(doc_id, "far away", vv._group_id)], [opposite]) + + # Query with a uniform-direction vector; high threshold should exclude + query = [float(i) for i in range(EMBEDDING_DIM)] + hits = vv.search_by_vector(query, score_threshold=0.99, top_k=4) + assert len(hits) == 0 + + def test_empty_collection_returns_empty(self, vv: ValkeyVector): + """Searching before any documents are inserted.""" + # Create index with a dummy doc then delete it + doc_id = str(uuid.uuid4()) + vv.create([_doc(doc_id, "temp", vv._group_id)], [_embedding()]) + vv.delete_by_ids([doc_id]) + + hits = vv.search_by_vector(_embedding(), top_k=4) + assert len(hits) == 0 + + def test_results_contain_page_content_and_metadata(self, vv: ValkeyVector): + doc_id = str(uuid.uuid4()) + vv.create([_doc(doc_id, "content check", vv._group_id)], [_embedding()]) + + hits = vv.search_by_vector(_embedding(), top_k=1) + assert hits[0].page_content == "content check" + assert hits[0].metadata["doc_id"] == doc_id + assert isinstance(hits[0].metadata["score"], float) + + def test_results_sorted_by_score_descending(self, vv: ValkeyVector): + """Closer vectors should rank higher.""" + d_close = str(uuid.uuid4()) + d_far = str(uuid.uuid4()) + query = _embedding(1.0) + vv.create( + [_doc(d_close, "close", vv._group_id), _doc(d_far, "far", vv._group_id)], + [_embedding(1.0), _embedding(50.0)], + ) + + hits = vv.search_by_vector(query, top_k=2) + assert len(hits) == 2 + assert hits[0].metadata["score"] >= hits[1].metadata["score"] + + +# --------------------------------------------------------------------------- +# search_by_full_text +# --------------------------------------------------------------------------- + + +class TestSearchByFullText: + def test_single_keyword_match(self, vv: ValkeyVector): + doc_id = str(uuid.uuid4()) + vv.create([_doc(doc_id, "the quick brown fox", vv._group_id)], [_embedding()]) + + hits = vv.search_by_full_text("fox", top_k=10) + assert len(hits) >= 1 + assert any(h.metadata["doc_id"] == doc_id for h in hits) + + def test_no_match_returns_empty(self, vv: ValkeyVector): + doc_id = str(uuid.uuid4()) + vv.create([_doc(doc_id, "hello world", vv._group_id)], [_embedding()]) + + hits = vv.search_by_full_text("zzzznonexistent", top_k=10) + assert len(hits) == 0 + + def test_empty_query_returns_empty(self, vv: ValkeyVector): + assert vv.search_by_full_text("", top_k=10) == [] + assert vv.search_by_full_text(" ", top_k=10) == [] + + def test_multi_keyword_or_logic(self, vv: ValkeyVector): + """Multi-word query should match documents containing ANY keyword.""" + d_apple = str(uuid.uuid4()) + d_banana = str(uuid.uuid4()) + d_both = str(uuid.uuid4()) + + vv.create( + [ + _doc(d_apple, "apple pie recipe", vv._group_id), + _doc(d_banana, "banana smoothie recipe", vv._group_id), + _doc(d_both, "apple banana fruit salad", vv._group_id), + ], + [_embedding(1.0), _embedding(2.0), _embedding(3.0)], + ) + + hits = vv.search_by_full_text("apple banana", top_k=10) + found_ids = {h.metadata["doc_id"] for h in hits} + assert d_apple in found_ids + assert d_banana in found_ids + assert d_both in found_ids + + def test_deduplication(self, vv: ValkeyVector): + """A document matching multiple keywords must appear only once.""" + doc_id = str(uuid.uuid4()) + vv.create([_doc(doc_id, "apple banana cherry", vv._group_id)], [_embedding()]) + + hits = vv.search_by_full_text("apple banana", top_k=10) + doc_ids = [h.metadata["doc_id"] for h in hits] + assert doc_ids.count(doc_id) == 1 + + def test_respects_top_k(self, vv: ValkeyVector): + ids = [str(uuid.uuid4()) for _ in range(5)] + docs = [_doc(did, f"common keyword text-{i}", vv._group_id) for i, did in enumerate(ids)] + vv.create(docs, [_embedding(float(i)) for i in range(5)]) + + hits = vv.search_by_full_text("common", top_k=2) + assert len(hits) <= 2 + + def test_metadata_preserved(self, vv: ValkeyVector): + doc_id = str(uuid.uuid4()) + vv.create([_doc(doc_id, "metadata preservation test", vv._group_id)], [_embedding()]) + + hits = vv.search_by_full_text("preservation", top_k=1) + assert len(hits) == 1 + assert hits[0].metadata["doc_id"] == doc_id + assert hits[0].page_content == "metadata preservation test" + + +# --------------------------------------------------------------------------- +# get_ids_by_metadata_field (not implemented — should raise) +# --------------------------------------------------------------------------- + + +class TestGetIdsByMetadataField: + def test_raises_not_implemented(self, vv: ValkeyVector): + with pytest.raises(NotImplementedError): + vv.get_ids_by_metadata_field("key", "value") diff --git a/api/providers/vdb/vdb-valkey/tests/unit_tests/test_valkey_vector.py b/api/providers/vdb/vdb-valkey/tests/unit_tests/test_valkey_vector.py new file mode 100644 index 00000000000000..6abb6a7213811c --- /dev/null +++ b/api/providers/vdb/vdb-valkey/tests/unit_tests/test_valkey_vector.py @@ -0,0 +1,241 @@ +"""Unit tests for the Valkey vector store backend. + +Only pure functions are tested here — serialisation, parsing, escaping, +distance conversion, and config. No mocks, no monkeypatching. +""" + +from __future__ import annotations + +import json +import struct + +import pytest + +# Pure helpers are importable without the glide C extension because they +# have no dependency on glide at module level. +from dify_vdb_valkey.valkey_vector import ( + ValkeyVectorConfig, + _bytes_to_float_vector, + _distance_to_similarity, + _escape_tag, + _escape_text, + _float_vector_to_bytes, + _parse_dict_keys, + _parse_full_text_results, + _parse_vector_search_results, + _to_str, +) + +# =================================================================== +# Float vector serialisation +# =================================================================== + + +class TestFloatVectorSerialization: + def test_roundtrip_single(self): + assert _bytes_to_float_vector(_float_vector_to_bytes([1.0])) == pytest.approx([1.0]) + + def test_roundtrip_multiple(self): + orig = [0.0, 1.5, -3.14, 100.0, 0.001] + assert _bytes_to_float_vector(_float_vector_to_bytes(orig)) == pytest.approx(orig, rel=1e-5) + + def test_empty(self): + assert _float_vector_to_bytes([]) == b"" + assert _bytes_to_float_vector(b"") == [] + + def test_little_endian_float32(self): + assert _float_vector_to_bytes([1.0]) == struct.pack("