Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ services:
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-eu_fact_force}
POSTGRES_DB: ${POSTGRES_DB:-eu_fact_force}
ports:
- 5432
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql
- ./docker/postgres/init:/docker-entrypoint-initdb.d:ro
Expand Down
20 changes: 20 additions & 0 deletions eu_fact_force/ingestion/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
MODEL_ID = "intfloat/multilingual-e5-base"
# E5 models expect "passage: " for documents to index and "query: " for search queries (asymmetric retrieval).
PASSAGE_PREFIX = "passage: "
QUERY_PREFIX = "query: "
EMBED_BATCH_SIZE = 32
_MODEL = None

Expand All @@ -17,6 +18,25 @@ def _get_model():
return _MODEL


def embed_query(query: str) -> list[float]:
"""
Embed a search query with the same model as ingestion (E5 query prefix).
Returns a 768-d normalized vector for use with pgvector similarity search.
"""
normalized_query = query.strip() if query is not None else ""
if not normalized_query:
raise ValueError("query must be non-empty")
model = _get_model()
text = f"{QUERY_PREFIX}{normalized_query}"
vector = model.encode(
[text],
show_progress_bar=False,
normalize_embeddings=True,
)
out = vector[0]
return out.tolist() if hasattr(out, "tolist") else list(out)


def _iter_batches(items: list[DocumentChunk], batch_size: int) -> Iterator[list[DocumentChunk]]:
for start in range(0, len(items), batch_size):
yield items[start : start + batch_size]
Expand Down
28 changes: 28 additions & 0 deletions eu_fact_force/ingestion/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Semantic search over ingested document chunks using pgvector."""

from eu_fact_force.ingestion.embedding import embed_query
from eu_fact_force.ingestion.models import DocumentChunk
from pgvector.django import CosineDistance


def search_chunks(query: str, k: int = 10) -> list[tuple[DocumentChunk, float]]:
"""
Return the top-k document chunks most similar to the query.

The query is embedded with the same model as ingestion (E5, query prefix).
Results are ordered by cosine distance (lower is more similar).
Only chunks with a stored embedding are considered.

Returns a list of (chunk, distance) tuples. Chunk includes source_file
via the ORM relation for display (e.g. source_file.doi).
"""
if k <= 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plutot un message d'erreur?

raise ValueError("k must be a positive integer")
query_vector = embed_query(query)
qs = (
DocumentChunk.objects.filter(embedding__isnull=False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Est ce que postgres a un système d'ANN? plutot que de tester toutes les lignes?

.select_related("source_file")
.annotate(distance=CosineDistance("embedding", query_vector))
.order_by("distance")[:k]
)
return [(chunk, float(chunk.distance)) for chunk in qs]
153 changes: 153 additions & 0 deletions tests/ingestion/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Tests for semantic search over document chunks."""

from pathlib import Path

import pytest

from eu_fact_force.ingestion import parsing as parsing_module
from eu_fact_force.ingestion import search as search_module
from eu_fact_force.ingestion import services as services_module
from eu_fact_force.ingestion.chunking import MAX_CHUNK_CHARS
from eu_fact_force.ingestion.models import DocumentChunk, EMBEDDING_DIMENSIONS, SourceFile
from eu_fact_force.ingestion.services import run_pipeline

PROJECT_ROOT = Path(__file__).resolve().parents[2]

