|
| 1 | +"""Abstract base class for object detection/segmentation models. |
| 2 | +
|
| 3 | +Allows easy swapping between different backends (YOLOv8, Detectron2, custom). |
| 4 | +""" |
| 5 | + |
| 6 | +from abc import ABC, abstractmethod |
| 7 | +from dataclasses import dataclass |
| 8 | +from typing import List, Optional, Tuple |
| 9 | + |
| 10 | +import numpy as np |
| 11 | + |
| 12 | + |
| 13 | +@dataclass |
| 14 | +class Detection: |
| 15 | + """Single detection result.""" |
| 16 | + |
| 17 | + class_id: int |
| 18 | + class_name: str |
| 19 | + confidence: float |
| 20 | + bbox: Tuple[int, int, int, int] # (x1, y1, x2, y2) |
| 21 | + mask: Optional[np.ndarray] = None # Binary mask if available |
| 22 | + |
| 23 | + @property |
| 24 | + def center(self) -> Tuple[int, int]: |
| 25 | + """Get center point of bounding box.""" |
| 26 | + x1, y1, x2, y2 = self.bbox |
| 27 | + return ((x1 + x2) // 2, (y1 + y2) // 2) |
| 28 | + |
| 29 | + @property |
| 30 | + def area(self) -> int: |
| 31 | + """Get area of bounding box.""" |
| 32 | + x1, y1, x2, y2 = self.bbox |
| 33 | + return (x2 - x1) * (y2 - y1) |
| 34 | + |
| 35 | + |
| 36 | +@dataclass |
| 37 | +class DetectionResult: |
| 38 | + """Complete detection result for a frame.""" |
| 39 | + |
| 40 | + detections: List[Detection] |
| 41 | + inference_time: float # seconds |
| 42 | + frame_shape: Tuple[int, int, int] # (H, W, C) |
| 43 | + |
| 44 | + def filter_by_class(self, class_names: List[str]) -> "DetectionResult": |
| 45 | + """Filter detections by class names.""" |
| 46 | + filtered = [d for d in self.detections if d.class_name in class_names] |
| 47 | + return DetectionResult(filtered, self.inference_time, self.frame_shape) |
| 48 | + |
| 49 | + def filter_by_confidence(self, min_confidence: float) -> "DetectionResult": |
| 50 | + """Filter detections by minimum confidence.""" |
| 51 | + filtered = [d for d in self.detections if d.confidence >= min_confidence] |
| 52 | + return DetectionResult(filtered, self.inference_time, self.frame_shape) |
| 53 | + |
| 54 | + @property |
| 55 | + def count(self) -> int: |
| 56 | + """Number of detections.""" |
| 57 | + return len(self.detections) |
| 58 | + |
| 59 | + |
| 60 | +class BaseDetector(ABC): |
| 61 | + """Abstract base class for all detectors.""" |
| 62 | + |
| 63 | + def __init__(self, config: dict): |
| 64 | + """Initialize detector with configuration. |
| 65 | +
|
| 66 | + Args: |
| 67 | + config: Dictionary containing detector configuration |
| 68 | + """ |
| 69 | + self.config = config |
| 70 | + self.class_names: List[str] = [] |
| 71 | + self._initialized = False |
| 72 | + |
| 73 | + @abstractmethod |
| 74 | + def load_model(self) -> None: |
| 75 | + """Load the detection model.""" |
| 76 | + pass |
| 77 | + |
| 78 | + @abstractmethod |
| 79 | + def detect(self, frame: np.ndarray) -> DetectionResult: |
| 80 | + """Run detection on a frame. |
| 81 | +
|
| 82 | + Args: |
| 83 | + frame: Input image as numpy array (H, W, C) in BGR format |
| 84 | +
|
| 85 | + Returns: |
| 86 | + DetectionResult containing all detections |
| 87 | + """ |
| 88 | + pass |
| 89 | + |
| 90 | + @abstractmethod |
| 91 | + def get_class_name(self, class_id: int) -> str: |
| 92 | + """Get class name from class ID.""" |
| 93 | + pass |
| 94 | + |
| 95 | + def warmup(self, num_iterations: int = 3) -> None: |
| 96 | + """Warmup the model with dummy input. Useful for GPU initialization. |
| 97 | +
|
| 98 | + Args: |
| 99 | + num_iterations: Number of warmup iterations |
| 100 | + """ |
| 101 | + if not self._initialized: |
| 102 | + raise RuntimeError("Model not loaded. Call load_model() first.") |
| 103 | + |
| 104 | + dummy_frame = np.zeros((480, 640, 3), dtype=np.uint8) |
| 105 | + for _ in range(num_iterations): |
| 106 | + self.detect(dummy_frame) |
| 107 | + |
| 108 | + def is_initialized(self) -> bool: |
| 109 | + """Check if model is loaded and ready.""" |
| 110 | + return self._initialized |
| 111 | + |
| 112 | + @property |
| 113 | + def device(self) -> str: |
| 114 | + """Get the device the model is running on.""" |
| 115 | + return self.config.get("device", "cpu") |
| 116 | + |
| 117 | + @staticmethod |
| 118 | + def create_detector(backend: str, config: dict) -> "BaseDetector": |
| 119 | + """Factory method to create detector instance. |
| 120 | +
|
| 121 | + Args: |
| 122 | + backend: Detector backend name ('yolov8', 'detectron2', etc.) |
| 123 | + config: Configuration dictionary |
| 124 | +
|
| 125 | + Returns: |
| 126 | + Detector instance |
| 127 | +
|
| 128 | + Raises: |
| 129 | + ValueError: If backend is not supported |
| 130 | + """ |
| 131 | + if backend == "yolov8": |
| 132 | + from .yolo_detector import YOLODetector |
| 133 | + |
| 134 | + return YOLODetector(config) |
| 135 | + elif backend == "detectron2": |
| 136 | + from .detectron2_detector import Detectron2Detector |
| 137 | + |
| 138 | + return Detectron2Detector(config) |
| 139 | + else: |
| 140 | + raise ValueError(f"Unsupported detector backend: {backend}") |
0 commit comments