Skip to content

Commit 8dfc007

Browse files
authored
Improve Layout (#24)
* Enforce keyword arguments for page_data * Add sort func for layout, suggested by #23
1 parent e035fc8 commit 8dfc007

File tree

2 files changed

+129
-46
lines changed

2 files changed

+129
-46
lines changed

src/layoutparser/elements.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,13 @@ def is_in(self, other, soft_margin={}, center=False):
274274

275275
@abstractmethod
276276
def intersect(self, other: "BaseCoordElement", strict: bool = True):
277-
"""Intersect the current shape with the other object, with operations defined in
277+
"""Intersect the current shape with the other object, with operations defined in
278278
:doc:`../notes/shape_operations`.
279279
"""
280280

281281
@abstractmethod
282282
def union(self, other: "BaseCoordElement", strict: bool = True):
283-
"""Union the current shape with the other object, with operations defined in
283+
"""Union the current shape with the other object, with operations defined in
284284
:doc:`../notes/shape_operations`.
285285
"""
286286

@@ -1727,18 +1727,19 @@ class Layout(MutableSequence):
17271727
A list of layout element blocks
17281728
page_data (Dict, optional):
17291729
A dictionary storing the page (canvas) related information
1730-
like `height`, `width`, etc.
1730+
like `height`, `width`, etc. It should be passed in as a
1731+
keyword argument to avoid any confusion.
17311732
Defaults to None.
17321733
"""
17331734

1734-
def __init__(self, blocks: Optional[List] = None, page_data: Dict = None):
1735+
def __init__(self, blocks: Optional[List] = None, *, page_data: Dict = None):
17351736
self._blocks = blocks if blocks is not None else []
17361737
self.page_data = page_data or {}
17371738

17381739
def __getitem__(self, key):
17391740
blocks = self._blocks[key]
17401741
if isinstance(key, slice):
1741-
return self.__class__(self._blocks[key], self.page_data)
1742+
return self.__class__(self._blocks[key], page_data=self.page_data)
17421743
else:
17431744
return blocks
17441745

@@ -1771,17 +1772,20 @@ def __eq__(self, other):
17711772
def __add__(self, other):
17721773
if isinstance(other, Layout):
17731774
if self.page_data == other.page_data:
1774-
return self.__class__(self._blocks + other._blocks, self.page_data)
1775+
return self.__class__(
1776+
self._blocks + other._blocks, page_data=self.page_data
1777+
)
17751778
elif self.page_data == {} or other.page_data == {}:
17761779
return self.__class__(
1777-
self._blocks + other._blocks, self.page_data or other.page_data
1780+
self._blocks + other._blocks,
1781+
page_data=self.page_data or other.page_data,
17781782
)
17791783
else:
17801784
raise ValueError(
17811785
f"Incompatible page_data for two innputs: {self.page_data} vs {other.page_data}."
17821786
)
17831787
elif isinstance(other, list):
1784-
return self.__class__(self._blocks + other, self.page_data)
1788+
return self.__class__(self._blocks + other, page_data=self.page_data)
17851789
else:
17861790
raise ValueError(
17871791
f"Invalid input type for other {other.__class__.__name__}."
@@ -1791,19 +1795,50 @@ def insert(self, key, value):
17911795
self._blocks.insert(key, value)
17921796

17931797
def copy(self):
1794-
return self.__class__(copy(self._blocks), self.page_data)
1798+
return self.__class__(copy(self._blocks), page_data=self.page_data)
17951799

17961800
def relative_to(self, other):
1797-
return self.__class__([ele.relative_to(other) for ele in self], self.page_data)
1801+
return self.__class__(
1802+
[ele.relative_to(other) for ele in self], page_data=self.page_data
1803+
)
17981804

17991805
def condition_on(self, other):
1800-
return self.__class__([ele.condition_on(other) for ele in self], self.page_data)
1806+
return self.__class__(
1807+
[ele.condition_on(other) for ele in self], page_data=self.page_data
1808+
)
18011809

18021810
def is_in(self, other, soft_margin={}, center=False):
18031811
return self.__class__(
1804-
[ele.is_in(other, soft_margin, center) for ele in self], self.page_data
1812+
[ele.is_in(other, soft_margin, center) for ele in self],
1813+
page_data=self.page_data,
18051814
)
18061815

1816+
def sort(self, key=None, reverse=False, inplace=False) -> Optional["Layout"]:
1817+
"""Sort the list of blocks based on the given
1818+
1819+
Args:
1820+
key ([type], optional): key specifies a function of one argument that
1821+
is used to extract a comparison key from each list element.
1822+
Defaults to None.
1823+
reverse (bool, optional): reverse is a boolean value. If set to True,
1824+
then the list elements are sorted as if each comparison were reversed.
1825+
Defaults to False.
1826+
inplace (bool, optional): whether to perform the sort inplace. If set
1827+
to False, it will return another object instance with _block sorted in
1828+
the order. Defaults to False.
1829+
1830+
Examples::
1831+
>>> import layoutparser as lp
1832+
>>> i = lp.Interval(4, 5, axis="y")
1833+
>>> l = lp.Layout([i, i.shift(2)])
1834+
>>> l.sort(key=lambda x: x.coordinates[1], reverse=True)
1835+
1836+
"""
1837+
if not inplace:
1838+
return self.__class__(sorted(self._blocks, key=key, reverse=reverse), page_data=self.page_data)
1839+
else:
1840+
self._blocks.sort(key=key, reverse=reverse)
1841+
18071842
def filter_by(self, other, soft_margin={}, center=False):
18081843
"""
18091844
Return a `Layout` object containing the elements that are in the `other` object.
@@ -1818,7 +1853,7 @@ def filter_by(self, other, soft_margin={}, center=False):
18181853
"""
18191854
return self.__class__(
18201855
[ele for ele in self if ele.is_in(other, soft_margin, center)],
1821-
self.page_data,
1856+
page_data=self.page_data,
18221857
)
18231858

18241859
def shift(self, shift_distance):
@@ -1835,7 +1870,7 @@ def shift(self, shift_distance):
18351870
A new layout object with all the elements shifted in the specified values.
18361871
"""
18371872
return self.__class__(
1838-
[ele.shift(shift_distance) for ele in self], self.page_data
1873+
[ele.shift(shift_distance) for ele in self], page_data=self.page_data
18391874
)
18401875

18411876
def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True):
@@ -1856,7 +1891,7 @@ def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True):
18561891
"""
18571892
return self.__class__(
18581893
[ele.pad(left, right, top, bottom, safe_mode) for ele in self],
1859-
self.page_data,
1894+
page_data=self.page_data,
18601895
)
18611896

18621897
def scale(self, scale_factor):
@@ -1871,7 +1906,9 @@ def scale(self, scale_factor):
18711906
:obj:`Layout`:
18721907
A new layout object with all the elements scaled in the specified values.
18731908
"""
1874-
return self.__class__([ele.scale(scale_factor) for ele in self], self.page_data)
1909+
return self.__class__(
1910+
[ele.scale(scale_factor) for ele in self], page_data=self.page_data
1911+
)
18751912