# Tolerance for float distance comparison.
DISTANCE_TOLERANCE = 1e-5
# Paragraph length so chunking yields 3 separate chunks (each under MAX_CHUNK_CHARS).
PARAGRAPH_LEN = (MAX_CHUNK_CHARS // 2) + 1
# Components for the second-closest vector (distance between near and far).
SECOND_CLOSEST_VEC_ALIGNED = 0.99
SECOND_CLOSEST_VEC_OFF = 0.01


def _constant_vector(value: float, dim: int = EMBEDDING_DIMENSIONS) -> list[float]:
"""Vector with the same value in every dimension."""
return [value] * dim


def _one_hot_vector(index: int, dim: int = EMBEDDING_DIMENSIONS) -> list[float]:
"""One-hot vector: 1.0 at index, 0 elsewhere (for distinct cosine distances)."""
v = [0.0] * dim
v[index] = 1.0
return v


class TestSearchChunks:
@pytest.mark.django_db
def test_returns_chunks_ordered_by_distance(self, monkeypatch):
"""search_chunks returns (chunk, distance) tuples ordered by cosine distance."""
source = SourceFile.objects.create(
doi="d", s3_key="k", status=SourceFile.Status.STORED
)
# Query [1,0,...,0]: closest to chunk_near (same), then chunk_far ([0,1,0,...,0]).
chunk_far = DocumentChunk.objects.create(
source_file=source, content="far", order=1, embedding=_one_hot_vector(1)
)
chunk_near = DocumentChunk.objects.create(
source_file=source, content="near", order=2, embedding=_one_hot_vector(0)
)

monkeypatch.setattr(search_module, "embed_query", lambda _: _one_hot_vector(0))

results = search_module.search_chunks("dummy", k=5)

assert len(results) == 2
(first, d1), (second, d2) = results
assert first.content == "near"
assert first.pk == chunk_near.pk
assert second.content == "far"
assert second.pk == chunk_far.pk
assert d1 == pytest.approx(0.0, abs=DISTANCE_TOLERANCE)
assert d2 > 0

@pytest.mark.django_db
def test_excludes_chunks_without_embedding(self, monkeypatch):
"""Only chunks with a stored embedding are returned."""
source = SourceFile.objects.create(
doi="d", s3_key="k", status=SourceFile.Status.STORED
)
with_emb = DocumentChunk.objects.create(
source_file=source, content="with", order=1, embedding=_constant_vector(0.1)
)
DocumentChunk.objects.create(
source_file=source, content="without", order=2, embedding=None
)

monkeypatch.setattr(
search_module, "embed_query", lambda _: _constant_vector(0.1)
)

results = search_module.search_chunks("q", k=5)

assert len(results) == 1
assert results[0][0].content == "with"
assert results[0][0].pk == with_emb.pk

@pytest.mark.django_db
def test_respects_k(self, monkeypatch):
"""At most k results are returned."""
source = SourceFile.objects.create(
doi="d", s3_key="k", status=SourceFile.Status.STORED
)
for i in range(5):
DocumentChunk.objects.create(
source_file=source,
content=f"c{i}",
order=i,
embedding=_constant_vector(0.5),
)

monkeypatch.setattr(
search_module, "embed_query", lambda _: _constant_vector(0.5)
)

assert len(search_module.search_chunks("q", k=2)) == 2
assert len(search_module.search_chunks("q", k=10)) == 5

@pytest.mark.django_db
def test_k_zero_raises_value_error(self):
"""k<=0 raises ValueError to signal incorrect usage."""
with pytest.raises(ValueError):
search_module.search_chunks("q", k=0)

@pytest.mark.django_db
def test_pipeline_then_search_returns_chunks_ordered_by_similarity(
self, tmp_storage, monkeypatch
):
"""Run pipeline with mocked parse and add_embeddings, then search; order matches."""
readme_fn = PROJECT_ROOT / "README.md"
assert readme_fn.exists(), f"Test file must exist: {readme_fn}"

# Long enough so paragraph chunking produces three separate chunks (max_chunk_chars=1200).
p1, p2, p3 = "A" * PARAGRAPH_LEN, "B" * PARAGRAPH_LEN, "C" * PARAGRAPH_LEN
parsed_text = f"{p1}\n\n{p2}\n\n{p3}"
monkeypatch.setattr(
parsing_module,
"_extract_text_from_source_file",
lambda _: parsed_text,
)

# Vectors so cosine distance order is well-defined: p1 closest, then p2, then p3.
def _add_known_embeddings(chunks):
near = _one_hot_vector(0)
mid = [SECOND_CLOSEST_VEC_ALIGNED] + [SECOND_CLOSEST_VEC_OFF] + [0.0] * (
EMBEDDING_DIMENSIONS - 2
)
far = _one_hot_vector(1)
vecs = [near, mid, far]
for i, ch in enumerate(chunks):
if ch.pk and ch.content.strip() and i < len(vecs):
ch.embedding = vecs[i]
DocumentChunk.objects.bulk_update(chunks, ["embedding"])

monkeypatch.setattr(services_module, "add_embeddings", _add_known_embeddings)
monkeypatch.setattr(search_module, "embed_query", lambda _: _one_hot_vector(0))

source_file, _ = run_pipeline("README.md")
results = search_module.search_chunks("query", k=5)

contents = [r[0].content for r in results]
assert contents == [p1, p2, p3]
assert all(r[0].source_file_id == source_file.pk for r in results)
Loading