Skip to content

Commit ddfe13e

Browse files
author
Saumya Saksena
committed
Add tello visualiser
1 parent 560644a commit ddfe13e

File tree

1 file changed

+246
-0
lines changed

1 file changed

+246
-0
lines changed

tello_vision/visualizer.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
"""Visualization utilities for rendering detection results on frames."""
2+
3+
import random
4+
from typing import Dict, List, Tuple
5+
6+
import cv2
7+
import numpy as np
8+
9+
from .detectors.base_detector import Detection, DetectionResult
10+
11+
12+
class Visualizer:
13+
"""Visualize detection results on frames."""
14+
15+
def __init__(self, config: dict):
16+
"""Initialize visualizer.
17+
18+
Args:
19+
config: Visualization configuration
20+
"""
21+
self.config = config
22+
self.class_colors: Dict[str, Tuple[int, int, int]] = {}
23+
24+
# Load predefined colors if available
25+
if "class_colors" in config:
26+
self.class_colors = {
27+
name: tuple(color) for name, color in config["class_colors"].items()
28+
}
29+
30+
def get_color(self, class_name: str) -> Tuple[int, int, int]:
31+
"""Get color for a class.
32+
33+
Args:
34+
class_name: Name of the class
35+
36+
Returns:
37+
RGB color tuple
38+
"""
39+
if class_name not in self.class_colors:
40+
# Generate random but consistent color
41+
random.seed(hash(class_name))
42+
self.class_colors[class_name] = (
43+
random.randint(0, 255),
44+
random.randint(0, 255),
45+
random.randint(0, 255),
46+
)
47+
return self.class_colors[class_name]
48+
49+
def draw_detection(self, frame: np.ndarray, detection: Detection) -> np.ndarray:
50+
"""Draw a single detection on the frame.
51+
52+
Args:
53+
frame: Input frame
54+
detection: Detection to draw
55+
56+
Returns:
57+
Frame with detection drawn
58+
"""
59+
color = self.get_color(detection.class_name)
60+
x1, y1, x2, y2 = detection.bbox
61+
62+
# Draw mask if available and enabled
63+
if self.config.get("show_masks", True) and detection.mask is not None:
64+
frame = self._draw_mask(frame, detection.mask, color)
65+
66+
# Draw bounding box if enabled
67+
if self.config.get("show_boxes", True):
68+
thickness = self.config.get("box_thickness", 2)
69+
cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness)
70+
71+
# Draw label if enabled
72+
if self.config.get("show_labels", True):
73+
label = detection.class_name
74+
75+
if self.config.get("show_confidence", True):
76+
label = f"{label} {detection.confidence:.2f}"
77+
78+
self._draw_label(frame, label, (x1, y1), color)
79+
80+
return frame
81+
82+
def draw_detections(self, frame: np.ndarray, result: DetectionResult) -> np.ndarray:
83+
"""Draw all detections on the frame.
84+
85+
Args:
86+
frame: Input frame
87+
result: Detection results
88+
89+
Returns:
90+
Frame with all detections drawn
91+
"""
92+
for detection in result.detections:
93+
frame = self.draw_detection(frame, detection)
94+
95+
return frame
96+
97+
def _draw_mask(
98+
self, frame: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int]
99+
) -> np.ndarray:
100+
"""Draw segmentation mask with transparency."""
101+
alpha = self.config.get("mask_alpha", 0.4)
102+
103+
# Create colored mask
104+
colored_mask = np.zeros_like(frame)
105+
colored_mask[mask > 0] = color
106+
107+
# Blend with original frame
108+
frame = cv2.addWeighted(frame, 1.0, colored_mask, alpha, 0)
109+
110+
return frame
111+
112+
def _draw_label(
113+
self,
114+
frame: np.ndarray,
115+
text: str,
116+
position: Tuple[int, int],
117+
color: Tuple[int, int, int],
118+
) -> None:
119+
"""Draw label with background."""
120+
font = cv2.FONT_HERSHEY_SIMPLEX
121+
font_scale = self.config.get("font_scale", 0.6)
122+
thickness = self.config.get("font_thickness", 2)
123+
124+
# Get text size
125+
(text_width, text_height), baseline = cv2.getTextSize(
126+
text, font, font_scale, thickness
127+
)
128+
129+
x, y = position
130+
131+
# Draw background rectangle
132+
cv2.rectangle(
133+
frame,
134+
(x, y - text_height - baseline - 5),
135+
(x + text_width + 5, y),
136+
color,
137+
-1,
138+
)
139+
140+
# Draw text
141+
cv2.putText(
142+
frame,
143+
text,
144+
(x + 2, y - 5),
145+
font,
146+
font_scale,
147+
(255, 255, 255),
148+
thickness,
149+
cv2.LINE_AA,
150+
)
151+
152+
def draw_stats(
153+
self, frame: np.ndarray, stats: List[str], position: Tuple[int, int] = (10, 30)
154+
) -> np.ndarray:
155+
"""Draw statistics text on frame.
156+
157+
Args:
158+
frame: Input frame
159+
stats: List of stat strings to display
160+
position: Starting position (x, y)
161+
162+
Returns:
163+
Frame with stats drawn
164+
"""
165+
font = cv2.FONT_HERSHEY_SIMPLEX
166+
font_scale = 0.6
167+
thickness = 2
168+
color = (0, 255, 0)
169+
line_height = 30
170+
171+
x, y = position
172+
173+
for i, stat in enumerate(stats):
174+
cv2.putText(
175+
frame,
176+
stat,
177+
(x, y + i * line_height),
178+
font,
179+
font_scale,
180+
color,
181+
thickness,
182+
cv2.LINE_AA,
183+
)
184+
185+
return frame
186+
187+
def draw_fps(
188+
self, frame: np.ndarray, fps: float, position: Tuple[int, int] = None
189+
) -> np.ndarray:
190+
"""Draw FPS counter on frame.
191+
192+
Args:
193+
frame: Input frame
194+
fps: Current FPS
195+
position: Position to draw (default: top-right)
196+
197+
Returns:
198+
Frame with FPS drawn
199+
"""
200+
if position is None:
201+
position = (frame.shape[1] - 150, 30)
202+
203+
text = f"FPS: {fps:.1f}"
204+
205+
cv2.putText(
206+
frame,
207+
text,
208+
position,
209+
cv2.FONT_HERSHEY_SIMPLEX,
210+
0.7,
211+
(0, 255, 0),
212+
2,
213+
cv2.LINE_AA,
214+
)
215+
216+
return frame
217+
218+
def draw_crosshair(
219+
self,
220+
frame: np.ndarray,
221+
size: int = 20,
222+
color: Tuple[int, int, int] = (0, 255, 0),
223+
) -> np.ndarray:
224+
"""Draw crosshair at center of frame.
225+
226+
Args:
227+
frame: Input frame
228+
size: Size of crosshair
229+
color: Color of crosshair
230+
231+
Returns:
232+
Frame with crosshair
233+
"""
234+
h, w = frame.shape[:2]
235+
cx, cy = w // 2, h // 2
236+
237+
# Draw horizontal line
238+
cv2.line(frame, (cx - size, cy), (cx + size, cy), color, 2)
239+
240+
# Draw vertical line
241+
cv2.line(frame, (cx, cy - size), (cx, cy + size), color, 2)
242+
243+
# Draw center circle
244+
cv2.circle(frame, (cx, cy), 5, color, -1)
245+
246+
return frame

0 commit comments

Comments
 (0)