Skip to content

Commit e035fc8

Browse files
authored
Misc improvements (#21)
* Add show_element_type in draw_box * Support lazy loading of label_maps in Detectron2LayoutModels * Add the enforce_cpu flag for Detectron2LayoutModel, inspired by #16
1 parent f599861 commit e035fc8

File tree

4 files changed

+96
-31
lines changed

4 files changed

+96
-31
lines changed

src/layoutparser/models/catalog.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,83 @@
11
from iopath.common.file_io import PathHandler, PathManager, HTTPURLHandler
22
from iopath.common.file_io import PathManager as PathManagerBase
3+
34
# A trick learned from https://github.com/facebookresearch/detectron2/blob/65faeb4779e4c142484deeece18dc958c5c9ad18/detectron2/utils/file_io.py#L3
45

56
MODEL_CATALOG = {
6-
'HJDataset': {
7-
'faster_rcnn_R_50_FPN_3x': 'https://www.dropbox.com/s/6icw6at8m28a2ho/model_final.pth?dl=1',
8-
'mask_rcnn_R_50_FPN_3x': 'https://www.dropbox.com/s/893paxpy5suvlx9/model_final.pth?dl=1',
9-
'retinanet_R_50_FPN_3x': 'https://www.dropbox.com/s/yxsloxu3djt456i/model_final.pth?dl=1'
7+
"HJDataset": {
8+
"faster_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/6icw6at8m28a2ho/model_final.pth?dl=1",
9+
"mask_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/893paxpy5suvlx9/model_final.pth?dl=1",
10+
"retinanet_R_50_FPN_3x": "https://www.dropbox.com/s/yxsloxu3djt456i/model_final.pth?dl=1",
1011
},
1112
"PubLayNet": {
1213
"faster_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/dgy9c10wykk4lq4/model_final.pth?dl=1",
1314
"mask_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/d9fc9tahfzyl6df/model_final.pth?dl=1",
14-
"mask_rcnn_X_101_32x8d_FPN_3x": "https://www.dropbox.com/s/57zjbwv6gh3srry/model_final.pth?dl=1"
15+
"mask_rcnn_X_101_32x8d_FPN_3x": "https://www.dropbox.com/s/57zjbwv6gh3srry/model_final.pth?dl=1",
1516
},
1617
"PrimaLayout": {
1718
"mask_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/h7th27jfv19rxiy/model_final.pth?dl=1"
1819
},
1920
"NewspaperNavigator": {
20-
'faster_rcnn_R_50_FPN_3x': 'https://www.dropbox.com/s/6ewh6g8rqt2ev3a/model_final.pth?dl=1',
21+
"faster_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/6ewh6g8rqt2ev3a/model_final.pth?dl=1",
2122
},
2223
"TableBank": {
23-
'faster_rcnn_R_50_FPN_3x': 'https://www.dropbox.com/s/8v4uqmz1at9v72a/model_final.pth?dl=1',
24-
'faster_rcnn_R_101_FPN_3x': 'https://www.dropbox.com/s/6vzfk8lk9xvyitg/model_final.pth?dl=1',
24+
"faster_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/8v4uqmz1at9v72a/model_final.pth?dl=1",
25+
"faster_rcnn_R_101_FPN_3x": "https://www.dropbox.com/s/6vzfk8lk9xvyitg/model_final.pth?dl=1",
2526
},
2627
}
2728

2829
CONFIG_CATALOG = {
29-
'HJDataset': {
30-
'faster_rcnn_R_50_FPN_3x': 'https://www.dropbox.com/s/j4yseny2u0hn22r/config.yml?dl=1',
31-
'mask_rcnn_R_50_FPN_3x': 'https://www.dropbox.com/s/4jmr3xanmxmjcf8/config.yml?dl=1',
32-
'retinanet_R_50_FPN_3x': 'https://www.dropbox.com/s/z8a8ywozuyc5c2x/config.yml?dl=1'
30+
"HJDataset": {
31+
"faster_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/j4yseny2u0hn22r/config.yml?dl=1",
32+
"mask_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/4jmr3xanmxmjcf8/config.yml?dl=1",
33+
"retinanet_R_50_FPN_3x": "https://www.dropbox.com/s/z8a8ywozuyc5c2x/config.yml?dl=1",
3334
},
3435
"PubLayNet": {
3536
"faster_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/f3b12qc4hc0yh4m/config.yml?dl=1",
3637
"mask_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/u9wbsfwz4y0ziki/config.yml?dl=1",
37-
"mask_rcnn_X_101_32x8d_FPN_3x": "https://www.dropbox.com/s/nau5ut6zgthunil/config.yaml?dl=1"
38+
"mask_rcnn_X_101_32x8d_FPN_3x": "https://www.dropbox.com/s/nau5ut6zgthunil/config.yaml?dl=1",
3839
},
3940
"PrimaLayout": {
4041
"mask_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/yc92x97k50abynt/config.yaml?dl=1"
4142
},
4243
"NewspaperNavigator": {
43-
'faster_rcnn_R_50_FPN_3x': 'https://www.dropbox.com/s/wnido8pk4oubyzr/config.yml?dl=1',
44+
"faster_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/wnido8pk4oubyzr/config.yml?dl=1",
4445
},
4546
"TableBank": {
46-
'faster_rcnn_R_50_FPN_3x': 'https://www.dropbox.com/s/7cqle02do7ah7k4/config.yaml?dl=1',
47-
'faster_rcnn_R_101_FPN_3x': 'https://www.dropbox.com/s/h63n6nv51kfl923/config.yaml?dl=1',
47+
"faster_rcnn_R_50_FPN_3x": "https://www.dropbox.com/s/7cqle02do7ah7k4/config.yaml?dl=1",
48+
"faster_rcnn_R_101_FPN_3x": "https://www.dropbox.com/s/h63n6nv51kfl923/config.yaml?dl=1",
49+
},
50+
}
51+
52+
LABEL_MAP_CATALOG = {
53+
"HJDataset": {
54+
1: "Page Frame",
55+
2: "Row",
56+
3: "Title Region",
57+
4: "Text Region",
58+
5: "Title",
59+
6: "Subtitle",
60+
7: "Other",
61+
},
62+
"PubLayNet": {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"},
63+
"PrimaLayout": {
64+
1: "TextRegion",
65+
2: "ImageRegion",
66+
3: "TableRegion",
67+
4: "MathsRegion",
68+
5: "SeparatorRegion",
69+
6: "OtherRegion",
70+
},
71+
"NewspaperNavigator": {
72+
0: "Photograph",
73+
1: "Illustration",
74+
2: "Map",
75+
3: "Comics/Cartoon",
76+
4: "Editorial Cartoon",
77+
5: "Headline",
78+
6: "Advertisement",
4879
},
80+
"TableBank": {0: "Table"},
4981
}
5082

5183

@@ -72,13 +104,13 @@ def _get_supported_prefixes(self):
72104
return [self.PREFIX]
73105

74106
def _get_local_path(self, path, **kwargs):
75-
model_name = path[len(self.PREFIX):]
76-
dataset_name, *model_name, data_type = model_name.split('/')
107+
model_name = path[len(self.PREFIX) :]
108+
dataset_name, *model_name, data_type = model_name.split("/")
77109

78-
if data_type == 'weight':
79-
model_url = MODEL_CATALOG[dataset_name]['/'.join(model_name)]
80-
elif data_type == 'config':
81-
model_url = CONFIG_CATALOG[dataset_name]['/'.join(model_name)]
110+
if data_type == "weight":
111+
model_url = MODEL_CATALOG[dataset_name]["/".join(model_name)]
112+
elif data_type == "config":
113+
model_url = CONFIG_CATALOG[dataset_name]["/".join(model_name)]
82114
else:
83115
raise ValueError(f"Unknown data_type {data_type}")
84116
return PathManager.get_local_path(model_url, **kwargs)

src/layoutparser/models/layoutmodel.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import torch
88

9-
from .catalog import PathManager
9+
from .catalog import PathManager, LABEL_MAP_CATALOG
1010
from ..elements import *
1111

1212
__all__ = ["Detectron2LayoutModel"]
@@ -65,8 +65,12 @@ class Detectron2LayoutModel(BaseLayoutModel):
6565
Defaults to `None`.
6666
label_map (:obj:`dict`, optional):
6767
The map from the model prediction (ids) to real
68-
word labels (strings).
68+
word labels (strings). If the config is from one of the supported
69+
datasets, Layout Parser will automatically initialize the label_map.
6970
Defaults to `None`.
71+
enforce_cpu(:obj:`bool`, optional):
72+
When set to `True`, it will enforce using cpu even if it is on a CUDA
73+
available device.
7074
extra_config (:obj:`list`, optional):
7175
Extra configuration passed to the Detectron2 model
7276
configuration. The argument will be used in the `merge_from_list
@@ -90,7 +94,21 @@ class Detectron2LayoutModel(BaseLayoutModel):
9094
{"import_name": "_config", "module_path": "detectron2.config"},
9195
]
9296

93-
def __init__(self, config_path, model_path=None, label_map=None, extra_config=[]):
97+
def __init__(
98+
self,
99+
config_path,
100+
model_path=None,
101+
label_map=None,
102+
extra_config=[],
103+
enforce_cpu=False,
104+
):
105+
106+
if config_path.startswith("lp://") and label_map is None:
107+
dataset_name = config_path.lstrip("lp://").split("/")[0]
108+
label_map = LABEL_MAP_CATALOG[dataset_name]
109+
110+
if enforce_cpu:
111+
extra_config.extend(["MODEL.DEVICE", "cpu"])
94112

95113
cfg = self._config.get_cfg()
96114
config_path = PathManager.get_local_path(config_path)

src/layoutparser/visualization.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def draw_box(
131131
box_width=None,
132132
color_map=None,
133133
show_element_id=False,
134+
show_element_type=False,
134135
id_font_size=None,
135136
id_font_path=None,
136137
id_text_color=None,
@@ -158,6 +159,10 @@ def draw_box(
158159
Whether to display `block.id` on the top-left corner of
159160
the block.
160161
Defaults to False.
162+
show_element_id (bool, optional):
163+
Whether to display `block.type` on the top-left corner of
164+
the block.
165+
Defaults to False.
161166
id_font_size (int, optional):
162167
Set to change the font size used for drawing `block.id`.
163168
Defaults to None, when the size is set to
@@ -183,7 +188,7 @@ def draw_box(
183188
if box_width is None:
184189
box_width = _calculate_default_box_width(canvas)
185190

186-
if show_element_id:
191+
if show_element_id or show_element_type:
187192
font_obj = _create_font_object(id_font_size, id_font_path)
188193

189194
if color_map is None:
@@ -208,11 +213,16 @@ def draw_box(
208213
p = ele.points.ravel().tolist()
209214
draw.line(p + p[:2], width=box_width, fill=outline_color)
210215

211-
if show_element_id:
212-
ele_id = ele.id or idx
216+
if show_element_id or show_element_type:
217+
text = ""
218+
if show_element_id:
219+
ele_id = ele.id or idx
220+
text += str(ele_id)
221+
if show_element_type:
222+
text = str(ele.type) if not text else text + ": " + str(ele.type)
213223

214224
start_x, start_y = ele.coordinates[:2]
215-
text_w, text_h = font_obj.getsize(f"{ele_id}")
225+
text_w, text_h = font_obj.getsize(text)
216226

217227
# Add a small background for the text
218228
draw.rectangle(
@@ -223,7 +233,7 @@ def draw_box(
223233
# Draw the ids
224234
draw.text(
225235
(start_x, start_y),
226-
f"{ele_id}",
236+
text,
227237
fill=id_text_color or DEFAULT_TEXT_COLOR,
228238
font=font_obj,
229239
)

tests/test_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,9 @@ def test_Detectron2Model(is_large_scale=False):
2727
else:
2828
model = Detectron2LayoutModel("tests/fixtures/model/config.yml")
2929
image = cv2.imread("tests/fixtures/model/test_model_image.jpg")
30-
layout = model.detect(image)
30+
layout = model.detect(image)
31+
32+
# Test in enforce CPU mode
33+
model = Detectron2LayoutModel("tests/fixtures/model/config.yml", enforce_cpu=True)
34+
image = cv2.imread("tests/fixtures/model/test_model_image.jpg")
35+
layout = model.detect(image)

0 commit comments

Comments
 (0)