18761913
def crop_image(self, image):
18771914
return [ele.crop_image(image) for ele in self]

tests/test_elements.py

Lines changed: 76 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22
import numpy as np
33
import pandas as pd
44

5-
from layoutparser.elements import Interval, Rectangle, Quadrilateral, TextBlock, Layout, InvalidShapeError, NotSupportedShapeError
5+
from layoutparser.elements import (
6+
Interval,
7+
Rectangle,
8+
Quadrilateral,
9+
TextBlock,
10+
Layout,
11+
InvalidShapeError,
12+
NotSupportedShapeError,
13+
)
614

715

816
def test_interval():
@@ -46,7 +54,7 @@ def test_quadrilateral():
4654

4755
points = np.array([[2, 2], [6, 2], [6, 7], [2, 6]])
4856
q = Quadrilateral(points)
49-
q.to_interval(axis='x')
57+
q.to_interval(axis="x")
5058
q.to_rectangle()
5159
assert q.shift(1) == Quadrilateral(points + 1)
5260
assert q.shift([1, 2]) == Quadrilateral(points + np.array([1, 2]))
@@ -70,7 +78,7 @@ def test_quadrilateral():
7078

7179
q = Quadrilateral([1, 2, 3, 4, 5, 6, 7, 8])
7280
assert (q.points == np.array([[1, 2], [3, 4], [5, 6], [7, 8]])).all()
73-
81+
7482
q = Quadrilateral([[1, 2], [3, 4], [5, 6], [7, 8]])
7583
assert (q.points == np.array([[1, 2], [3, 4], [5, 6], [7, 8]])).all()
7684

@@ -173,25 +181,48 @@ def test_textblock():
173181

174182
t = TextBlock(q, score=0.2)
175183

