Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion lib/crewai/src/crewai/memory/encoding_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any
from uuid import uuid4

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator

from crewai.flow.flow import Flow, listen, start
from crewai.memory.analyze import (
Expand Down Expand Up @@ -68,6 +68,31 @@ class ItemState(BaseModel):
plan: ConsolidationPlan | None = None
result_record: MemoryRecord | None = None

@field_validator("similar_records", "result_record", mode="before")
@classmethod
def ensure_embedding_is_list(cls, v: Any) -> Any:
"""Ensure MemoryRecord embeddings are list[float], not bytes."""
if v is None:
return None
if isinstance(v, list):
# Process list of MemoryRecords
for record in v:
if isinstance(record, MemoryRecord) and isinstance(
record.embedding, bytes
):
import numpy as np

arr = np.frombuffer(record.embedding, dtype=np.float32)
record.embedding = [float(x) for x in arr]
return v
if isinstance(v, MemoryRecord) and isinstance(v.embedding, bytes):
# Process single MemoryRecord
import numpy as np

arr = np.frombuffer(v.embedding, dtype=np.float32)
v.embedding = [float(x) for x in arr]
return v


class EncodingState(BaseModel):
"""Batch-level state for the encoding flow."""
Expand Down
52 changes: 50 additions & 2 deletions lib/crewai/src/crewai/memory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

from __future__ import annotations

import concurrent.futures
from datetime import datetime
import logging
from typing import Any
from uuid import uuid4

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator


_logger = logging.getLogger(__name__)

# When searching the vector store, we ask for more results than the caller
# requested so that post-search steps (composite scoring, deduplication,
# category filtering) have enough candidates to fill the final result set.
Expand Down Expand Up @@ -57,6 +61,23 @@ class MemoryRecord(BaseModel):
repr=False,
description="Vector embedding for semantic search. Excluded from serialization to save tokens.",
)

@field_validator("embedding", mode="before")
@classmethod
def validate_embedding(cls, v: Any) -> list[float] | None:
"""Ensure embedding is always list[float] or None, never bytes."""
if v is None:
return None
if isinstance(v, bytes):
# Convert bytes to list[float] if needed
import numpy as np

if len(v) == 0:
return None
arr = np.frombuffer(v, dtype=np.float32)
return [float(x) for x in arr]
return [float(x) for x in v]

