Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions api/configs/middleware/vdb/milvus_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,21 @@ class MilvusConfig(BaseSettings):
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
default=None,
)

MILVUS_SECURE: bool = Field(
description="Enable TLS for the Milvus connection (one-way TLS). When True, the client uses gRPC over TLS "
"and verifies the server certificate. Equivalent to passing secure=True to pymilvus.",
default=False,
)

MILVUS_SERVER_PEM_PATH: str | None = Field(
description="Filesystem path inside the container to the Milvus server certificate (PEM). Mount this via "
"a Kubernetes secret. Used as pymilvus's server_pem_path when MILVUS_SECURE is True.",
default=None,
)

MILVUS_SERVER_NAME: str | None = Field(
description="Server name (TLS SNI / certificate CN or SAN) to verify against the Milvus server certificate. "
"Required when MILVUS_SERVER_PEM_PATH is set.",
default=None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class MilvusConfig(BaseModel):
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
analyzer_params: str | None = None # Analyzer params
secure: bool = False # Enable one-way TLS to Milvus
server_pem_path: str | None = None # Path to server certificate (PEM) for TLS verification
server_name: str | None = None # Server name to verify against the certificate (SNI / CN)

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -388,16 +391,19 @@ def _init_client(self, config: MilvusConfig) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
kwargs: dict[str, Any] = {"uri": config.uri, "db_name": config.database}
if config.token:
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
kwargs["token"] = config.token
else:
client = MilvusClient(
uri=config.uri,
user=config.user or "",
password=config.password or "",
db_name=config.database,
)
return client
kwargs["user"] = config.user or ""
kwargs["password"] = config.password or ""
if config.secure:
kwargs["secure"] = True
if config.server_pem_path:
kwargs["server_pem_path"] = config.server_pem_path
if config.server_name:
kwargs["server_name"] = config.server_name
return MilvusClient(**kwargs)


class MilvusVectorFactory(AbstractVectorFactory):
Expand Down Expand Up @@ -427,5 +433,8 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "",
secure=dify_config.MILVUS_SECURE,
server_pem_path=dify_config.MILVUS_SERVER_PEM_PATH,
server_name=dify_config.MILVUS_SERVER_NAME,
),
)
29 changes: 29 additions & 0 deletions api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,35 @@ def test_init_client_supports_token_and_user_password(milvus_module):
assert user_client.init_kwargs["password"] == "Milvus"


def test_init_client_passes_tls_kwargs_when_secure(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
client = vector._init_client(
milvus_module.MilvusConfig.model_validate(
{
"uri": "https://milvus.example.com:19530",
"token": "abc",
"database": "db",
"secure": True,
"server_pem_path": "/etc/milvus/certs/server.pem",
"server_name": "milvus.example.com",
}
)
)
assert client.init_kwargs["secure"] is True
assert client.init_kwargs["server_pem_path"] == "/etc/milvus/certs/server.pem"
assert client.init_kwargs["server_name"] == "milvus.example.com"


def test_init_client_omits_tls_kwargs_when_not_secure(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
client = vector._init_client(
milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"})
)
assert "secure" not in client.init_kwargs
assert "server_pem_path" not in client.init_kwargs
assert "server_name" not in client.init_kwargs


def test_init_loads_fields_when_collection_exists(milvus_module):
client = milvus_module.MilvusClient(uri="http://localhost:19530")
client.has_collection.return_value = True
Expand Down
Loading