176-
# Additional test for shape conversion
177-
assert TextBlock(i, id=1, type=2, text="12").to_interval() == TextBlock(i, id=1, type=2, text="12")
178-
assert TextBlock(i, id=1, type=2, text="12").to_rectangle() == TextBlock(i.to_rectangle(), id=1, type=2, text="12")
179-
assert TextBlock(i, id=1, type=2, text="12").to_quadrilateral() == TextBlock(i.to_quadrilateral(), id=1, type=2, text="12")
180-
181-
assert TextBlock(r, id=1, type=2, parent="a").to_interval(axis="x") == TextBlock(r.to_interval(axis="x"), id=1, type=2, parent="a")
182-
assert TextBlock(r, id=1, type=2, parent="a").to_interval(axis="y") == TextBlock(r.to_interval(axis="y"), id=1, type=2, parent="a")
183-
assert TextBlock(r, id=1, type=2, parent="a").to_rectangle() == TextBlock(r, id=1, type=2, parent="a")
184-
assert TextBlock(r, id=1, type=2, parent="a").to_quadrilateral() == TextBlock(r.to_quadrilateral(), id=1, type=2, parent="a")
185-
186-
assert TextBlock(q, id=1, type=2, parent="a").to_interval(axis="x") == TextBlock(q.to_interval(axis="x"), id=1, type=2, parent="a")
187-
assert TextBlock(q, id=1, type=2, parent="a").to_interval(axis="y") == TextBlock(q.to_interval(axis="y"), id=1, type=2, parent="a")
188-
assert TextBlock(q, id=1, type=2, parent="a").to_rectangle() == TextBlock(q.to_rectangle(), id=1, type=2, parent="a")
189-
assert TextBlock(q, id=1, type=2, parent="a").to_quadrilateral() == TextBlock(q, id=1, type=2, parent="a")
184+
# Additional test for shape conversion
185+
assert TextBlock(i, id=1, type=2, text="12").to_interval() == TextBlock(
186+
i, id=1, type=2, text="12"
187+
)
188+
assert TextBlock(i, id=1, type=2, text="12").to_rectangle() == TextBlock(
189+
i.to_rectangle(), id=1, type=2, text="12"
190+
)
191+
assert TextBlock(i, id=1, type=2, text="12").to_quadrilateral() == TextBlock(
192+
i.to_quadrilateral(), id=1, type=2, text="12"
193+
)
194+
195+
assert TextBlock(r, id=1, type=2, parent="a").to_interval(axis="x") == TextBlock(
196+
r.to_interval(axis="x"), id=1, type=2, parent="a"
197+
)
198+
assert TextBlock(r, id=1, type=2, parent="a").to_interval(axis="y") == TextBlock(
199+
r.to_interval(axis="y"), id=1, type=2, parent="a"
200+
)
201+
assert TextBlock(r, id=1, type=2, parent="a").to_rectangle() == TextBlock(
202+
r, id=1, type=2, parent="a"
203+
)
204+
assert TextBlock(r, id=1, type=2, parent="a").to_quadrilateral() == TextBlock(
205+
r.to_quadrilateral(), id=1, type=2, parent="a"
206+
)
207+
208+
assert TextBlock(q, id=1, type=2, parent="a").to_interval(axis="x") == TextBlock(
209+
q.to_interval(axis="x"), id=1, type=2, parent="a"
210+
)
211+
assert TextBlock(q, id=1, type=2, parent="a").to_interval(axis="y") == TextBlock(
212+
q.to_interval(axis="y"), id=1, type=2, parent="a"
213+
)
214+
assert TextBlock(q, id=1, type=2, parent="a").to_rectangle() == TextBlock(
215+
q.to_rectangle(), id=1, type=2, parent="a"
216+
)
217+
assert TextBlock(q, id=1, type=2, parent="a").to_quadrilateral() == TextBlock(
218+
q, id=1, type=2, parent="a"
219+
)
190220

191221
with pytest.raises(ValueError):
192222
TextBlock(q, id=1, type=2, parent="a").to_interval()
193223
TextBlock(r, id=1, type=2, parent="a").to_interval()
194224

225+
195226
def test_layout():
196227
i = Interval(4, 5, axis="y")
197228
q = Quadrilateral(np.array([[2, 2], [6, 2], [6, 7], [2, 5]]))
@@ -241,6 +272,20 @@ def test_layout():
241272
l.page_data = {"width": 200, "height": 400}
242273
l + l2
243274

275+
# Test sort
276+
l = Layout([i, i.shift(2)])
277+
l.sort(key=lambda x: x.coordinates[1], reverse=True)
278+
assert l == Layout([i.shift(2), i])
279+
280+
l = Layout([q, r, i], page_data={"width": 200, "height": 400})
281+
assert l.sort(key=lambda x: x.coordinates[0], inplace=False) == Layout(
282+
[i, q, r], page_data={"width": 200, "height": 400}
283+
)
284+
285+
l = Layout([q, t])
286+
assert l.sort(key=lambda x: x.coordinates[0], inplace=False) == Layout([q, t])
287+
288+
244289
def test_shape_operations():
245290
i_1 = Interval(1, 2, axis="y", canvas_height=30, canvas_width=400)
246291
i_2 = TextBlock(Interval(1, 2, axis="x"))
@@ -249,19 +294,19 @@ def test_shape_operations():
249294
r_1 = Rectangle(0.5, 0.5, 2.5, 1.5)
250295
r_2 = TextBlock(Rectangle(0.5, 0.5, 2, 2.5))
251296

