From c81bd10ae643efd2c6816cb2506feaf012ddc4b5 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 7 Mar 2025 07:07:26 +0100 Subject: [PATCH] Add as_tuples argument to spaCyLayout.pipe --- README.md | 7 +++--- spacy_layout/layout.py | 53 +++++++++++++++++++++++++++++++++++++----- tests/test_general.py | 10 ++++++++ 3 files changed, 61 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 988cbe2..7a20993 100644 --- a/README.md +++ b/README.md @@ -186,7 +186,7 @@ doc = layout("./starcraft.pdf") #### method `spaCyLayout.pipe` -Process multiple documents and create spaCy [`Doc`](https://spacy.io/api/doc) objects. You should use this method if you're processing larger volumes of documents at scale. +Process multiple documents and create spaCy [`Doc`](https://spacy.io/api/doc) objects. You should use this method if you're processing larger volumes of documents at scale. The behavior of `as_tuples` works like it does in spaCy's [`Language.pipe`](https://spacy.io/api/language#pipe). ```python layout = spaCyLayout(nlp) @@ -196,5 +196,6 @@ docs = layout.pipe(paths) | Argument | Type | Description | | --- | --- | --- | -| `sources` | `Iterable[str \| Path \| bytes]` | Paths of documents to process or bytes. | -| **YIELDS** | `Doc` | The processed spaCy `Doc` object. | +| `sources` | `Iterable[str \| Path \| bytes] \| Iterable[tuple[str \| Path \| bytes, Any]]` | Paths of documents to process or bytes, or `(source, context)` tuples if `as_tuples` is set to `True`. | +| `as_tuples` | `bool` | If set to `True`, inputs should be an iterable of `(source, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. | +| **YIELDS** | `Doc \| tuple[Doc, Any]` | The processed spaCy `Doc` objects or `(doc, context)` tuples if `as_tuples` is set to `True`. | diff --git a/spacy_layout/layout.py b/spacy_layout/layout.py index 6e0467b..2231a73 100644 --- a/spacy_layout/layout.py +++ b/spacy_layout/layout.py @@ -1,6 +1,15 @@ from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Callable, Iterable, Iterator +from typing import ( + TYPE_CHECKING, + Callable, + Iterable, + Iterator, + Literal, + TypeVar, + cast, + overload, +) import srsly from docling.datamodel.base_models import DocumentStream @@ -18,6 +27,8 @@ from pandas import DataFrame from spacy.language import Language +# Type variable for contexts piped with documents +_AnyContext = TypeVar("_AnyContext") TABLE_PLACEHOLDER = "TABLE" @@ -75,12 +86,42 @@ def __call__(self, source: str | Path | bytes | DoclingDocument) -> Doc: result = self.converter.convert(self._get_source(source)).document return self._result_to_doc(result) - def pipe(self, sources: Iterable[str | Path | bytes]) -> Iterator[Doc]: + @overload + def pipe( + self, + sources: Iterable[str | Path | bytes], + as_tuples: Literal[False] = ..., + ) -> Iterator[Doc]: ... + + @overload + def pipe( + self, + sources: Iterable[tuple[str | Path | bytes, _AnyContext]], + as_tuples: Literal[True] = ..., + ) -> Iterator[tuple[Doc, _AnyContext]]: ... + + def pipe( + self, + sources: ( + Iterable[str | Path | bytes] + | Iterable[tuple[str | Path | bytes, _AnyContext]] + ), + as_tuples: bool = False, + ) -> Iterator[Doc] | Iterator[tuple[Doc, _AnyContext]]: """Process multiple documents and create spaCy Doc objects.""" - data = (self._get_source(source) for source in sources) - results = self.converter.convert_all(data) - for result in results: - yield self._result_to_doc(result.document) + if as_tuples: + sources = cast(Iterable[tuple[str | Path | bytes, _AnyContext]], sources) + data = (self._get_source(source) for source, _ in sources) + contexts = (context for _, context in sources) + results = self.converter.convert_all(data) + for result, context in zip(results, contexts): + yield (self._result_to_doc(result.document), context) + else: + sources = cast(Iterable[str | Path | bytes], sources) + data = (self._get_source(source) for source in sources) + results = self.converter.convert_all(data) + for result in results: + yield self._result_to_doc(result.document) def _get_source(self, source: str | Path | bytes) -> str | Path | DocumentStream: if isinstance(source, (str, Path)): diff --git a/tests/test_general.py b/tests/test_general.py index 2b80cfa..08708ac 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -40,6 +40,7 @@ def test_general(path, nlp, span_labels): assert span.label_ in span_labels assert isinstance(span._.get(layout.attrs.span_layout), SpanLayout) + @pytest.mark.parametrize("path, pg_no", [(PDF_STARCRAFT, 6), (PDF_SIMPLE, 1)]) def test_pages(path, pg_no, nlp): layout = spaCyLayout(nlp) @@ -73,6 +74,15 @@ def test_simple_pipe(nlp): assert len(doc.spans[layout.attrs.span_group]) == 4 +def test_simple_pipe_as_tuples(nlp): + layout = spaCyLayout(nlp) + data = [(PDF_SIMPLE, "pdf"), (DOCX_SIMPLE, "docx")] + result = list(layout.pipe(data, as_tuples=True)) + for doc, _ in result: + assert len(doc.spans[layout.attrs.span_group]) == 4 + assert [context for _, context in result] == ["pdf", "docx"] + + def test_table(nlp): layout = spaCyLayout(nlp) doc = layout(PDF_TABLE)