Skip to content

Commit 8fdabde

Browse files
author
Saumya Saksena
committed
Add detectron
1 parent 9d996b0 commit 8fdabde

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""Detectron2 detector implementation.
2+
3+
Higher quality but slower than YOLO. Good for precision applications.
4+
"""
5+
6+
import time
7+
8+
import numpy as np
9+
10+
from .base_detector import BaseDetector, Detection, DetectionResult
11+
12+
13+
class Detectron2Detector(BaseDetector):
14+
"""Detectron2 Mask R-CNN detector."""
15+
16+
def __init__(self, config: dict):
17+
super().__init__(config)
18+
self.predictor = None
19+
self.metadata = None
20+
21+
def load_model(self) -> None:
22+
"""Load Detectron2 model."""
23+
try:
24+
from detectron2 import model_zoo
25+
from detectron2.config import get_cfg
26+
from detectron2.data import MetadataCatalog
27+
from detectron2.engine import DefaultPredictor
28+
except ImportError:
29+
raise ImportError(
30+
"detectron2 not installed. Install from: "
31+
"https://github.com/facebookresearch/detectron2"
32+
)
33+
34+
cfg = get_cfg()
35+
36+
# Load config
37+
config_file = self.config.get(
38+
"config_file", "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
39+
)
40+
cfg.merge_from_file(model_zoo.get_config_file(config_file))
41+
42+
# Set model weights
43+
weights = self.config.get("model_weights")
44+
if weights and weights.startswith("detectron2://"):
45+
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file)
46+
else:
47+
cfg.MODEL.WEIGHTS = weights or model_zoo.get_checkpoint_url(config_file)
48+
49+
# Set confidence threshold
50+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = self.config.get("confidence", 0.5)
51+
52+
# Set device
53+
device = self.config.get("device", "cuda")
54+
cfg.MODEL.DEVICE = device
55+
56+
print(f"Loading Detectron2 model: {config_file} on {device}")
57+
58+
# Create predictor
59+
self.predictor = DefaultPredictor(cfg)
60+
61+
# Get metadata for class names
62+
dataset_name = config_file.split("/")[0]
63+
if dataset_name.startswith("COCO"):
64+
self.metadata = MetadataCatalog.get("coco_2017_val")
65+
else:
66+
self.metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
67+
68+
self.class_names = self.metadata.thing_classes
69+
70+
self._initialized = True
71+
print(f"Detectron2 model loaded. Classes: {len(self.class_names)}")
72+
73+
def detect(self, frame: np.ndarray) -> DetectionResult:
74+
"""Run Detectron2 detection on frame.
75+
76+
Args:
77+
frame: Input image (H, W, C) in BGR format
78+
79+
Returns:
80+
DetectionResult with all detections
81+
"""
82+
if not self._initialized:
83+
raise RuntimeError("Model not loaded. Call load_model() first.")
84+
85+
start_time = time.time()
86+
87+
# Run inference
88+
outputs = self.predictor(frame)
89+
90+
inference_time = time.time() - start_time
91+
92+
# Parse results
93+
detections = []
94+
instances = outputs["instances"].to("cpu")
95+
96+
if len(instances) > 0:
97+
boxes = instances.pred_boxes.tensor.numpy()
98+
scores = instances.scores.numpy()
99+
classes = instances.pred_classes.numpy()
100+
101+
# Get masks if available
102+
masks = None
103+
if instances.has("pred_masks"):
104+
masks = instances.pred_masks.numpy()
105+
106+
for idx in range(len(instances)):
107+
bbox = boxes[idx].astype(int)
108+
109+
# Get mask
110+
mask = None
111+
if masks is not None:
112+
mask = masks[idx].astype(np.uint8)
113+
114+
detection = Detection(
115+
class_id=int(classes[idx]),
116+
class_name=self.get_class_name(int(classes[idx])),
117+
confidence=float(scores[idx]),
118+
bbox=tuple(bbox),
119+
mask=mask,
120+
)
121+
detections.append(detection)
122+
123+
return DetectionResult(
124+
detections=detections,
125+
inference_time=inference_time,
126+
frame_shape=frame.shape,
127+
)
128+
129+
def get_class_name(self, class_id: int) -> str:
130+
"""Get class name from ID."""
131+
if 0 <= class_id < len(self.class_names):
132+
return self.class_names[class_id]
133+
return f"class_{class_id}"
134+
135+
def get_model_info(self) -> dict:
136+
"""Get model information."""
137+
return {
138+
"backend": "detectron2",
139+
"config": self.config.get("config_file", "unknown"),
140+
"device": self.device,
141+
"num_classes": len(self.class_names),
142+
"classes": self.class_names,
143+
}

0 commit comments

Comments
 (0)