Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions lib/crewai/src/crewai/memory/storage/lancedb_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,37 @@ def __init__(
with store_lock(self._lock_name):
self._table = self._create_table(vector_dim)

@staticmethod
def _escape_sql_str(value: Any) -> str:
"""Escape a string for inclusion inside a single-quoted SQL literal.

LanceDB's ``where()`` accepts an Apache DataFusion SQL expression as a
raw string -- it does not support parameterized queries. Any caller-
supplied value (scope path, record id, etc.) that is interpolated into
a string literal MUST first have its single quotes doubled, otherwise
an attacker (or simply a record with an apostrophe in its id) can
terminate the literal early and inject arbitrary SQL.
"""
return str(value).replace("'", "''")

@staticmethod
def _escape_like(value: Any) -> str:
"""Escape a string for use as a literal prefix inside a ``LIKE`` clause.

Doubles single quotes (so the pattern can't break out of its literal)
and escapes the SQL ``LIKE`` metacharacters ``%`` and ``_`` so that
callers passing those characters in a scope path don't accidentally
(or maliciously) widen the match. The returned pattern fragment is
intended to be paired with ``ESCAPE '\\'``.
"""
return (
str(value)
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
.replace("'", "''")
)

@staticmethod
def _infer_dim_from_table(table: Any) -> int:
"""Read vector dimension from an existing table's schema."""
Expand Down Expand Up @@ -317,7 +348,7 @@ def update(self, record: MemoryRecord) -> None:
"""Update a record by ID. Preserves created_at, updates last_accessed."""
with store_lock(self._lock_name):
self._ensure_table()
safe_id = str(record.id).replace("'", "''")
safe_id = self._escape_sql_str(record.id)
self._do_write("delete", f"id = '{safe_id}'")
row = self._record_to_row(record)
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
Expand All @@ -338,7 +369,7 @@ def touch_records(self, record_ids: list[str]) -> None:
return
with store_lock(self._lock_name):
now = datetime.utcnow().isoformat()
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
safe_ids = [self._escape_sql_str(rid) for rid in record_ids]
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
self._do_write(
"update",
Expand All @@ -350,7 +381,7 @@ def get_record(self, record_id: str) -> MemoryRecord | None:
"""Return a single record by ID, or None if not found."""
if self._table is None:
return None
safe_id = str(record_id).replace("'", "''")
safe_id = self._escape_sql_str(record_id)
rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list()
if not rows:
return None
Expand All @@ -370,8 +401,8 @@ def search(
query = self._table.search(query_embedding)
if scope_prefix is not None and scope_prefix.strip("/"):
prefix = scope_prefix.rstrip("/")
like_val = prefix + "%"
query = query.where(f"scope LIKE '{like_val}'")
like_val = self._escape_like(prefix) + "%"
query = query.where(f"scope LIKE '{like_val}' ESCAPE '\\'")
results = query.limit(
limit * 3 if (categories or metadata_filter) else limit
).to_list()
Expand Down Expand Up @@ -405,7 +436,8 @@ def delete(
with store_lock(self._lock_name):
if record_ids and not (categories or metadata_filter):
before = int(self._table.count_rows())
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
safe_ids = [self._escape_sql_str(rid) for rid in record_ids]
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
self._do_write("delete", f"id IN ({ids_expr})")
return before - int(self._table.count_rows())
if categories or metadata_filter:
Expand All @@ -427,17 +459,22 @@ def delete(
if not to_delete:
return 0
before = int(self._table.count_rows())
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
safe_ids = [self._escape_sql_str(rid) for rid in to_delete]
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
self._do_write("delete", f"id IN ({ids_expr})")
return before - int(self._table.count_rows())
conditions = []
if scope_prefix is not None and scope_prefix.strip("/"):
prefix = scope_prefix.rstrip("/")
if not prefix.startswith("/"):
prefix = "/" + prefix
conditions.append(f"scope LIKE '{prefix}%' OR scope = '/'")
like_val = self._escape_like(prefix) + "%"
conditions.append(
f"(scope LIKE '{like_val}' ESCAPE '\\' OR scope = '/')"
)
if older_than is not None:
conditions.append(f"created_at < '{older_than.isoformat()}'")
safe_ts = self._escape_sql_str(older_than.isoformat())
conditions.append(f"created_at < '{safe_ts}'")
if not conditions:
before = int(self._table.count_rows())
self._do_write("delete", "id != ''")
Expand Down Expand Up @@ -469,7 +506,8 @@ def _scan_rows(
return []
q = self._table.search()
if scope_prefix is not None and scope_prefix.strip("/"):
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
like_val = self._escape_like(scope_prefix.rstrip("/")) + "%"
q = q.where(f"scope LIKE '{like_val}' ESCAPE '\\'")
if columns is not None:
q = q.select(columns)
result: list[dict[str, Any]] = q.limit(limit).to_list()
Expand Down Expand Up @@ -595,8 +633,10 @@ def reset(self, scope_prefix: str | None = None) -> None:
return
prefix = scope_prefix.rstrip("/")
if prefix:
safe_prefix = self._escape_sql_str(prefix)
self._do_write(
"delete", f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'"
"delete",
f"scope >= '{safe_prefix}' AND scope < '{safe_prefix}/\uffff'",
)

def optimize(self) -> None:
Expand Down
Loading
Loading