Skip to content

Commit 688d10f

Browse files
committed
Add Layout Modeling
* Add Detectron2LayoutModel * Add tests for Detectron2LayoutModel * Update setup requirements
1 parent 0870b6a commit 688d10f

File tree

6 files changed

+395
-1
lines changed

6 files changed

+395
-1
lines changed

docs/api_doc/models.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
DL Layout Model
2+
================================
3+
4+
5+
.. autoclass:: layoutparser.models.Detectron2LayoutModel
6+
:members:
7+
:undoc-members:
8+
:show-inheritance:

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Welcome to Layout Parser's documentation!
2020
api_doc/elements
2121
api_doc/ocr
2222
api_doc/visualization
23+
api_doc/models
2324

2425
Indices and tables
2526
==================

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
"numpy",
2323
"opencv-python",
2424
"pandas",
25-
"pillow"
25+
"pillow",
26+
"pyyaml>=5.1",
27+
"torch==1.4",
28+
"torchvision==0.5",
29+
"detectron2 @ git+https://github.com/facebookresearch/detectron2.git@v0.1.3#egg=detectron2"
2630
],
2731
extras_require={
2832
"GCV": ['google-cloud-vision'],

src/layoutparser/models.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from abc import ABC, abstractmethod
2+
import os
3+
import torch
4+
from detectron2.config import get_cfg
5+
from detectron2.engine import DefaultPredictor
6+
from .elements import *
7+
8+
9+
class BaseLayoutModel(ABC):
10+
11+
@abstractmethod
12+
def detect(self): pass
13+
14+
15+
class Detectron2LayoutModel(BaseLayoutModel):
16+
17+
def __init__(self, config_name,
18+
model_path = None,
19+
label_map = None,
20+
extra_config= []):
21+
22+
cfg = get_cfg()
23+
cfg.merge_from_file(config_name)
24+
cfg.merge_from_list(extra_config)
25+
26+
if model_path is not None:
27+
cfg.MODEL.WEIGHTS = model_path
28+
cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
29+
self.cfg = cfg
30+
31+
self.label_map = label_map
32+
self._create_model()
33+
34+
def gather_output(self, outputs):
35+
36+
instance_pred = outputs['instances'].to("cpu")
37+
38+
layout = Layout()
39+
scores = instance_pred.scores.tolist()
40+
boxes = instance_pred.pred_boxes.tensor.tolist()
41+
labels = instance_pred.pred_classes.tolist()
42+
43+
for score, box, label in zip(scores, boxes, labels):
44+
x_1, y_1, x_2, y_2 = box
45+
46+
if self.label_map is not None:
47+
label = self.label_map[label]
48+
49+
cur_block = TextBlock(
50+
Rectangle(x_1, y_1, x_2, y_2),
51+
type=label,
52+
score=score)
53+
layout.append(cur_block)
54+
55+
return layout
56+
57+
def _create_model(self):
58+
self.model = DefaultPredictor(self.cfg)
59+
60+
def detect(self, image):
61+
62+
outputs = self.model(image)
63+
layout = self.gather_output(outputs)
64+
return layout

0 commit comments

Comments
 (0)