diff --git a/lib/crewai-tools/src/crewai_tools/rag/embedding_service.py b/lib/crewai-tools/src/crewai_tools/rag/embedding_service.py index 9ac1b66e82..721268e7fe 100644 --- a/lib/crewai-tools/src/crewai_tools/rag/embedding_service.py +++ b/lib/crewai-tools/src/crewai_tools/rag/embedding_service.py @@ -51,6 +51,7 @@ class EmbeddingService: - roboflow: Roboflow embeddings (roboflow-embeddings-v2-base-en, etc.) - voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.) - watsonx: Watson X embeddings (ibm/slate-125m-english-rtrvr, etc.) + - fastembed: FastEmbed embeddings (sentence-transformers/all-MiniLM-L6-v2, etc.) - custom: Custom embeddings (embedding_callable, etc.) - sentence-transformer: Sentence Transformers embeddings (all-MiniLM-L6-v2, etc.) - text2vec: Text2Vec embeddings (text2vec-base-en, etc.) @@ -230,6 +231,11 @@ def _build_provider_config(self) -> dict[str, Any]: "model_name": self.config.model, **self.config.extra_config, } + elif self.config.provider == "fastembed": + base_config["config"] = { + "model_name": self.config.model, + **self.config.extra_config, + } elif self.config.provider == "custom": # Custom provider requires embedding_callable in extra_config base_config["config"] = { @@ -365,6 +371,7 @@ def list_supported_providers(cls) -> list[str]: "amazon-bedrock", "cohere", "custom", + "fastembed", "google-generativeai", "google-vertex", "huggingface", @@ -497,6 +504,15 @@ def create_watsonx_service( """Create a Watson X embedding service.""" return cls(provider="watsonx", model=model, api_key=api_key, **kwargs) + @classmethod + def create_fastembed_service( + cls, + model: str = "sentence-transformers/all-MiniLM-L6-v2", + **kwargs: Any, + ) -> EmbeddingService: + """Create a FastEmbed embedding service (local, fast inference).""" + return cls(provider="fastembed", model=model, **kwargs) + @classmethod def create_custom_service( cls, diff --git a/lib/crewai/src/crewai/rag/embeddings/factory.py b/lib/crewai/src/crewai/rag/embeddings/factory.py index 8027793200..63eac660db 100644 --- a/lib/crewai/src/crewai/rag/embeddings/factory.py +++ b/lib/crewai/src/crewai/rag/embeddings/factory.py @@ -51,6 +51,10 @@ from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec + from crewai.rag.embeddings.providers.fastembed.embedding_callable import ( + FastEmbedEmbeddingFunction, + ) + from crewai.rag.embeddings.providers.fastembed.types import FastEmbedProviderSpec from crewai.rag.embeddings.providers.google.genai_vertex_embedding import ( GoogleGenAIVertexEmbeddingFunction, ) @@ -92,6 +96,7 @@ "amazon-bedrock": "crewai.rag.embeddings.providers.aws.bedrock.BedrockProvider", "cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider", "custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider", + "fastembed": "crewai.rag.embeddings.providers.fastembed.fastembed_provider.FastEmbedProvider", "google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider", "google": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider", "google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider", @@ -142,6 +147,12 @@ def build_embedder_from_dict(spec: CohereProviderSpec) -> CohereEmbeddingFunctio def build_embedder_from_dict(spec: CustomProviderSpec) -> EmbeddingFunction[Any]: ... +@overload +def build_embedder_from_dict( + spec: FastEmbedProviderSpec, +) -> FastEmbedEmbeddingFunction: ... + + @overload def build_embedder_from_dict( spec: GenerativeAiProviderSpec, @@ -283,6 +294,10 @@ def build_embedder(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ... def build_embedder(spec: CustomProviderSpec) -> EmbeddingFunction[Any]: ... +@overload +def build_embedder(spec: FastEmbedProviderSpec) -> FastEmbedEmbeddingFunction: ... + + @overload def build_embedder( spec: GenerativeAiProviderSpec, diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/__init__.py b/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/__init__.py new file mode 100644 index 0000000000..4aa1e66e5f --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/__init__.py @@ -0,0 +1,20 @@ +"""FastEmbed embedding providers.""" + +from crewai.rag.embeddings.providers.fastembed.embedding_callable import ( + FastEmbedEmbeddingFunction, +) +from crewai.rag.embeddings.providers.fastembed.fastembed_provider import ( + FastEmbedProvider, +) +from crewai.rag.embeddings.providers.fastembed.types import ( + FastEmbedProviderConfig, + FastEmbedProviderSpec, +) + + +__all__ = [ + "FastEmbedEmbeddingFunction", + "FastEmbedProvider", + "FastEmbedProviderConfig", + "FastEmbedProviderSpec", +] diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/embedding_callable.py b/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/embedding_callable.py new file mode 100644 index 0000000000..61ca9cd7a6 --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/embedding_callable.py @@ -0,0 +1,69 @@ +"""FastEmbed embedding function implementation.""" + +from typing import Any, cast + +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +from typing_extensions import Unpack + +from crewai.rag.embeddings.providers.fastembed.types import FastEmbedProviderConfig + + +class FastEmbedEmbeddingFunction(EmbeddingFunction[Documents]): + """Embedding function for FastEmbed text embedding models.""" + + def __init__(self, **kwargs: Unpack[FastEmbedProviderConfig]) -> None: + """Initialize FastEmbed embedding function. + + Args: + **kwargs: Configuration parameters for FastEmbed. + """ + try: + from fastembed import TextEmbedding + except ImportError as e: + raise ImportError( + "fastembed is required for fastembed embeddings. " + "Install it with: uv add fastembed" + ) from e + + model_kwargs: dict[str, Any] = { + "model_name": kwargs.get( + "model_name", "sentence-transformers/all-MiniLM-L6-v2" + ) + } + for key in ( + "cache_dir", + "threads", + "providers", + "cuda", + "device_ids", + "lazy_load", + ): + if key in kwargs and kwargs[key] is not None: + model_kwargs[key] = kwargs[key] + + self._model = TextEmbedding(**model_kwargs) + self._batch_size = kwargs.get("batch_size", 256) + self._parallel = kwargs.get("parallel") + + @staticmethod + def name() -> str: + """Return the name of the embedding function for ChromaDB compatibility.""" + return "fastembed" + + def __call__(self, input: Documents) -> Embeddings: + """Generate embeddings for input documents. + + Args: + input: List of documents to embed. + + Returns: + List of embedding vectors. + """ + if isinstance(input, str): + input = [input] + + embed_kwargs: dict[str, Any] = {"batch_size": self._batch_size} + if self._parallel is not None: + embed_kwargs["parallel"] = self._parallel + + return cast(Embeddings, list(self._model.embed(input, **embed_kwargs))) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/fastembed_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/fastembed_provider.py new file mode 100644 index 0000000000..28726aeb52 --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/fastembed_provider.py @@ -0,0 +1,78 @@ +"""FastEmbed embeddings provider.""" + +from pydantic import AliasChoices, Field + +from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider +from crewai.rag.embeddings.providers.fastembed.embedding_callable import ( + FastEmbedEmbeddingFunction, +) + + +class FastEmbedProvider(BaseEmbeddingsProvider[FastEmbedEmbeddingFunction]): + """FastEmbed embeddings provider.""" + + embedding_callable: type[FastEmbedEmbeddingFunction] = Field( + default=FastEmbedEmbeddingFunction, + description="FastEmbed embedding function class", + ) + model_name: str = Field( + default="sentence-transformers/all-MiniLM-L6-v2", + description="Model name to use", + validation_alias=AliasChoices( + "EMBEDDINGS_FASTEMBED_MODEL_NAME", + "FASTEMBED_MODEL_NAME", + "model", + ), + ) + cache_dir: str | None = Field( + default=None, + description="Directory to cache downloaded FastEmbed models", + validation_alias=AliasChoices( + "EMBEDDINGS_FASTEMBED_CACHE_DIR", "FASTEMBED_CACHE_DIR" + ), + ) + threads: int | None = Field( + default=None, + description="Number of threads to use for inference", + validation_alias=AliasChoices("EMBEDDINGS_FASTEMBED_THREADS", "FASTEMBED_THREADS"), + ) + providers: list[str] | None = Field( + default=None, + description="ONNX Runtime execution providers", + validation_alias=AliasChoices( + "EMBEDDINGS_FASTEMBED_PROVIDERS", "FASTEMBED_PROVIDERS" + ), + ) + cuda: bool = Field( + default=False, + description="Whether to use CUDA execution", + validation_alias=AliasChoices("EMBEDDINGS_FASTEMBED_CUDA", "FASTEMBED_CUDA"), + ) + device_ids: list[int] | None = Field( + default=None, + description="CUDA device IDs to use", + validation_alias=AliasChoices( + "EMBEDDINGS_FASTEMBED_DEVICE_IDS", "FASTEMBED_DEVICE_IDS" + ), + ) + lazy_load: bool = Field( + default=False, + description="Whether to defer model loading until first embedding call", + validation_alias=AliasChoices( + "EMBEDDINGS_FASTEMBED_LAZY_LOAD", "FASTEMBED_LAZY_LOAD" + ), + ) + batch_size: int = Field( + default=256, + description="Batch size to use when embedding documents", + validation_alias=AliasChoices( + "EMBEDDINGS_FASTEMBED_BATCH_SIZE", "FASTEMBED_BATCH_SIZE" + ), + ) + parallel: int | None = Field( + default=None, + description="Number of parallel workers to use when embedding documents", + validation_alias=AliasChoices( + "EMBEDDINGS_FASTEMBED_PARALLEL", "FASTEMBED_PARALLEL" + ), + ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/types.py b/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/types.py new file mode 100644 index 0000000000..be4fc9410e --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/fastembed/types.py @@ -0,0 +1,26 @@ +"""Type definitions for FastEmbed embedding providers.""" + +from typing import Annotated, Literal + +from typing_extensions import Required, TypedDict + + +class FastEmbedProviderConfig(TypedDict, total=False): + """Configuration for FastEmbed provider.""" + + model_name: Annotated[str, "sentence-transformers/all-MiniLM-L6-v2"] + cache_dir: str | None + threads: int | None + providers: list[str] | None + cuda: Annotated[bool, False] + device_ids: list[int] | None + lazy_load: Annotated[bool, False] + batch_size: Annotated[int, 256] + parallel: int | None + + +class FastEmbedProviderSpec(TypedDict, total=False): + """FastEmbed provider specification.""" + + provider: Required[Literal["fastembed"]] + config: FastEmbedProviderConfig diff --git a/lib/crewai/src/crewai/rag/embeddings/types.py b/lib/crewai/src/crewai/rag/embeddings/types.py index 794f4c6f9a..26ba0ac46f 100644 --- a/lib/crewai/src/crewai/rag/embeddings/types.py +++ b/lib/crewai/src/crewai/rag/embeddings/types.py @@ -6,6 +6,7 @@ from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec +from crewai.rag.embeddings.providers.fastembed.types import FastEmbedProviderSpec from crewai.rag.embeddings.providers.google.types import ( GenerativeAiProviderSpec, VertexAIProviderSpec, @@ -34,6 +35,7 @@ | BedrockProviderSpec | CohereProviderSpec | CustomProviderSpec + | FastEmbedProviderSpec | GenerativeAiProviderSpec | HuggingFaceProviderSpec | InstructorProviderSpec @@ -55,6 +57,7 @@ "amazon-bedrock", "cohere", "custom", + "fastembed", "google-generativeai", "google-vertex", "huggingface", diff --git a/lib/crewai/tests/rag/embeddings/test_backward_compatibility.py b/lib/crewai/tests/rag/embeddings/test_backward_compatibility.py index d10a75cdec..ac33b73ec8 100644 --- a/lib/crewai/tests/rag/embeddings/test_backward_compatibility.py +++ b/lib/crewai/tests/rag/embeddings/test_backward_compatibility.py @@ -9,6 +9,7 @@ from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider from crewai.rag.embeddings.providers.ollama.ollama_provider import OllamaProvider from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider +from crewai.rag.embeddings.providers.fastembed.fastembed_provider import FastEmbedProvider from crewai.rag.embeddings.providers.text2vec.text2vec_provider import Text2VecProvider from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import ( SentenceTransformerProvider, @@ -101,6 +102,13 @@ def test_text2vec_provider_accepts_model_key(self): ) assert provider.model_name == "shibing624/text2vec-base-multilingual" + def test_fastembed_provider_accepts_model_key(self): + """Test FastEmbed provider accepts 'model' as alias for 'model_name'.""" + provider = FastEmbedProvider( + model="sentence-transformers/all-MiniLM-L6-v2", + ) + assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2" + def test_sentence_transformer_provider_accepts_model_key(self): """Test SentenceTransformer provider accepts 'model' as alias.""" provider = SentenceTransformerProvider( @@ -361,4 +369,4 @@ def test_legacy_azure_with_model_key(self): deployment_id="test-deployment", model="text-embedding-3-large", ) - assert provider.model_name == "text-embedding-3-large" \ No newline at end of file + assert provider.model_name == "text-embedding-3-large" diff --git a/lib/crewai/tests/rag/embeddings/test_embedding_factory.py b/lib/crewai/tests/rag/embeddings/test_embedding_factory.py index 7e553d0a75..858e2f052b 100644 --- a/lib/crewai/tests/rag/embeddings/test_embedding_factory.py +++ b/lib/crewai/tests/rag/embeddings/test_embedding_factory.py @@ -154,6 +154,35 @@ def test_build_embedder_cohere(self, mock_import): "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider" ) + @patch("crewai.rag.embeddings.factory.import_and_validate_definition") + def test_build_embedder_fastembed(self, mock_import): + """Test building FastEmbed embedder.""" + mock_provider_class = MagicMock() + mock_provider_instance = MagicMock() + mock_embedding_function = MagicMock() + + mock_import.return_value = mock_provider_class + mock_provider_class.return_value = mock_provider_instance + mock_provider_instance.embedding_callable.return_value = mock_embedding_function + + config = { + "provider": "fastembed", + "config": { + "model": "sentence-transformers/all-MiniLM-L6-v2", + "cache_dir": ".fastembed_cache", + }, + } + + build_embedder(config) + + mock_import.assert_called_once_with( + "crewai.rag.embeddings.providers.fastembed.fastembed_provider.FastEmbedProvider" + ) + + call_kwargs = mock_provider_class.call_args.kwargs + assert call_kwargs["model"] == "sentence-transformers/all-MiniLM-L6-v2" + assert call_kwargs["cache_dir"] == ".fastembed_cache" + @patch("crewai.rag.embeddings.factory.import_and_validate_definition") def test_build_embedder_voyageai(self, mock_import): """Test building VoyageAI embedder."""