diff --git a/lib/crewai/src/crewai/memory/encoding_flow.py b/lib/crewai/src/crewai/memory/encoding_flow.py index acd025d553..ac753d26d0 100644 --- a/lib/crewai/src/crewai/memory/encoding_flow.py +++ b/lib/crewai/src/crewai/memory/encoding_flow.py @@ -18,7 +18,7 @@ from typing import Any from uuid import uuid4 -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from crewai.flow.flow import Flow, listen, start from crewai.memory.analyze import ( @@ -68,6 +68,31 @@ class ItemState(BaseModel): plan: ConsolidationPlan | None = None result_record: MemoryRecord | None = None + @field_validator("similar_records", "result_record", mode="before") + @classmethod + def ensure_embedding_is_list(cls, v: Any) -> Any: + """Ensure MemoryRecord embeddings are list[float], not bytes.""" + if v is None: + return None + if isinstance(v, list): + # Process list of MemoryRecords + for record in v: + if isinstance(record, MemoryRecord) and isinstance( + record.embedding, bytes + ): + import numpy as np + + arr = np.frombuffer(record.embedding, dtype=np.float32) + record.embedding = [float(x) for x in arr] + return v + if isinstance(v, MemoryRecord) and isinstance(v.embedding, bytes): + # Process single MemoryRecord + import numpy as np + + arr = np.frombuffer(v.embedding, dtype=np.float32) + v.embedding = [float(x) for x in arr] + return v + class EncodingState(BaseModel): """Batch-level state for the encoding flow.""" diff --git a/lib/crewai/src/crewai/memory/types.py b/lib/crewai/src/crewai/memory/types.py index e787b569d0..fc37027cfd 100644 --- a/lib/crewai/src/crewai/memory/types.py +++ b/lib/crewai/src/crewai/memory/types.py @@ -2,13 +2,17 @@ from __future__ import annotations +import concurrent.futures from datetime import datetime +import logging from typing import Any from uuid import uuid4 -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator +_logger = logging.getLogger(__name__) + # When searching the vector store, we ask for more results than the caller # requested so that post-search steps (composite scoring, deduplication, # category filtering) have enough candidates to fill the final result set. @@ -57,6 +61,23 @@ class MemoryRecord(BaseModel): repr=False, description="Vector embedding for semantic search. Excluded from serialization to save tokens.", ) + + @field_validator("embedding", mode="before") + @classmethod + def validate_embedding(cls, v: Any) -> list[float] | None: + """Ensure embedding is always list[float] or None, never bytes.""" + if v is None: + return None + if isinstance(v, bytes): + # Convert bytes to list[float] if needed + import numpy as np + + if len(v) == 0: + return None + arr = np.frombuffer(v, dtype=np.float32) + return [float(x) for x in arr] + return [float(x) for x in v] + source: str | None = Field( default=None, description=( @@ -304,7 +325,11 @@ def embed_text(embedder: Any, text: str) -> list[float]: """ if not text or not text.strip(): return [] + + # Just call the embedder directly - the blocking issue needs to be fixed + # at a higher level (making Memory.recall() async) result = embedder([text]) + if not result: return [] first = result[0] @@ -315,6 +340,11 @@ def embed_text(embedder: Any, text: str) -> list[float]: return list(first) +# Reusable thread pool for running embedder calls from sync context +# when an async event loop is already running. +_EMBED_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]: """Embed multiple texts in a single API call. @@ -328,6 +358,8 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]: Returns: List of embeddings, one per input text. Empty texts produce empty lists. """ + import asyncio + if not texts: return [] # Filter out empty texts, remembering their positions @@ -337,7 +369,23 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]: if not valid: return [[] for _ in texts] - result = embedder([t for _, t in valid]) + # Check if we're in an async context + result: Any + try: + asyncio.get_running_loop() + # We're in an async context, but this is a sync function + # Run embedder in thread pool to avoid blocking the event loop + try: + result = _EMBED_POOL.submit(embedder, [t for _, t in valid]).result( + timeout=30 + ) + except concurrent.futures.TimeoutError: + _logger.warning("Embedder timed out after 30s, returning empty embeddings") + return [[] for _ in texts] + except RuntimeError: + # Not in async context, run directly + result = embedder([t for _, t in valid]) + embeddings: list[list[float]] = [[] for _ in texts] for (orig_idx, _), emb in zip(valid, result, strict=False): if hasattr(emb, "tolist"): diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index d879bace0c..3b9086916e 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -5,6 +5,7 @@ from concurrent.futures import Future, ThreadPoolExecutor import contextvars from datetime import datetime +import logging import threading import time from typing import TYPE_CHECKING, Annotated, Any, Literal @@ -36,6 +37,9 @@ from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec +_logger = logging.getLogger(__name__) + + if TYPE_CHECKING: from chromadb.utils.embedding_functions.openai_embedding_function import ( OpenAIEmbeddingFunction, @@ -316,16 +320,60 @@ def _on_save_done(self, future: Future[Any]) -> None: except Exception: # noqa: S110 pass # swallow everything during shutdown - def drain_writes(self) -> None: + def drain_writes(self, timeout_per_save: float = 60.0) -> None: """Block until all pending background saves have completed. Called automatically by ``recall()`` and should be called by the crew at shutdown to ensure no saves are lost. + + Args: + timeout_per_save: Maximum seconds to wait per save operation. + Default 60s. If a save times out, logs warning + but continues to avoid blocking crew completion. """ with self._pending_lock: pending = list(self._pending_saves) - for future in pending: - future.result() # blocks until done; re-raises exceptions + + if pending: + _logger.debug( + "[DRAIN_WRITES] Waiting for %d pending saves...", len(pending) + ) + + failed_saves = 0 + for i, future in enumerate(pending): + try: + _logger.debug( + "[DRAIN_WRITES] Waiting for save %d/%d...", i + 1, len(pending) + ) + future.result(timeout=timeout_per_save) + _logger.debug( + "[DRAIN_WRITES] Save %d/%d completed", i + 1, len(pending) + ) + except TimeoutError: # noqa: PERF203 + failed_saves += 1 + _logger.warning( + "[DRAIN_WRITES] Save %d/%d timed out after %ss. " + "This save will be abandoned. Consider increasing timeout or checking " + "LLM/embedder performance.", + i + 1, + len(pending), + timeout_per_save, + ) + # Don't raise - just log and continue to avoid blocking crew completion + except Exception as e: + failed_saves += 1 + _logger.error( + "[DRAIN_WRITES] Save %d/%d failed: %s", i + 1, len(pending), e + ) + # Don't raise - just log and continue + + if failed_saves > 0: + _logger.warning( + "[DRAIN_WRITES] %d/%d saves failed or timed out. " + "Some memories may not have been persisted.", + failed_saves, + len(pending), + ) def close(self) -> None: """Drain pending saves, flush storage, and shut down the background thread pool.""" diff --git a/lib/crewai/tests/memory/test_embedding_safety.py b/lib/crewai/tests/memory/test_embedding_safety.py new file mode 100644 index 0000000000..ac5288e40c --- /dev/null +++ b/lib/crewai/tests/memory/test_embedding_safety.py @@ -0,0 +1,115 @@ +"""Tests for embedding safety: bytes→float validators and async-safe embed_texts.""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from crewai.memory.types import MemoryRecord, embed_text, embed_texts + + +class TestMemoryRecordEmbeddingValidator: + """Tests for MemoryRecord.validate_embedding (bytes→list[float]).""" + + def test_none_embedding_stays_none(self) -> None: + r = MemoryRecord(content="test", embedding=None) + assert r.embedding is None + + def test_list_of_floats_passes_through(self) -> None: + r = MemoryRecord(content="test", embedding=[0.1, 0.2, 0.3]) + assert r.embedding == [0.1, 0.2, 0.3] + + def test_bytes_converted_to_list_float(self) -> None: + arr = np.array([0.1, 0.2, 0.3], dtype=np.float32) + raw_bytes = arr.tobytes() + r = MemoryRecord(content="test", embedding=raw_bytes) + assert r.embedding is not None + assert len(r.embedding) == 3 + assert all(isinstance(x, float) for x in r.embedding) + np.testing.assert_allclose(r.embedding, [0.1, 0.2, 0.3], atol=1e-6) + + def test_empty_bytes_becomes_none(self) -> None: + r = MemoryRecord(content="test", embedding=b"") + assert r.embedding is None + + def test_list_of_ints_converted_to_floats(self) -> None: + r = MemoryRecord(content="test", embedding=[1, 2, 3]) + assert r.embedding == [1.0, 2.0, 3.0] + assert all(isinstance(x, float) for x in r.embedding) + + def test_numpy_array_converted_to_list(self) -> None: + arr = np.array([0.5, 0.6], dtype=np.float32) + r = MemoryRecord(content="test", embedding=arr) + assert r.embedding is not None + assert isinstance(r.embedding, list) + assert len(r.embedding) == 2 + + +class TestEmbedTextsAsyncSafety: + """Tests for embed_texts running safely in async context.""" + + def test_embed_texts_sync_context(self) -> None: + """embed_texts works in a normal sync context.""" + embedder = MagicMock(return_value=[[0.1, 0.2], [0.3, 0.4]]) + result = embed_texts(embedder, ["hello", "world"]) + assert len(result) == 2 + assert result[0] == [0.1, 0.2] + embedder.assert_called_once() + + def test_embed_texts_empty_input(self) -> None: + embedder = MagicMock() + assert embed_texts(embedder, []) == [] + embedder.assert_not_called() + + def test_embed_texts_all_empty_strings(self) -> None: + embedder = MagicMock() + result = embed_texts(embedder, ["", " ", ""]) + assert result == [[], [], []] + embedder.assert_not_called() + + def test_embed_texts_skips_empty_preserves_positions(self) -> None: + embedder = MagicMock(return_value=[[0.1, 0.2]]) + result = embed_texts(embedder, ["", "hello", ""]) + assert result == [[], [0.1, 0.2], []] + embedder.assert_called_once_with(["hello"]) + + def test_embed_texts_in_async_context(self) -> None: + """embed_texts uses thread pool when called from async context.""" + embedder = MagicMock(return_value=[[0.1, 0.2]]) + + async def run() -> list[list[float]]: + return embed_texts(embedder, ["hello"]) + + result = asyncio.run(run()) + assert result == [[0.1, 0.2]] + embedder.assert_called_once() + + +class TestEmbedText: + """Tests for embed_text (single text).""" + + def test_empty_string_returns_empty(self) -> None: + embedder = MagicMock() + assert embed_text(embedder, "") == [] + embedder.assert_not_called() + + def test_whitespace_only_returns_empty(self) -> None: + embedder = MagicMock() + assert embed_text(embedder, " ") == [] + embedder.assert_not_called() + + def test_normal_text_returns_embedding(self) -> None: + embedder = MagicMock(return_value=[[0.1, 0.2, 0.3]]) + result = embed_text(embedder, "hello") + assert result == [0.1, 0.2, 0.3] + + def test_numpy_array_result_converted(self) -> None: + arr = np.array([0.1, 0.2], dtype=np.float32) + embedder = MagicMock(return_value=[arr]) + result = embed_text(embedder, "hello") + assert isinstance(result, list) + assert len(result) == 2 diff --git a/uv.lock b/uv.lock index 5101cea490..7a0e852087 100644 --- a/uv.lock +++ b/uv.lock @@ -13,7 +13,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-27T16:00:00Z" +exclude-newer = "2026-04-28T04:00:00Z" [manifest] members = [