252-
q_1 = Quadrilateral([[1,1], [2.5, 1.2], [2.5, 3], [1.5, 3]])
253-
q_2 = TextBlock(Quadrilateral([[0.5, 0.5], [2,1], [1.5, 2.5], [0.5, 2]]))
297+
q_1 = Quadrilateral([[1, 1], [2.5, 1.2], [2.5, 3], [1.5, 3]])
298+
q_2 = TextBlock(Quadrilateral([[0.5, 0.5], [2, 1], [1.5, 2.5], [0.5, 2]]))
254299

255-
# I and I in different axes
300+
# I and I in different axes
256301
assert i_1.intersect(i_1) == i_1
257-
assert i_1.intersect(i_2) == Rectangle(1,1,2,2)
258-
assert i_1.intersect(i_3) == i_1 # Ensure intersect copy the canvas size
302+
assert i_1.intersect(i_2) == Rectangle(1, 1, 2, 2)
303+
assert i_1.intersect(i_3) == i_1 # Ensure intersect copy the canvas size
259304

260305
assert i_1.union(i_1) == i_1
261306
with pytest.raises(InvalidShapeError):
262-
assert i_1.union(i_2) == Rectangle(1,1,2,2)
307+
assert i_1.union(i_2) == Rectangle(1, 1, 2, 2)
263308

264-
# I and R in different axes
309+
# I and R in different axes
265310
assert i_1.intersect(r_1) == Rectangle(0.5, 1, 2.5, 1.5)
266311
assert i_2.intersect(r_1).block == Rectangle(1, 0.5, 2, 1.5)
267312
assert i_1.union(r_1) == Rectangle(0.5, 0.5, 2.5, 2)
@@ -271,12 +316,12 @@ def test_shape_operations():
271316
with pytest.raises(NotSupportedShapeError):
272317
i_1.intersect(q_1)
273318
i_1.union(q_1)
274-
319+
275320
# I and Q in different axes
276-
assert i_1.intersect(q_1, strict=False) == Rectangle(1,1,2.5,2)
277-
assert i_1.union(q_1, strict=False) == Rectangle(1,1,2.5,3)
321+
assert i_1.intersect(q_1, strict=False) == Rectangle(1, 1, 2.5, 2)
322+
assert i_1.union(q_1, strict=False) == Rectangle(1, 1, 2.5, 3)
278323
assert i_2.intersect(q_1, strict=False).block == Rectangle(1, 1, 2, 3)
279-
assert i_2.union(q_1, strict=False).block == Rectangle(1,1,2.5,3)
324+
assert i_2.union(q_1, strict=False).block == Rectangle(1, 1, 2.5, 3)
280325

281326
# R and I
282327
assert r_1.intersect(i_1) == i_1.intersect(r_1)
@@ -289,7 +334,7 @@ def test_shape_operations():
289334
with pytest.raises(NotSupportedShapeError):
290335
r_1.intersect(q_1)
291336
r_1.union(q_1)
292-
337+
293338
assert r_1.intersect(q_1, strict=False) == Rectangle(1, 1, 2.5, 1.5)
294339
assert r_1.union(q_1, strict=False) == Rectangle(0.5, 0.5, 2.5, 3)
295340
assert r_1.intersect(q_2, strict=False) == r_1.intersect(q_2.to_rectangle())
@@ -300,7 +345,7 @@ def test_shape_operations():
300345
q_1.intersect(i_1)
301346
q_1.intersect(r_1)
302347
q_1.intersect(q_2)
303-
348+
304349
# Q and I
305350
assert q_1.intersect(i_1, strict=False) == i_1.intersect(q_1, strict=False)
306351
assert q_1.union(i_1, strict=False) == i_1.union(q_1, strict=False)
@@ -315,6 +360,7 @@ def test_shape_operations():
315360
assert q_1.union(q_2, strict=False) == q_2.union(q_1, strict=False).block
316361
assert q_1.union(q_2, strict=False) == Rectangle(0.5, 0.5, 2.5, 3)
317362

363+
318364
def test_dict():
319365

320366
i = Interval(1, 2, "y", canvas_height=5)

0 commit comments

Comments
 (0)