diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d4e88ee --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.mypy] +warn_unused_ignores = true diff --git a/requirements.txt b/requirements.txt index 2f8e4ae..94e90b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ pandas # version range set by Docling srsly # version range set by spaCy # Dev requirements pytest +pandas-stubs \ No newline at end of file diff --git a/spacy_layout/layout.py b/spacy_layout/layout.py index a699230..abf9955 100644 --- a/spacy_layout/layout.py +++ b/spacy_layout/layout.py @@ -11,10 +11,10 @@ overload, ) -import srsly +import srsly # type: ignore[import-untyped] from docling.datamodel.base_models import DocumentStream from docling.document_converter import DocumentConverter -from docling_core.types.doc.document import DoclingDocument +from docling_core.types.doc.document import DoclingDocument, TableItem, TextItem from docling_core.types.doc.labels import DocItemLabel from spacy.tokens import Doc, Span, SpanGroup @@ -130,7 +130,7 @@ def _get_source(self, source: str | Path | bytes) -> str | Path | DocumentStream return DocumentStream(name="source", stream=BytesIO(source)) def _result_to_doc(self, document: DoclingDocument) -> Doc: - inputs = [] + inputs: list[tuple[str, TextItem | TableItem]] = [] pages = { (page.page_no): PageLayout( page_no=page.page_no, @@ -144,17 +144,17 @@ def _result_to_doc(self, document: DoclingDocument) -> Doc: # We want to iterate over the tree to get different elements in order for node, _ in document.iterate_items(): if node.self_ref in text_items: - item = text_items[node.self_ref] - if item.text == "": + text_item = text_items[node.self_ref] + if text_item.text == "": continue - inputs.append((item.text, item)) + inputs.append((text_item.text, text_item)) elif node.self_ref in table_items: - item = table_items[node.self_ref] + table_item = table_items[node.self_ref] if isinstance(self.display_table, str): table_text = self.display_table else: - table_text = self.display_table(item.export_to_dataframe()) - inputs.append((table_text, item)) + table_text = self.display_table(table_item.export_to_dataframe()) + inputs.append((table_text, table_item)) doc = self._texts_to_doc(inputs, pages) doc._.set(self.attrs.doc_layout, DocLayout(pages=[p for p in pages.values()])) doc._.set(self.attrs.doc_markdown, document.export_to_markdown()) @@ -189,6 +189,7 @@ def _texts_to_doc( layout = self._get_span_layout(item, pages) span._.set(self.attrs.span_layout, layout) if item.label in TABLE_ITEM_LABELS: + item = cast(TableItem, item) span._.set(self.attrs.span_data, item.export_to_dataframe()) spans.append(span) doc.spans[self.attrs.span_group] = SpanGroup( @@ -207,12 +208,13 @@ def _get_span_layout( return SpanLayout( x=x, y=y, width=width, height=height, page_no=prov.page_no ) + return None def get_pages(self, doc: Doc) -> list[tuple[PageLayout, list[Span]]]: """Get all pages and their layout spans.""" layout = doc._.get(self.attrs.doc_layout) pages = {page.page_no: page for page in layout.pages} - page_spans = {page.page_no: [] for page in layout.pages} + page_spans: dict[int, list[Span]] = {page.page_no: [] for page in layout.pages} for span in doc.spans[self.attrs.span_group]: span_layout = span._.get(self.attrs.span_layout) page_spans[span_layout.page_no].append(span) @@ -226,6 +228,7 @@ def get_heading(self, span: Span) -> Span | None: for candidate in spans[: span.id][::-1]: if candidate.label_ in self.headings: return candidate + return None def get_tables(self, doc: Doc) -> list[Span]: """Get all tables in the document.""" diff --git a/spacy_layout/py.typed b/spacy_layout/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/spacy_layout/util.py b/spacy_layout/util.py index 6b59ddb..d4a8d0c 100644 --- a/spacy_layout/util.py +++ b/spacy_layout/util.py @@ -10,7 +10,8 @@ from docling_core.types.doc.base import BoundingBox TYPE_ATTR = "__type__" -OBJ_TYPES = {"SpanLayout": SpanLayout, "DocLayout": DocLayout, "PageLayout": PageLayout} +Layouts = SpanLayout | DocLayout | PageLayout +OBJ_TYPES: dict[str, type[Layouts]] = {"SpanLayout": SpanLayout, "DocLayout": DocLayout, "PageLayout": PageLayout} def encode_obj(obj: Any, chain: Callable | None = None) -> Any: