Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions lib/crewai-tools/src/crewai_tools/rag/embedding_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class EmbeddingService:
- roboflow: Roboflow embeddings (roboflow-embeddings-v2-base-en, etc.)
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
- watsonx: Watson X embeddings (ibm/slate-125m-english-rtrvr, etc.)
- fastembed: FastEmbed embeddings (sentence-transformers/all-MiniLM-L6-v2, etc.)
- custom: Custom embeddings (embedding_callable, etc.)
- sentence-transformer: Sentence Transformers embeddings (all-MiniLM-L6-v2, etc.)
- text2vec: Text2Vec embeddings (text2vec-base-en, etc.)
Expand Down Expand Up @@ -230,6 +231,11 @@ def _build_provider_config(self) -> dict[str, Any]:
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "fastembed":
base_config["config"] = {
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "custom":
# Custom provider requires embedding_callable in extra_config
base_config["config"] = {
Expand Down Expand Up @@ -365,6 +371,7 @@ def list_supported_providers(cls) -> list[str]:
"amazon-bedrock",
"cohere",
"custom",
"fastembed",
"google-generativeai",
"google-vertex",
"huggingface",
Expand Down Expand Up @@ -497,6 +504,15 @@ def create_watsonx_service(
"""Create a Watson X embedding service."""
return cls(provider="watsonx", model=model, api_key=api_key, **kwargs)

@classmethod
def create_fastembed_service(
cls,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
**kwargs: Any,
) -> EmbeddingService:
"""Create a FastEmbed embedding service (local, fast inference)."""
return cls(provider="fastembed", model=model, **kwargs)

@classmethod
def create_custom_service(
cls,
Expand Down
15 changes: 15 additions & 0 deletions lib/crewai/src/crewai/rag/embeddings/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.fastembed.embedding_callable import (
FastEmbedEmbeddingFunction,
)
from crewai.rag.embeddings.providers.fastembed.types import FastEmbedProviderSpec
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
Expand Down Expand Up @@ -92,6 +96,7 @@
"amazon-bedrock": "crewai.rag.embeddings.providers.aws.bedrock.BedrockProvider",
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
"fastembed": "crewai.rag.embeddings.providers.fastembed.fastembed_provider.FastEmbedProvider",
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
"google": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
Expand Down Expand Up @@ -142,6 +147,12 @@ def build_embedder_from_dict(spec: CohereProviderSpec) -> CohereEmbeddingFunctio
def build_embedder_from_dict(spec: CustomProviderSpec) -> EmbeddingFunction[Any]: ...


@overload
def build_embedder_from_dict(
spec: FastEmbedProviderSpec,
) -> FastEmbedEmbeddingFunction: ...


@overload
def build_embedder_from_dict(
spec: GenerativeAiProviderSpec,
Expand Down Expand Up @@ -283,6 +294,10 @@ def build_embedder(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
def build_embedder(spec: CustomProviderSpec) -> EmbeddingFunction[Any]: ...


@overload
def build_embedder(spec: FastEmbedProviderSpec) -> FastEmbedEmbeddingFunction: ...


@overload
def build_embedder(
spec: GenerativeAiProviderSpec,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""FastEmbed embedding providers."""

from crewai.rag.embeddings.providers.fastembed.embedding_callable import (
FastEmbedEmbeddingFunction,
)
from crewai.rag.embeddings.providers.fastembed.fastembed_provider import (
FastEmbedProvider,
)
from crewai.rag.embeddings.providers.fastembed.types import (
FastEmbedProviderConfig,
FastEmbedProviderSpec,
)


__all__ = [
"FastEmbedEmbeddingFunction",
"FastEmbedProvider",
"FastEmbedProviderConfig",
"FastEmbedProviderSpec",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""FastEmbed embedding function implementation."""

from typing import Any, cast

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack

from crewai.rag.embeddings.providers.fastembed.types import FastEmbedProviderConfig


class FastEmbedEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for FastEmbed text embedding models."""

def __init__(self, **kwargs: Unpack[FastEmbedProviderConfig]) -> None:
"""Initialize FastEmbed embedding function.

Args:
**kwargs: Configuration parameters for FastEmbed.
"""
try:
from fastembed import TextEmbedding
except ImportError as e:
raise ImportError(
"fastembed is required for fastembed embeddings. "
"Install it with: uv add fastembed"
) from e

model_kwargs: dict[str, Any] = {
"model_name": kwargs.get(
"model_name", "sentence-transformers/all-MiniLM-L6-v2"
)
}
for key in (
"cache_dir",
"threads",
"providers",
"cuda",
"device_ids",
"lazy_load",
):
if key in kwargs and kwargs[key] is not None:
model_kwargs[key] = kwargs[key]

self._model = TextEmbedding(**model_kwargs)
self._batch_size = kwargs.get("batch_size", 256)
self._parallel = kwargs.get("parallel")

@staticmethod
def name() -> str:
"""Return the name of the embedding function for ChromaDB compatibility."""
return "fastembed"

def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.

Args:
input: List of documents to embed.

Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]

embed_kwargs: dict[str, Any] = {"batch_size": self._batch_size}
if self._parallel is not None:
embed_kwargs["parallel"] = self._parallel

return cast(Embeddings, list(self._model.embed(input, **embed_kwargs)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""FastEmbed embeddings provider."""

from pydantic import AliasChoices, Field

from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.fastembed.embedding_callable import (
FastEmbedEmbeddingFunction,
)


class FastEmbedProvider(BaseEmbeddingsProvider[FastEmbedEmbeddingFunction]):
"""FastEmbed embeddings provider."""

embedding_callable: type[FastEmbedEmbeddingFunction] = Field(
default=FastEmbedEmbeddingFunction,
description="FastEmbed embedding function class",
)
model_name: str = Field(
default="sentence-transformers/all-MiniLM-L6-v2",
description="Model name to use",
validation_alias=AliasChoices(
"EMBEDDINGS_FASTEMBED_MODEL_NAME",
"FASTEMBED_MODEL_NAME",
"model",
),
)
cache_dir: str | None = Field(
default=None,
description="Directory to cache downloaded FastEmbed models",
validation_alias=AliasChoices(
"EMBEDDINGS_FASTEMBED_CACHE_DIR", "FASTEMBED_CACHE_DIR"
),
)
threads: int | None = Field(
default=None,
description="Number of threads to use for inference",
validation_alias=AliasChoices("EMBEDDINGS_FASTEMBED_THREADS", "FASTEMBED_THREADS"),
)
providers: list[str] | None = Field(
default=None,
description="ONNX Runtime execution providers",
validation_alias=AliasChoices(
"EMBEDDINGS_FASTEMBED_PROVIDERS", "FASTEMBED_PROVIDERS"
),
)
cuda: bool = Field(
default=False,
description="Whether to use CUDA execution",
validation_alias=AliasChoices("EMBEDDINGS_FASTEMBED_CUDA", "FASTEMBED_CUDA"),
)
device_ids: list[int] | None = Field(
default=None,
description="CUDA device IDs to use",
validation_alias=AliasChoices(
"EMBEDDINGS_FASTEMBED_DEVICE_IDS", "FASTEMBED_DEVICE_IDS"
),
)
lazy_load: bool = Field(
default=False,
description="Whether to defer model loading until first embedding call",
validation_alias=AliasChoices(
"EMBEDDINGS_FASTEMBED_LAZY_LOAD", "FASTEMBED_LAZY_LOAD"
),
)
batch_size: int = Field(
default=256,
description="Batch size to use when embedding documents",
validation_alias=AliasChoices(
"EMBEDDINGS_FASTEMBED_BATCH_SIZE", "FASTEMBED_BATCH_SIZE"
),
)
parallel: int | None = Field(
default=None,
description="Number of parallel workers to use when embedding documents",
validation_alias=AliasChoices(
"EMBEDDINGS_FASTEMBED_PARALLEL", "FASTEMBED_PARALLEL"
),
)
26 changes: 26 additions & 0 deletions lib/crewai/src/crewai/rag/embeddings/providers/fastembed/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Type definitions for FastEmbed embedding providers."""

from typing import Annotated, Literal

from typing_extensions import Required, TypedDict


class FastEmbedProviderConfig(TypedDict, total=False):
"""Configuration for FastEmbed provider."""

model_name: Annotated[str, "sentence-transformers/all-MiniLM-L6-v2"]
cache_dir: str | None
threads: int | None
providers: list[str] | None
cuda: Annotated[bool, False]
device_ids: list[int] | None
lazy_load: Annotated[bool, False]
batch_size: Annotated[int, 256]
parallel: int | None


class FastEmbedProviderSpec(TypedDict, total=False):
"""FastEmbed provider specification."""

provider: Required[Literal["fastembed"]]
config: FastEmbedProviderConfig
3 changes: 3 additions & 0 deletions lib/crewai/src/crewai/rag/embeddings/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.fastembed.types import FastEmbedProviderSpec
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec,
VertexAIProviderSpec,
Expand Down Expand Up @@ -34,6 +35,7 @@
| BedrockProviderSpec
| CohereProviderSpec
| CustomProviderSpec
| FastEmbedProviderSpec
| GenerativeAiProviderSpec
| HuggingFaceProviderSpec
| InstructorProviderSpec
Expand All @@ -55,6 +57,7 @@
"amazon-bedrock",
"cohere",
"custom",
"fastembed",
"google-generativeai",
"google-vertex",
"huggingface",
Expand Down
10 changes: 9 additions & 1 deletion lib/crewai/tests/rag/embeddings/test_backward_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
from crewai.rag.embeddings.providers.ollama.ollama_provider import OllamaProvider
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
from crewai.rag.embeddings.providers.fastembed.fastembed_provider import FastEmbedProvider
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import Text2VecProvider
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
SentenceTransformerProvider,
Expand Down Expand Up @@ -101,6 +102,13 @@ def test_text2vec_provider_accepts_model_key(self):
)
assert provider.model_name == "shibing624/text2vec-base-multilingual"

def test_fastembed_provider_accepts_model_key(self):
"""Test FastEmbed provider accepts 'model' as alias for 'model_name'."""
provider = FastEmbedProvider(
model="sentence-transformers/all-MiniLM-L6-v2",
)
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"

def test_sentence_transformer_provider_accepts_model_key(self):
"""Test SentenceTransformer provider accepts 'model' as alias."""
provider = SentenceTransformerProvider(
Expand Down Expand Up @@ -361,4 +369,4 @@ def test_legacy_azure_with_model_key(self):
deployment_id="test-deployment",
model="text-embedding-3-large",
)
assert provider.model_name == "text-embedding-3-large"
assert provider.model_name == "text-embedding-3-large"
29 changes: 29 additions & 0 deletions lib/crewai/tests/rag/embeddings/test_embedding_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,35 @@ def test_build_embedder_cohere(self, mock_import):
"crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider"
)

@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_fastembed(self, mock_import):
"""Test building FastEmbed embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()

mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function

config = {
"provider": "fastembed",
"config": {
"model": "sentence-transformers/all-MiniLM-L6-v2",
"cache_dir": ".fastembed_cache",
},
}

build_embedder(config)

mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.fastembed.fastembed_provider.FastEmbedProvider"
)

call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["model"] == "sentence-transformers/all-MiniLM-L6-v2"
assert call_kwargs["cache_dir"] == ".fastembed_cache"

@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_voyageai(self, mock_import):
"""Test building VoyageAI embedder."""
Expand Down