Skip to content

Commit af7b842

Browse files
author
Saumya Saksena
committed
Add examples
1 parent de0a13a commit af7b842

File tree

3 files changed

+526
-0
lines changed

3 files changed

+526
-0
lines changed

examples/benchmark.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""Benchmark different detector models.
2+
3+
Compares speed and optionally accuracy across different backends and model sizes.
4+
"""
5+
6+
import argparse
7+
import time
8+
from typing import Dict
9+
10+
import numpy as np
11+
import yaml
12+
13+
from tello_vision.detectors.base_detector import BaseDetector
14+
15+
16+
def benchmark_detector(
17+
detector: BaseDetector, num_frames: int = 100, resolution: tuple = (960, 720)
18+
) -> Dict:
19+
"""Benchmark a detector.
20+
21+
Args:
22+
detector: Detector instance
23+
num_frames: Number of frames to process
24+
resolution: Frame resolution (width, height)
25+
26+
Returns:
27+
Dictionary with benchmark results
28+
"""
29+
print(f"Benchmarking {detector.__class__.__name__}...")
30+
31+
# Load model
32+
detector.load_model()
33+
34+
# Warmup
35+
print(" Warming up...")
36+
detector.warmup(num_iterations=10)
37+
38+
# Generate dummy frames
39+
frames = [
40+
np.random.randint(0, 255, (resolution[1], resolution[0], 3), dtype=np.uint8)
41+
for _ in range(num_frames)
42+
]
43+
44+
# Benchmark
45+
print(f" Processing {num_frames} frames...")
46+
inference_times = []
47+
total_detections = 0
48+
49+
start_time = time.time()
50+
51+
for frame in frames:
52+
result = detector.detect(frame)
53+
inference_times.append(result.inference_time)
54+
total_detections += result.count
55+
56+
total_time = time.time() - start_time
57+
58+
# Calculate stats
59+
avg_inference = np.mean(inference_times)
60+
std_inference = np.std(inference_times)
61+
fps = num_frames / total_time
62+
63+
return {
64+
"avg_inference_ms": avg_inference * 1000,
65+
"std_inference_ms": std_inference * 1000,
66+
"min_inference_ms": min(inference_times) * 1000,
67+
"max_inference_ms": max(inference_times) * 1000,
68+
"fps": fps,
69+
"total_time": total_time,
70+
"avg_detections": total_detections / num_frames,
71+
}
72+
73+
74+
def main():
75+
parser = argparse.ArgumentParser(description="Benchmark detectors")
76+
parser.add_argument(
77+
"--num-frames", type=int, default=100, help="Number of frames to process"
78+
)
79+
parser.add_argument(
80+
"--resolution", type=str, default="960x720", help="Frame resolution (WxH)"
81+
)
82+
args = parser.parse_args()
83+
84+
# Parse resolution
85+
width, height = map(int, args.resolution.split("x"))
86+
resolution = (width, height)
87+
88+
# Load base config
89+
with open("config.yaml", "r") as f:
90+
base_config = yaml.safe_load(f)
91+
92+
# Define models to benchmark
93+
benchmarks = [
94+
# YOLOv8 models
95+
(
96+
"YOLOv8n-seg (Nano)",
97+
"yolov8",
98+
{"model": "yolov8n-seg.pt", "device": "cuda", "confidence": 0.5},
99+
),
100+
(
101+
"YOLOv8s-seg (Small)",
102+
"yolov8",
103+
{"model": "yolov8s-seg.pt", "device": "cuda", "confidence": 0.5},
104+
),
105+
(
106+
"YOLOv8m-seg (Medium)",
107+
"yolov8",
108+
{"model": "yolov8m-seg.pt", "device": "cuda", "confidence": 0.5},
109+
),
110+
# Detectron2
111+
("Detectron2 R50-FPN", "detectron2", base_config["detector"]["detectron2"]),
112+
]
113+
114+
results = []
115+
116+
print("=" * 80)
117+
print(
118+
f"Benchmarking Detectors - {args.num_frames}"
119+
f" frames at {resolution[0]}x{resolution[1]}"
120+
)
121+
print("=" * 80)
122+
print()
123+
124+
for name, backend, config in benchmarks:
125+
try:
126+
detector = BaseDetector.create_detector(backend, config)
127+
result = benchmark_detector(detector, args.num_frames, resolution)
128+
result["name"] = name
129+
results.append(result)
130+
print(" ✓ Complete\n")
131+
except Exception as e:
132+
print(f" ✗ Failed: {e}\n")
133+
continue
134+
135+
# Print results
136+
print("\n" + "=" * 80)
137+
print("BENCHMARK RESULTS")
138+
print("=" * 80)
139+
print()
140+
print(
141+
f"{'Model':<30} {'FPS':>8} {'Avg(ms)':>10}"
142+
f" {'Std(ms)':>10} {'Min(ms)':>10} {'Max(ms)':>10}"
143+
)
144+
print("-" * 80)
145+
146+
for result in sorted(results, key=lambda x: x["fps"], reverse=True):
147+
print(
148+
f"{result['name']:<30} "
149+
f"{result['fps']:>8.1f} "
150+
f"{result['avg_inference_ms']:>10.1f} "
151+
f"{result['std_inference_ms']:>10.1f} "
152+
f"{result['min_inference_ms']:>10.1f} "
153+
f"{result['max_inference_ms']:>10.1f}"
154+
)
155+
156+
print()
157+
print("=" * 80)
158+
print("\nRecommendations:")
159+
160+
fastest = max(results, key=lambda x: x["fps"])
161+
print(f" Fastest: {fastest['name']} ({fastest['fps']:.1f} FPS)")
162+
163+
most_stable = min(results, key=lambda x: x["std_inference_ms"])
164+
print(
165+
f" Most Stable: {most_stable['name']}"
166+
f" (±{most_stable['std_inference_ms']:.1f}ms)"
167+
)
168+
169+
print()
170+
171+
172+
if __name__ == "__main__":
173+
main()

0 commit comments

Comments
 (0)