source: str | None = Field(
default=None,
description=(
Expand Down Expand Up @@ -304,7 +325,11 @@ def embed_text(embedder: Any, text: str) -> list[float]:
"""
if not text or not text.strip():
return []

# Just call the embedder directly - the blocking issue needs to be fixed
# at a higher level (making Memory.recall() async)
result = embedder([text])

if not result:
return []
first = result[0]
Expand All @@ -315,6 +340,11 @@ def embed_text(embedder: Any, text: str) -> list[float]:
return list(first)


# Reusable thread pool for running embedder calls from sync context
# when an async event loop is already running.
_EMBED_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=1)


def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
"""Embed multiple texts in a single API call.

Expand All @@ -328,6 +358,8 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
Returns:
List of embeddings, one per input text. Empty texts produce empty lists.
"""
import asyncio

if not texts:
return []
# Filter out empty texts, remembering their positions
Expand All @@ -337,7 +369,23 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
if not valid:
return [[] for _ in texts]

result = embedder([t for _, t in valid])
# Check if we're in an async context
result: Any
try:
asyncio.get_running_loop()
# We're in an async context, but this is a sync function
# Run embedder in thread pool to avoid blocking the event loop
try:
result = _EMBED_POOL.submit(embedder, [t for _, t in valid]).result(
timeout=30
)
except concurrent.futures.TimeoutError:
_logger.warning("Embedder timed out after 30s, returning empty embeddings")
return [[] for _ in texts]
except RuntimeError:
# Not in async context, run directly
result = embedder([t for _, t in valid])

embeddings: list[list[float]] = [[] for _ in texts]
for (orig_idx, _), emb in zip(valid, result, strict=False):
if hasattr(emb, "tolist"):
Expand Down
54 changes: 51 additions & 3 deletions lib/crewai/src/crewai/memory/unified_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from concurrent.futures import Future, ThreadPoolExecutor
import contextvars
from datetime import datetime
import logging
import threading
import time
from typing import TYPE_CHECKING, Annotated, Any, Literal
Expand Down Expand Up @@ -36,6 +37,9 @@
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec


_logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
Expand Down Expand Up @@ -316,16 +320,60 @@ def _on_save_done(self, future: Future[Any]) -> None:
except Exception: # noqa: S110
pass # swallow everything during shutdown

def drain_writes(self) -> None:
def drain_writes(self, timeout_per_save: float = 60.0) -> None:
"""Block until all pending background saves have completed.

Called automatically by ``recall()`` and should be called by the
crew at shutdown to ensure no saves are lost.

Args:
timeout_per_save: Maximum seconds to wait per save operation.
Default 60s. If a save times out, logs warning
but continues to avoid blocking crew completion.
"""
with self._pending_lock:
pending = list(self._pending_saves)
for future in pending:
future.result() # blocks until done; re-raises exceptions

if pending:
_logger.debug(
"[DRAIN_WRITES] Waiting for %d pending saves...", len(pending)
)

failed_saves = 0
for i, future in enumerate(pending):
try:
_logger.debug(
"[DRAIN_WRITES] Waiting for save %d/%d...", i + 1, len(pending)
)
future.result(timeout=timeout_per_save)
_logger.debug(
"[DRAIN_WRITES] Save %d/%d completed", i + 1, len(pending)
)
except TimeoutError: # noqa: PERF203
failed_saves += 1
_logger.warning(
"[DRAIN_WRITES] Save %d/%d timed out after %ss. "
"This save will be abandoned. Consider increasing timeout or checking "
"LLM/embedder performance.",
i + 1,
len(pending),
timeout_per_save,
)
# Don't raise - just log and continue to avoid blocking crew completion
except Exception as e:
failed_saves += 1
_logger.error(
"[DRAIN_WRITES] Save %d/%d failed: %s", i + 1, len(pending), e
)
# Don't raise - just log and continue

if failed_saves > 0:
_logger.warning(
"[DRAIN_WRITES] %d/%d saves failed or timed out. "
"Some memories may not have been persisted.",
failed_saves,
len(pending),
)

def close(self) -> None:
"""Drain pending saves, flush storage, and shut down the background thread pool."""
Expand Down
115 changes: 115 additions & 0 deletions lib/crewai/tests/memory/test_embedding_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests for embedding safety: bytes→float validators and async-safe embed_texts."""

from __future__ import annotations

import asyncio
import concurrent.futures
from unittest.mock import MagicMock

import numpy as np
import pytest

from crewai.memory.types import MemoryRecord, embed_text, embed_texts


class TestMemoryRecordEmbeddingValidator:
"""Tests for MemoryRecord.validate_embedding (bytes→list[float])."""

def test_none_embedding_stays_none(self) -> None:
r = MemoryRecord(content="test", embedding=None)
assert r.embedding is None

def test_list_of_floats_passes_through(self) -> None:
r = MemoryRecord(content="test", embedding=[0.1, 0.2, 0.3])
assert r.embedding == [0.1, 0.2, 0.3]

def test_bytes_converted_to_list_float(self) -> None:
arr = np.array([0.1, 0.2, 0.3], dtype=np.float32)
raw_bytes = arr.tobytes()
r = MemoryRecord(content="test", embedding=raw_bytes)
assert r.embedding is not None
assert len(r.embedding) == 3
assert all(isinstance(x, float) for x in r.embedding)
np.testing.assert_allclose(r.embedding, [0.1, 0.2, 0.3], atol=1e-6)

def test_empty_bytes_becomes_none(self) -> None:
r = MemoryRecord(content="test", embedding=b"")
assert r.embedding is None

def test_list_of_ints_converted_to_floats(self) -> None:
r = MemoryRecord(content="test", embedding=[1, 2, 3])
assert r.embedding == [1.0, 2.0, 3.0]
assert all(isinstance(x, float) for x in r.embedding)

def test_numpy_array_converted_to_list(self) -> None:
arr = np.array([0.5, 0.6], dtype=np.float32)
r = MemoryRecord(content="test", embedding=arr)
assert r.embedding is not None
assert isinstance(r.embedding, list)
assert len(r.embedding) == 2


class TestEmbedTextsAsyncSafety:
"""Tests for embed_texts running safely in async context."""

def test_embed_texts_sync_context(self) -> None:
"""embed_texts works in a normal sync context."""
embedder = MagicMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
result = embed_texts(embedder, ["hello", "world"])
assert len(result) == 2
assert result[0] == [0.1, 0.2]
embedder.assert_called_once()

def test_embed_texts_empty_input(self) -> None:
embedder = MagicMock()
assert embed_texts(embedder, []) == []
embedder.assert_not_called()

def test_embed_texts_all_empty_strings(self) -> None:
embedder = MagicMock()
result = embed_texts(embedder, ["", " ", ""])
assert result == [[], [], []]
embedder.assert_not_called()

def test_embed_texts_skips_empty_preserves_positions(self) -> None:
embedder = MagicMock(return_value=[[0.1, 0.2]])
result = embed_texts(embedder, ["", "hello", ""])
assert result == [[], [0.1, 0.2], []]
embedder.assert_called_once_with(["hello"])

def test_embed_texts_in_async_context(self) -> None:
"""embed_texts uses thread pool when called from async context."""
embedder = MagicMock(return_value=[[0.1, 0.2]])

async def run() -> list[list[float]]:
return embed_texts(embedder, ["hello"])

result = asyncio.run(run())
assert result == [[0.1, 0.2]]
embedder.assert_called_once()


class TestEmbedText:
"""Tests for embed_text (single text)."""

def test_empty_string_returns_empty(self) -> None:
embedder = MagicMock()
assert embed_text(embedder, "") == []
embedder.assert_not_called()

def test_whitespace_only_returns_empty(self) -> None:
embedder = MagicMock()
assert embed_text(embedder, " ") == []
embedder.assert_not_called()

def test_normal_text_returns_embedding(self) -> None:
embedder = MagicMock(return_value=[[0.1, 0.2, 0.3]])
result = embed_text(embedder, "hello")
assert result == [0.1, 0.2, 0.3]

def test_numpy_array_result_converted(self) -> None:
arr = np.array([0.1, 0.2], dtype=np.float32)
embedder = MagicMock(return_value=[arr])
result = embed_text(embedder, "hello")
assert isinstance(result, list)
assert len(result) == 2
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading