Skip to content

Commit 9d996b0

Browse files
author
Saumya Saksena
committed
Add tello base detector
1 parent ddfe13e commit 9d996b0

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

tello_vision/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Tello Vision - Modern instance segmentation for DJI Tello drones."""
2+
3+
__version__ = "2.0.0"
4+
5+
from .app import TelloVisionApp
6+
from .detectors import BaseDetector, Detection, DetectionResult
7+
from .tello_controller import TelloController
8+
from .visualizer import Visualizer
9+
10+
__all__ = [
11+
"TelloVisionApp",
12+
"TelloController",
13+
"Visualizer",
14+
"BaseDetector",
15+
"Detection",
16+
"DetectionResult",
17+
]
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""Abstract base class for object detection/segmentation models.
2+
3+
Allows easy swapping between different backends (YOLOv8, Detectron2, custom).
4+
"""
5+
6+
from abc import ABC, abstractmethod
7+
from dataclasses import dataclass
8+
from typing import List, Optional, Tuple
9+
10+
import numpy as np
11+
12+
13+
@dataclass
14+
class Detection:
15+
"""Single detection result."""
16+
17+
class_id: int
18+
class_name: str
19+
confidence: float
20+
bbox: Tuple[int, int, int, int] # (x1, y1, x2, y2)
21+
mask: Optional[np.ndarray] = None # Binary mask if available
22+
23+
@property
24+
def center(self) -> Tuple[int, int]:
25+
"""Get center point of bounding box."""
26+
x1, y1, x2, y2 = self.bbox
27+
return ((x1 + x2) // 2, (y1 + y2) // 2)
28+
29+
@property
30+
def area(self) -> int:
31+
"""Get area of bounding box."""
32+
x1, y1, x2, y2 = self.bbox
33+
return (x2 - x1) * (y2 - y1)
34+
35+
36+
@dataclass
37+
class DetectionResult:
38+
"""Complete detection result for a frame."""
39+
40+
detections: List[Detection]
41+
inference_time: float # seconds
42+
frame_shape: Tuple[int, int, int] # (H, W, C)
43+
44+
def filter_by_class(self, class_names: List[str]) -> "DetectionResult":
45+
"""Filter detections by class names."""
46+
filtered = [d for d in self.detections if d.class_name in class_names]
47+
return DetectionResult(filtered, self.inference_time, self.frame_shape)
48+
49+
def filter_by_confidence(self, min_confidence: float) -> "DetectionResult":
50+
"""Filter detections by minimum confidence."""
51+
filtered = [d for d in self.detections if d.confidence >= min_confidence]
52+
return DetectionResult(filtered, self.inference_time, self.frame_shape)
53+
54+
@property
55+
def count(self) -> int:
56+
"""Number of detections."""
57+
return len(self.detections)
58+
59+
60+
class BaseDetector(ABC):
61+
"""Abstract base class for all detectors."""
62+
63+
def __init__(self, config: dict):
64+
"""Initialize detector with configuration.
65+
66+
Args:
67+
config: Dictionary containing detector configuration
68+
"""
69+
self.config = config
70+
self.class_names: List[str] = []
71+
self._initialized = False
72+
73+
@abstractmethod
74+
def load_model(self) -> None:
75+
"""Load the detection model."""
76+
pass
77+
78+
@abstractmethod
79+
def detect(self, frame: np.ndarray) -> DetectionResult:
80+
"""Run detection on a frame.
81+
82+
Args:
83+
frame: Input image as numpy array (H, W, C) in BGR format
84+
85+
Returns:
86+
DetectionResult containing all detections
87+
"""
88+
pass
89+
90+
@abstractmethod
91+
def get_class_name(self, class_id: int) -> str:
92+
"""Get class name from class ID."""
93+
pass
94+
95+
def warmup(self, num_iterations: int = 3) -> None:
96+
"""Warmup the model with dummy input. Useful for GPU initialization.
97+
98+
Args:
99+
num_iterations: Number of warmup iterations
100+
"""
101+
if not self._initialized:
102+
raise RuntimeError("Model not loaded. Call load_model() first.")
103+
104+
dummy_frame = np.zeros((480, 640, 3), dtype=np.uint8)
105+
for _ in range(num_iterations):
106+
self.detect(dummy_frame)
107+
108+
def is_initialized(self) -> bool:
109+
"""Check if model is loaded and ready."""
110+
return self._initialized
111+
112+
@property
113+
def device(self) -> str:
114+
"""Get the device the model is running on."""
115+
return self.config.get("device", "cpu")
116+
117+
@staticmethod
118+
def create_detector(backend: str, config: dict) -> "BaseDetector":
119+
"""Factory method to create detector instance.
120+
121+
Args:
122+
backend: Detector backend name ('yolov8', 'detectron2', etc.)
123+
config: Configuration dictionary
124+
125+
Returns:
126+
Detector instance
127+
128+
Raises:
129+
ValueError: If backend is not supported
130+
"""
131+
if backend == "yolov8":
132+
from .yolo_detector import YOLODetector
133+
134+
return YOLODetector(config)
135+
elif backend == "detectron2":
136+
from .detectron2_detector import Detectron2Detector
137+
138+
return Detectron2Detector(config)
139+
else:
140+
raise ValueError(f"Unsupported detector backend: {backend}")

0 commit comments

Comments
 (0)