|
| 1 | +"""Visualization utilities for rendering detection results on frames.""" |
| 2 | + |
| 3 | +import random |
| 4 | +from typing import Dict, List, Tuple |
| 5 | + |
| 6 | +import cv2 |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from .detectors.base_detector import Detection, DetectionResult |
| 10 | + |
| 11 | + |
| 12 | +class Visualizer: |
| 13 | + """Visualize detection results on frames.""" |
| 14 | + |
| 15 | + def __init__(self, config: dict): |
| 16 | + """Initialize visualizer. |
| 17 | +
|
| 18 | + Args: |
| 19 | + config: Visualization configuration |
| 20 | + """ |
| 21 | + self.config = config |
| 22 | + self.class_colors: Dict[str, Tuple[int, int, int]] = {} |
| 23 | + |
| 24 | + # Load predefined colors if available |
| 25 | + if "class_colors" in config: |
| 26 | + self.class_colors = { |
| 27 | + name: tuple(color) for name, color in config["class_colors"].items() |
| 28 | + } |
| 29 | + |
| 30 | + def get_color(self, class_name: str) -> Tuple[int, int, int]: |
| 31 | + """Get color for a class. |
| 32 | +
|
| 33 | + Args: |
| 34 | + class_name: Name of the class |
| 35 | +
|
| 36 | + Returns: |
| 37 | + RGB color tuple |
| 38 | + """ |
| 39 | + if class_name not in self.class_colors: |
| 40 | + # Generate random but consistent color |
| 41 | + random.seed(hash(class_name)) |
| 42 | + self.class_colors[class_name] = ( |
| 43 | + random.randint(0, 255), |
| 44 | + random.randint(0, 255), |
| 45 | + random.randint(0, 255), |
| 46 | + ) |
| 47 | + return self.class_colors[class_name] |
| 48 | + |
| 49 | + def draw_detection(self, frame: np.ndarray, detection: Detection) -> np.ndarray: |
| 50 | + """Draw a single detection on the frame. |
| 51 | +
|
| 52 | + Args: |
| 53 | + frame: Input frame |
| 54 | + detection: Detection to draw |
| 55 | +
|
| 56 | + Returns: |
| 57 | + Frame with detection drawn |
| 58 | + """ |
| 59 | + color = self.get_color(detection.class_name) |
| 60 | + x1, y1, x2, y2 = detection.bbox |
| 61 | + |
| 62 | + # Draw mask if available and enabled |
| 63 | + if self.config.get("show_masks", True) and detection.mask is not None: |
| 64 | + frame = self._draw_mask(frame, detection.mask, color) |
| 65 | + |
| 66 | + # Draw bounding box if enabled |
| 67 | + if self.config.get("show_boxes", True): |
| 68 | + thickness = self.config.get("box_thickness", 2) |
| 69 | + cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness) |
| 70 | + |
| 71 | + # Draw label if enabled |
| 72 | + if self.config.get("show_labels", True): |
| 73 | + label = detection.class_name |
| 74 | + |
| 75 | + if self.config.get("show_confidence", True): |
| 76 | + label = f"{label} {detection.confidence:.2f}" |
| 77 | + |
| 78 | + self._draw_label(frame, label, (x1, y1), color) |
| 79 | + |
| 80 | + return frame |
| 81 | + |
| 82 | + def draw_detections(self, frame: np.ndarray, result: DetectionResult) -> np.ndarray: |
| 83 | + """Draw all detections on the frame. |
| 84 | +
|
| 85 | + Args: |
| 86 | + frame: Input frame |
| 87 | + result: Detection results |
| 88 | +
|
| 89 | + Returns: |
| 90 | + Frame with all detections drawn |
| 91 | + """ |
| 92 | + for detection in result.detections: |
| 93 | + frame = self.draw_detection(frame, detection) |
| 94 | + |
| 95 | + return frame |
| 96 | + |
| 97 | + def _draw_mask( |
| 98 | + self, frame: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int] |
| 99 | + ) -> np.ndarray: |
| 100 | + """Draw segmentation mask with transparency.""" |
| 101 | + alpha = self.config.get("mask_alpha", 0.4) |
| 102 | + |
| 103 | + # Create colored mask |
| 104 | + colored_mask = np.zeros_like(frame) |
| 105 | + colored_mask[mask > 0] = color |
| 106 | + |
| 107 | + # Blend with original frame |
| 108 | + frame = cv2.addWeighted(frame, 1.0, colored_mask, alpha, 0) |
| 109 | + |
| 110 | + return frame |
| 111 | + |
| 112 | + def _draw_label( |
| 113 | + self, |
| 114 | + frame: np.ndarray, |
| 115 | + text: str, |
| 116 | + position: Tuple[int, int], |
| 117 | + color: Tuple[int, int, int], |
| 118 | + ) -> None: |
| 119 | + """Draw label with background.""" |
| 120 | + font = cv2.FONT_HERSHEY_SIMPLEX |
| 121 | + font_scale = self.config.get("font_scale", 0.6) |
| 122 | + thickness = self.config.get("font_thickness", 2) |
| 123 | + |
| 124 | + # Get text size |
| 125 | + (text_width, text_height), baseline = cv2.getTextSize( |
| 126 | + text, font, font_scale, thickness |
| 127 | + ) |
| 128 | + |
| 129 | + x, y = position |
| 130 | + |
| 131 | + # Draw background rectangle |
| 132 | + cv2.rectangle( |
| 133 | + frame, |
| 134 | + (x, y - text_height - baseline - 5), |
| 135 | + (x + text_width + 5, y), |
| 136 | + color, |
| 137 | + -1, |
| 138 | + ) |
| 139 | + |
| 140 | + # Draw text |
| 141 | + cv2.putText( |
| 142 | + frame, |
| 143 | + text, |
| 144 | + (x + 2, y - 5), |
| 145 | + font, |
| 146 | + font_scale, |
| 147 | + (255, 255, 255), |
| 148 | + thickness, |
| 149 | + cv2.LINE_AA, |
| 150 | + ) |
| 151 | + |
| 152 | + def draw_stats( |
| 153 | + self, frame: np.ndarray, stats: List[str], position: Tuple[int, int] = (10, 30) |
| 154 | + ) -> np.ndarray: |
| 155 | + """Draw statistics text on frame. |
| 156 | +
|
| 157 | + Args: |
| 158 | + frame: Input frame |
| 159 | + stats: List of stat strings to display |
| 160 | + position: Starting position (x, y) |
| 161 | +
|
| 162 | + Returns: |
| 163 | + Frame with stats drawn |
| 164 | + """ |
| 165 | + font = cv2.FONT_HERSHEY_SIMPLEX |
| 166 | + font_scale = 0.6 |
| 167 | + thickness = 2 |
| 168 | + color = (0, 255, 0) |
| 169 | + line_height = 30 |
| 170 | + |
| 171 | + x, y = position |
| 172 | + |
| 173 | + for i, stat in enumerate(stats): |
| 174 | + cv2.putText( |
| 175 | + frame, |
| 176 | + stat, |
| 177 | + (x, y + i * line_height), |
| 178 | + font, |
| 179 | + font_scale, |
| 180 | + color, |
| 181 | + thickness, |
| 182 | + cv2.LINE_AA, |
| 183 | + ) |
| 184 | + |
| 185 | + return frame |
| 186 | + |
| 187 | + def draw_fps( |
| 188 | + self, frame: np.ndarray, fps: float, position: Tuple[int, int] = None |
| 189 | + ) -> np.ndarray: |
| 190 | + """Draw FPS counter on frame. |
| 191 | +
|
| 192 | + Args: |
| 193 | + frame: Input frame |
| 194 | + fps: Current FPS |
| 195 | + position: Position to draw (default: top-right) |
| 196 | +
|
| 197 | + Returns: |
| 198 | + Frame with FPS drawn |
| 199 | + """ |
| 200 | + if position is None: |
| 201 | + position = (frame.shape[1] - 150, 30) |
| 202 | + |
| 203 | + text = f"FPS: {fps:.1f}" |
| 204 | + |
| 205 | + cv2.putText( |
| 206 | + frame, |
| 207 | + text, |
| 208 | + position, |
| 209 | + cv2.FONT_HERSHEY_SIMPLEX, |
| 210 | + 0.7, |
| 211 | + (0, 255, 0), |
| 212 | + 2, |
| 213 | + cv2.LINE_AA, |
| 214 | + ) |
| 215 | + |
| 216 | + return frame |
| 217 | + |
| 218 | + def draw_crosshair( |
| 219 | + self, |
| 220 | + frame: np.ndarray, |
| 221 | + size: int = 20, |
| 222 | + color: Tuple[int, int, int] = (0, 255, 0), |
| 223 | + ) -> np.ndarray: |
| 224 | + """Draw crosshair at center of frame. |
| 225 | +
|
| 226 | + Args: |
| 227 | + frame: Input frame |
| 228 | + size: Size of crosshair |
| 229 | + color: Color of crosshair |
| 230 | +
|
| 231 | + Returns: |
| 232 | + Frame with crosshair |
| 233 | + """ |
| 234 | + h, w = frame.shape[:2] |
| 235 | + cx, cy = w // 2, h // 2 |
| 236 | + |
| 237 | + # Draw horizontal line |
| 238 | + cv2.line(frame, (cx - size, cy), (cx + size, cy), color, 2) |
| 239 | + |
| 240 | + # Draw vertical line |
| 241 | + cv2.line(frame, (cx, cy - size), (cx, cy + size), color, 2) |
| 242 | + |
| 243 | + # Draw center circle |
| 244 | + cv2.circle(frame, (cx, cy), 5, color, -1) |
| 245 | + |
| 246 | + return frame |
0 commit comments