diff --git a/spacy_layout/layout.py b/spacy_layout/layout.py index 6e0467b..b865faa 100644 --- a/spacy_layout/layout.py +++ b/spacy_layout/layout.py @@ -20,6 +20,7 @@ TABLE_PLACEHOLDER = "TABLE" +TABLE_ITEM_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX] # Register msgpack encoders and decoders for custom types srsly.msgpack_encoders.register("spacy-layout.dataclass", func=encode_obj) @@ -146,7 +147,7 @@ def _texts_to_doc( span = Span(doc, start=start, end=end, label=item.label, span_id=i) layout = self._get_span_layout(item, pages) span._.set(self.attrs.span_layout, layout) - if item.label == DocItemLabel.TABLE: + if item.label in TABLE_ITEM_LABELS: span._.set(self.attrs.span_data, item.export_to_dataframe()) spans.append(span) doc.spans[self.attrs.span_group] = SpanGroup( @@ -190,5 +191,5 @@ def get_tables(self, doc: Doc) -> list[Span]: return [ span for span in doc.spans[self.attrs.span_group] - if span.label_ == DocItemLabel.TABLE + if span.label_ in TABLE_ITEM_LABELS ] diff --git a/tests/data/table_document_index.pdf b/tests/data/table_document_index.pdf new file mode 100644 index 0000000..cdfa135 Binary files /dev/null and b/tests/data/table_document_index.pdf differ diff --git a/tests/test_general.py b/tests/test_general.py index 8a06311..e7e5e5a 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -8,6 +8,7 @@ from pandas import DataFrame from pandas.testing import assert_frame_equal from spacy.tokens import DocBin +import pandas as pd from spacy_layout import spaCyLayout from spacy_layout.layout import TABLE_PLACEHOLDER, get_bounding_box @@ -18,6 +19,7 @@ DOCX_SIMPLE = Path(__file__).parent / "data" / "simple.docx" PDF_SIMPLE_BYTES = PDF_SIMPLE.open("rb").read() PDF_TABLE = Path(__file__).parent / "data" / "table.pdf" +PDF_INDEX = Path(__file__).parent / "data" / "table_document_index.pdf" @pytest.fixture @@ -101,6 +103,23 @@ def test_table(nlp): assert markdown in doc._.get(layout.attrs.doc_markdown) +def test_table_index(nlp): + layout = spaCyLayout(nlp) + doc = layout(PDF_INDEX) + assert len(doc._.get(layout.attrs.doc_tables)) == 3 + table = doc._.get(layout.attrs.doc_tables)[0] + assert table.text == TABLE_PLACEHOLDER + assert table.label_ == DocItemLabel.DOCUMENT_INDEX.value + + # Check that each document_index table has a dataframe + document_index_tables = [span for span in doc._.get( + layout.attrs.doc_tables) if span.label_ == DocItemLabel.DOCUMENT_INDEX.value] + for table in document_index_tables: + assert table._.data is not None, "Table data not available" + assert isinstance( + table._.data, pd.DataFrame), "Table data is not a DataFrame" + + def test_table_placeholder(nlp): def display_table(df): return f"Table with columns: {', '.join(df.columns.tolist())}"