-
Notifications
You must be signed in to change notification settings - Fork 3
Pipeline semantic search #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
6312bf8
integrate Docling parsing and paragraph chunking into ingestion pipeline
AymanL 0af60af
integrate Docling chunking pipeline and add chunk embedding column
AymanL 7c1b36f
add_embeddings with multilingual-e5-base, batching, and bulk_update.
AymanL f24fe7d
test adaptation
AymanL 2a8f555
update readme
AymanL ec1352c
Add semantic search over DocumentChunk + fix docker-compose port
AymanL 5d26cea
implement tests for semantic search
AymanL 082356f
reuse same consts as the actual pipeline
AymanL 68160c5
Merge branch 'main' into pipeline_semantic_search
AymanL 1169663
Merge branch 'main' into pipeline_semantic_search
AymanL 0caf77c
factorize strip calls
AymanL 2721d29
k now raises ValueError if <= 0
AymanL 9e84893
move tests to dedicated class
AymanL 706df69
remove important since project requires python 3.12
AymanL File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| """Semantic search over ingested document chunks using pgvector.""" | ||
|
|
||
| from eu_fact_force.ingestion.embedding import embed_query | ||
| from eu_fact_force.ingestion.models import DocumentChunk | ||
| from pgvector.django import CosineDistance | ||
|
|
||
|
|
||
| def search_chunks(query: str, k: int = 10) -> list[tuple[DocumentChunk, float]]: | ||
| """ | ||
| Return the top-k document chunks most similar to the query. | ||
|
|
||
| The query is embedded with the same model as ingestion (E5, query prefix). | ||
| Results are ordered by cosine distance (lower is more similar). | ||
| Only chunks with a stored embedding are considered. | ||
|
|
||
| Returns a list of (chunk, distance) tuples. Chunk includes source_file | ||
| via the ORM relation for display (e.g. source_file.doi). | ||
| """ | ||
| if k <= 0: | ||
| raise ValueError("k must be a positive integer") | ||
| query_vector = embed_query(query) | ||
| qs = ( | ||
| DocumentChunk.objects.filter(embedding__isnull=False) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Est ce que postgres a un système d'ANN? plutot que de tester toutes les lignes? |
||
| .select_related("source_file") | ||
| .annotate(distance=CosineDistance("embedding", query_vector)) | ||
| .order_by("distance")[:k] | ||
| ) | ||
| return [(chunk, float(chunk.distance)) for chunk in qs] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| """Tests for semantic search over document chunks.""" | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
|
|
||
| from eu_fact_force.ingestion import parsing as parsing_module | ||
| from eu_fact_force.ingestion import search as search_module | ||
| from eu_fact_force.ingestion import services as services_module | ||
| from eu_fact_force.ingestion.chunking import MAX_CHUNK_CHARS | ||
| from eu_fact_force.ingestion.models import DocumentChunk, EMBEDDING_DIMENSIONS, SourceFile | ||
| from eu_fact_force.ingestion.services import run_pipeline | ||
|
|
||
| PROJECT_ROOT = Path(__file__).resolve().parents[2] | ||
|
|
||
| # Tolerance for float distance comparison. | ||
| DISTANCE_TOLERANCE = 1e-5 | ||
| # Paragraph length so chunking yields 3 separate chunks (each under MAX_CHUNK_CHARS). | ||
| PARAGRAPH_LEN = (MAX_CHUNK_CHARS // 2) + 1 | ||
| # Components for the second-closest vector (distance between near and far). | ||
| SECOND_CLOSEST_VEC_ALIGNED = 0.99 | ||
| SECOND_CLOSEST_VEC_OFF = 0.01 | ||
|
|
||
|
|
||
| def _constant_vector(value: float, dim: int = EMBEDDING_DIMENSIONS) -> list[float]: | ||
| """Vector with the same value in every dimension.""" | ||
| return [value] * dim | ||
|
|
||
|
|
||
| def _one_hot_vector(index: int, dim: int = EMBEDDING_DIMENSIONS) -> list[float]: | ||
| """One-hot vector: 1.0 at index, 0 elsewhere (for distinct cosine distances).""" | ||
| v = [0.0] * dim | ||
| v[index] = 1.0 | ||
| return v | ||
|
|
||
|
|
||
| class TestSearchChunks: | ||
| @pytest.mark.django_db | ||
| def test_returns_chunks_ordered_by_distance(self, monkeypatch): | ||
| """search_chunks returns (chunk, distance) tuples ordered by cosine distance.""" | ||
| source = SourceFile.objects.create( | ||
| doi="d", s3_key="k", status=SourceFile.Status.STORED | ||
| ) | ||
| # Query [1,0,...,0]: closest to chunk_near (same), then chunk_far ([0,1,0,...,0]). | ||
| chunk_far = DocumentChunk.objects.create( | ||
| source_file=source, content="far", order=1, embedding=_one_hot_vector(1) | ||
| ) | ||
| chunk_near = DocumentChunk.objects.create( | ||
| source_file=source, content="near", order=2, embedding=_one_hot_vector(0) | ||
| ) | ||
|
|
||
| monkeypatch.setattr(search_module, "embed_query", lambda _: _one_hot_vector(0)) | ||
|
|
||
| results = search_module.search_chunks("dummy", k=5) | ||
|
|
||
| assert len(results) == 2 | ||
| (first, d1), (second, d2) = results | ||
| assert first.content == "near" | ||
| assert first.pk == chunk_near.pk | ||
| assert second.content == "far" | ||
| assert second.pk == chunk_far.pk | ||
| assert d1 == pytest.approx(0.0, abs=DISTANCE_TOLERANCE) | ||
| assert d2 > 0 | ||
|
|
||
| @pytest.mark.django_db | ||
| def test_excludes_chunks_without_embedding(self, monkeypatch): | ||
| """Only chunks with a stored embedding are returned.""" | ||
| source = SourceFile.objects.create( | ||
| doi="d", s3_key="k", status=SourceFile.Status.STORED | ||
| ) | ||
| with_emb = DocumentChunk.objects.create( | ||
| source_file=source, content="with", order=1, embedding=_constant_vector(0.1) | ||
| ) | ||
| DocumentChunk.objects.create( | ||
| source_file=source, content="without", order=2, embedding=None | ||
| ) | ||
|
|
||
| monkeypatch.setattr( | ||
| search_module, "embed_query", lambda _: _constant_vector(0.1) | ||
| ) | ||
|
|
||
| results = search_module.search_chunks("q", k=5) | ||
|
|
||
| assert len(results) == 1 | ||
| assert results[0][0].content == "with" | ||
| assert results[0][0].pk == with_emb.pk | ||
|
|
||
| @pytest.mark.django_db | ||
| def test_respects_k(self, monkeypatch): | ||
| """At most k results are returned.""" | ||
| source = SourceFile.objects.create( | ||
| doi="d", s3_key="k", status=SourceFile.Status.STORED | ||
| ) | ||
| for i in range(5): | ||
| DocumentChunk.objects.create( | ||
| source_file=source, | ||
| content=f"c{i}", | ||
| order=i, | ||
| embedding=_constant_vector(0.5), | ||
| ) | ||
|
|
||
| monkeypatch.setattr( | ||
| search_module, "embed_query", lambda _: _constant_vector(0.5) | ||
| ) | ||
|
|
||
| assert len(search_module.search_chunks("q", k=2)) == 2 | ||
| assert len(search_module.search_chunks("q", k=10)) == 5 | ||
|
|
||
| @pytest.mark.django_db | ||
| def test_k_zero_raises_value_error(self): | ||
| """k<=0 raises ValueError to signal incorrect usage.""" | ||
| with pytest.raises(ValueError): | ||
| search_module.search_chunks("q", k=0) | ||
|
|
||
| @pytest.mark.django_db | ||
| def test_pipeline_then_search_returns_chunks_ordered_by_similarity( | ||
| self, tmp_storage, monkeypatch | ||
| ): | ||
| """Run pipeline with mocked parse and add_embeddings, then search; order matches.""" | ||
| readme_fn = PROJECT_ROOT / "README.md" | ||
| assert readme_fn.exists(), f"Test file must exist: {readme_fn}" | ||
|
|
||
| # Long enough so paragraph chunking produces three separate chunks (max_chunk_chars=1200). | ||
| p1, p2, p3 = "A" * PARAGRAPH_LEN, "B" * PARAGRAPH_LEN, "C" * PARAGRAPH_LEN | ||
| parsed_text = f"{p1}\n\n{p2}\n\n{p3}" | ||
| monkeypatch.setattr( | ||
| parsing_module, | ||
| "_extract_text_from_source_file", | ||
| lambda _: parsed_text, | ||
| ) | ||
|
|
||
| # Vectors so cosine distance order is well-defined: p1 closest, then p2, then p3. | ||
| def _add_known_embeddings(chunks): | ||
| near = _one_hot_vector(0) | ||
| mid = [SECOND_CLOSEST_VEC_ALIGNED] + [SECOND_CLOSEST_VEC_OFF] + [0.0] * ( | ||
| EMBEDDING_DIMENSIONS - 2 | ||
| ) | ||
| far = _one_hot_vector(1) | ||
| vecs = [near, mid, far] | ||
| for i, ch in enumerate(chunks): | ||
| if ch.pk and ch.content.strip() and i < len(vecs): | ||
| ch.embedding = vecs[i] | ||
| DocumentChunk.objects.bulk_update(chunks, ["embedding"]) | ||
|
|
||
| monkeypatch.setattr(services_module, "add_embeddings", _add_known_embeddings) | ||
| monkeypatch.setattr(search_module, "embed_query", lambda _: _one_hot_vector(0)) | ||
|
|
||
| source_file, _ = run_pipeline("README.md") | ||
| results = search_module.search_chunks("query", k=5) | ||
|
|
||
| contents = [r[0].content for r in results] | ||
| assert contents == [p1, p2, p3] | ||
| assert all(r[0].source_file_id == source_file.pk for r in results) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plutot un message d'erreur?