Skip to content

Commit cc70548

Browse files
author
Saumya Saksena
committed
Add YOLO detector
1 parent 8fdabde commit cc70548

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed

tello_vision/detectors/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Detector module for various object detection/segmentation backends."""
2+
3+
from .base_detector import BaseDetector, Detection, DetectionResult
4+
5+
__all__ = ["BaseDetector", "Detection", "DetectionResult"]
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""YOLOv8 detector implementation using Ultralytics.
2+
3+
Fast, real-time capable, and easy to use.
4+
"""
5+
6+
import time
7+
8+
import cv2
9+
import numpy as np
10+
11+
from .base_detector import BaseDetector, Detection, DetectionResult
12+
13+
14+
class YOLODetector(BaseDetector):
15+
"""YOLOv8 instance segmentation detector."""
16+
17+
def __init__(self, config: dict):
18+
super().__init__(config)
19+
self.model = None
20+
21+
def load_model(self) -> None:
22+
"""Load YOLOv8 model."""
23+
try:
24+
from ultralytics import YOLO
25+
except ImportError:
26+
raise ImportError(
27+
"ultralytics not installed. Install with: pip install ultralytics"
28+
)
29+
30+
model_name = self.config.get("model", "yolov8n-seg.pt")
31+
device = self.config.get("device", "cuda")
32+
33+
print(f"Loading YOLOv8 model: {model_name} on {device}")
34+
self.model = YOLO(model_name)
35+
36+
# Move to device
37+
self.model.to(device)
38+
39+
# Get class names
40+
self.class_names = list(self.model.names.values())
41+
42+
self._initialized = True
43+
print(f"YOLOv8 model loaded. Classes: {len(self.class_names)}")
44+
45+
def detect(self, frame: np.ndarray) -> DetectionResult:
46+
"""Run YOLOv8 detection on frame.
47+
48+
Args:
49+
frame: Input image (H, W, C) in BGR format
50+
51+
Returns:
52+
DetectionResult with all detections
53+
"""
54+
if not self._initialized:
55+
raise RuntimeError("Model not loaded. Call load_model() first.")
56+
57+
start_time = time.time()
58+
59+
# Run inference
60+
results = self.model(
61+
frame,
62+
conf=self.config.get("confidence", 0.5),
63+
iou=self.config.get("iou_threshold", 0.45),
64+
verbose=False,
65+
)[0]
66+
67+
inference_time = time.time() - start_time
68+
69+
# Parse results
70+
detections = []
71+
72+
if results.boxes is not None and len(results.boxes) > 0:
73+
boxes = results.boxes.xyxy.cpu().numpy() # (x1, y1, x2, y2)
74+
confidences = results.boxes.conf.cpu().numpy()
75+
class_ids = results.boxes.cls.cpu().numpy().astype(int)
76+
77+
# Get masks if available
78+
masks = None
79+
if hasattr(results, "masks") and results.masks is not None:
80+
masks = results.masks.data.cpu().numpy()
81+
82+
for idx in range(len(boxes)):
83+
class_id = class_ids[idx]
84+
bbox = boxes[idx].astype(int)
85+
86+
# Get mask if available
87+
mask = None
88+
if masks is not None and idx < len(masks):
89+
# Resize mask to original frame size
90+
mask_resized = cv2.resize(
91+
masks[idx],
92+
(frame.shape[1], frame.shape[0]),
93+
interpolation=cv2.INTER_LINEAR,
94+
)
95+
mask = (mask_resized > 0.5).astype(np.uint8)
96+
97+
detection = Detection(
98+
class_id=class_id,
99+
class_name=self.get_class_name(class_id),
100+
confidence=float(confidences[idx]),
101+
bbox=tuple(bbox),
102+
mask=mask,
103+
)
104+
detections.append(detection)
105+
106+
return DetectionResult(
107+
detections=detections,
108+
inference_time=inference_time,
109+
frame_shape=frame.shape,
110+
)
111+
112+
def get_class_name(self, class_id: int) -> str:
113+
"""Get class name from ID."""
114+
if 0 <= class_id < len(self.class_names):
115+
return self.class_names[class_id]
116+
return f"class_{class_id}"
117+
118+
def get_model_info(self) -> dict:
119+
"""Get model information."""
120+
return {
121+
"backend": "yolov8",
122+
"model": self.config.get("model", "unknown"),
123+
"device": self.device,
124+
"num_classes": len(self.class_names),
125+
"classes": self.class_names,
126+
}

0 commit comments

Comments
 (0)