From 3d73b8cbcd5c144d1dea57da699fc1952737d850 Mon Sep 17 00:00:00 2001 From: 5000user5000 Date: Thu, 6 Nov 2025 16:10:08 +0800 Subject: [PATCH] benchmark tools --- benchmark/BENCHMARK_GUIDE.md | 315 +++++++++++++++++++++++++++ benchmark/comprehensive_bench.py | 361 +++++++++++++++++++++++++++++++ benchmark/plot_tradeoff.py | 220 +++++++++++++++++++ 3 files changed, 896 insertions(+) create mode 100644 benchmark/BENCHMARK_GUIDE.md create mode 100755 benchmark/comprehensive_bench.py create mode 100755 benchmark/plot_tradeoff.py diff --git a/benchmark/BENCHMARK_GUIDE.md b/benchmark/BENCHMARK_GUIDE.md new file mode 100644 index 0000000..3071612 --- /dev/null +++ b/benchmark/BENCHMARK_GUIDE.md @@ -0,0 +1,315 @@ +# ZenANN 綜合評估指南 + +本指南說明如何使用 `comprehensive_bench.py` 完成專案要求的所有評估指標。 + +--- + +## 📋 評估指標覆蓋 + +### ✅ 所有指標均已支援 + +| 評估項目 | 支援狀態 | 工具 | +|----------|----------|------| +| **資料集** | +| SIFT1M (128D) | ✅ | comprehensive_bench.py | +| GIST1M (960D) | ✅ | comprehensive_bench.py | +| **準確率** | +| Recall@1 | ✅ | comprehensive_bench.py | +| Recall@10 | ✅ | comprehensive_bench.py | +| Recall@100 | ✅ | comprehensive_bench.py | +| **性能** | +| QPS | ✅ | comprehensive_bench.py | +| p50 latency | ✅ | comprehensive_bench.py | +| p95 latency | ✅ | comprehensive_bench.py | +| **索引成本** | +| Index build time | ✅ | comprehensive_bench.py | +| bytes/vector | ✅ | comprehensive_bench.py | +| **視覺化** | +| Recall-QPS curve | ✅ | plot_tradeoff.py | + +--- + +## 🚀 快速開始 + +### 步驟 1: 準備數據集 + +```bash +# 創建數據目錄 +mkdir -p data + +# 下載 SIFT1M +cd data +wget ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz +tar -xzvf sift.tar.gz + +# 下載 GIST1M +wget ftp://ftp.irisa.fr/local/texmex/corpus/gist.tar.gz +tar -xzvf gist.tar.gz + +cd .. +``` + +### 步驟 2: 安裝依賴 + +```bash +pip install psutil matplotlib numpy +``` + +### 步驟 3: 運行 Benchmark + +```bash +# 設定環境變數 +export LD_LIBRARY_PATH=extern/faiss/build/install/lib:$LD_LIBRARY_PATH + +# SIFT1M 測試 +python3 benchmark/comprehensive_bench.py \ + --base data/sift/sift_base.fvecs \ + --query data/sift/sift_query.fvecs \ + --groundtruth data/sift/sift_groundtruth.ivecs \ + --nlist 1024 \ + --nprobe-list "1,2,4,8,16,32,64,128,256" \ + --k-list "1,10,100" \ + --index-file sift_index.bin \ + --output-dir benchmark_results + +# GIST1M 測試(可選,要注意會花相當多時間) +python3 benchmark/comprehensive_bench.py \ + --base data/gist/gist_base.fvecs \ + --query data/gist/gist_query.fvecs \ + --groundtruth data/gist/gist_groundtruth.ivecs \ + --nlist 1024 \ + --nprobe-list "1,4,16,64,256,512" \ + --k-list "1,10,100" \ + --index-file gist_index.bin \ + --output-dir benchmark_results +``` + +### 步驟 4: 生成報告和圖表 + +```bash +# 生成 Recall-QPS 曲線 +python3 benchmark/plot_tradeoff.py benchmark_results/*.json +``` + +輸出文件: +- `recall_qps_tradeoff.png` - Recall vs QPS 曲線(3 個子圖,對應 k=1,10,100) +- `latency_distribution.png` - 延遲分析圖 +- `benchmark_report.txt` - 文字報告 + +--- + +## 📊 輸出指標說明 + +### Console 輸出範例 + +``` +====================================================================== +Testing: nlist=1024, nprobe=16 +====================================================================== +Measuring batch QPS (k=100)... + QPS (batch): 2450.32 + Latency - Mean: 0.408 ms + Latency - p50: 0.385 ms + Latency - p95: 0.612 ms + Latency - p99: 0.758 ms +Computing Recall@k... + Recall@1: 84.52% + Recall@10: 95.28% + Recall@100: 99.15% + +====================================================================== +SUMMARY: Recall-QPS Trade-off +====================================================================== +nprobe QPS p50(ms) p95(ms) R@1 R@10 R@100 +---------------------------------------------------------------------- +1 12450.3 0.080 0.125 32.15 42.58 58.23 +2 8920.5 0.112 0.185 52.34 65.87 78.45 +4 5630.2 0.178 0.295 68.92 82.15 89.67 +8 3580.1 0.279 0.448 79.45 91.23 95.82 +16 2450.3 0.408 0.612 84.52 95.28 98.15 +32 1680.5 0.595 0.891 87.89 97.45 99.32 + +✅ Target achieved: Recall@10 = 95.28% >= 95% + Best config: nprobe=16, QPS=2450.3 +``` + +### JSON 輸出 + +```json +{ + "metadata": { + "dataset": "sift", + "n_base": 1000000, + "n_queries": 10000, + "dimension": 128, + "nlist": 1024, + "nprobe_list": [1, 2, 4, 8, 16, 32], + "k_values": [1, 10, 100], + "build_time_sec": 45.234, + "bytes_per_vector": 8.5, + "timestamp": "20251106_150000" + }, + "results": [ + { + "nlist": 1024, + "nprobe": 16, + "qps_batch": 2450.32, + "latency_mean_ms": 0.408, + "latency_p50_ms": 0.385, + "latency_p95_ms": 0.612, + "latency_p99_ms": 0.758, + "recall@1": 0.8452, + "recall@10": 0.9528, + "recall@100": 0.9915, + "memory_mb": 128.5, + "build_time_sec": 45.234, + "bytes_per_vector": 8.5 + } + ] +} +``` + +--- + +## 整體範例 + +### 以 openMP 為例 + +```bash +export LD_LIBRARY_PATH=extern/faiss/build/install/lib:$LD_LIBRARY_PATH + +# 1. 測試 OpenMP +## "Testing OpenMP version..." +git checkout feature/openMP +make clean && make + +python3 benchmark/comprehensive_bench.py \ + --base data/sift/sift_base.fvecs \ + --query data/sift/sift_query.fvecs \ + --groundtruth data/sift/sift_groundtruth.ivecs \ + --nlist 1024 \ + --nprobe-list "1,4,8,16,32,64" \ + --k-list "1,10,100" \ + --index-file sift_openmp.bin \ + --output-dir results_openmp + +# 2. 生成對比圖表 (但注意要指定正確的 json,或是把之前的 json 清理) +python3 benchmark/plot_tradeoff.py \ + results_baseline/sift*.json \ + results_openmp/sift*.json + +## "Done! Check recall_qps_tradeoff.png and benchmark_report.txt" +``` + +--- + +## 📈 預期結果 + +### Recall@10 ≥ 0.95 達成條件 + +根據文獻,對於 SIFT1M: + +| nlist | nprobe | 預期 Recall@10 | 預期 QPS (baseline) | +|-------|--------|----------------|---------------------| +| 1024 | 16 | ~95% | ~2000 | +| 1024 | 32 | ~97% | ~1200 | +| 2048 | 32 | ~96% | ~1500 | + +### OpenMP 加速比預期 + +| 指標 | Baseline | OpenMP (8核心) | 加速比 | +|------|----------|----------------|--------| +| QPS | 2000 | 8000-12000 | 4-6x | +| p95 latency | 0.5 ms | 0.15 ms | 3-4x | +| Build time | 45 s | 45 s | 1x (未優化) | + +--- + +## 🔍 故障排除 + +### 問題 1: Recall 太低 + +**症狀**:即使 nprobe=256 也達不到 95% + +**解決**: +```bash +# 增加 nlist +python3 comprehensive_bench.py ... --nlist 2048 +``` + +### 問題 2: QPS 沒有提升 + +**症狀**:OpenMP 版本 QPS 與 baseline 相同 + +**檢查**: +```bash +# 確認 OpenMP 編譯標誌 +cat Makefile | grep fopenmp + +# 確認運行時線程數 +export OMP_NUM_THREADS=8 +``` + +### 問題 3: 記憶體不足 + +**症狀**:`MemoryError` 或程序被殺 + +**解決**:使用索引文件避免重複建構 +```bash +# 先建構並保存索引 +python3 comprehensive_bench.py ... --index-file sift.bin + +# 後續測試重用索引(跳過 build) +python3 comprehensive_bench.py ... --index-file sift.bin +``` + +--- + +## 💡 進階使用 + +### 自定義 nprobe 掃描範圍 + +```bash +# 細粒度掃描(找到精確的 Recall@10=95% 點) +--nprobe-list "10,12,14,16,18,20,22,24" + +# 粗粒度掃描(快速探索) +--nprobe-list "1,8,64,512" +``` + +### 測試不同 nlist 配置 + +```bash +# 對比不同 nlist +for nlist in 512 1024 2048; do + python3 comprehensive_bench.py \ + ... \ + --nlist $nlist \ + --output-dir results_nlist${nlist} +done + +# 統一繪圖對比 +python3 plot_tradeoff.py results_nlist*/*.json +``` + +--- + +## 📚 相關文檔 + +- `comprehensive_bench.py --help` - 完整參數說明 +- `plot_tradeoff.py --help` - 繪圖工具說明 +- `ivf-bench.py` - 簡化版測試(向下兼容) + +--- + +## ✅ 檢查清單 + +完成評估前確認: + +- [ ] SIFT1M 數據集已下載 +- [ ] GIST1M 數據集已下載(可選) +- [ ] 已安裝 psutil, matplotlib +- [ ] 生成了 Recall-QPS 曲線圖 +- [ ] 確認 Recall@10 ≥ 95% 在合理的 QPS 下達成 +- [ ] 記錄了 OpenMP 加速比 +- [ ] 保存了所有 JSON 結果文件 diff --git a/benchmark/comprehensive_bench.py b/benchmark/comprehensive_bench.py new file mode 100755 index 0000000..d9ee184 --- /dev/null +++ b/benchmark/comprehensive_bench.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python3 +""" +Comprehensive Benchmark for ZenANN IVFFlatIndex + +Measures all required metrics: +- Recall@k for k ∈ {1, 10, 100} +- QPS (Queries Per Second) +- Latency (p50, p95) +- Index build time +- Memory usage (bytes/vector) +- Recall-QPS trade-off curve across nprobe values + +Supports SIFT1M and GIST1M datasets. +""" + +import sys +import os +import time +import argparse +import numpy as np +import json +import csv +from pathlib import Path +from datetime import datetime + +# Add build directory to path +sys.path.insert(0, os.path.abspath(os.path.join(__file__, '..', '..', 'build'))) + + +def load_fvecs(filename): + """Load .fvecs format file""" + fv = np.fromfile(filename, dtype=np.float32) + if fv.size == 0: + return np.zeros((0, 0)) + dim = fv.view(np.int32)[0] + assert dim > 0 + fv = fv.reshape(-1, 1 + dim) + if not all(fv.view(np.int32)[:, 0] == dim): + raise IOError(f"Non-uniform vector sizes in {filename}") + fv = fv[:, 1:] + return fv.copy() + + +def load_ivecs(filename): + """Load .ivecs format file""" + fv = np.fromfile(filename, dtype=np.int32) + if fv.size == 0: + return np.zeros((0, 0), dtype=np.int32) + dim = fv.view(np.int32)[0] + assert dim > 0 + fv = fv.reshape(-1, 1 + dim) + fv = fv[:, 1:] + return fv + + +def compute_recall_at_k(predicted, groundtruth, k): + """Compute Recall@k""" + n_queries = len(predicted) + recalls = [] + + for i in range(n_queries): + # Get top-k from ground truth + true_set = set(groundtruth[i, :k]) + # Get top-k from predictions (handle variable length) + pred_k = min(k, len(predicted[i])) + pred_set = set(predicted[i][:pred_k]) + # Calculate recall + intersection = len(true_set & pred_set) + recall = intersection / k + recalls.append(recall) + + return np.mean(recalls) + + +def measure_latencies(index, queries, k, nprobe): + """Measure per-query latencies for percentile calculation""" + latencies = [] + + for query in queries: + t0 = time.perf_counter() + result = index.search(query.tolist(), k, nprobe) + t1 = time.perf_counter() + latencies.append((t1 - t0) * 1000) # Convert to milliseconds + + return np.array(latencies) + + +def get_memory_usage_mb(): + """Get current process memory usage in MB""" + import psutil + process = psutil.Process() + return process.memory_info().rss / 1024 / 1024 + + +def benchmark_single_config(index, queries, gt, nlist, nprobe, k_values, n_base_vectors): + """ + Benchmark a single configuration (nlist, nprobe) + Returns metrics for all k values + """ + import zenann + + print(f"\n{'='*70}") + print(f"Testing: nlist={nlist}, nprobe={nprobe}") + print(f"{'='*70}") + + # 1. Measure batch QPS (use max k) + max_k = max(k_values) + print(f"Measuring batch QPS (k={max_k})...") + + t0 = time.perf_counter() + all_results = index.search_batch(queries.tolist(), max_k, nprobe) + t_batch = time.perf_counter() - t0 + qps_batch = len(queries) / t_batch + + # Collect results + results_array = [] + for res in all_results: + results_array.append(res.indices) + + # 2. Measure per-query latencies (sample for p50/p95) + print(f"Measuring latencies (sampling 1000 queries)...") + sample_size = min(1000, len(queries)) + sample_indices = np.random.choice(len(queries), sample_size, replace=False) + sample_queries = queries[sample_indices] + + latencies = measure_latencies(index, sample_queries, max_k, nprobe) + p50_latency = np.percentile(latencies, 50) + p95_latency = np.percentile(latencies, 95) + p99_latency = np.percentile(latencies, 99) + mean_latency = np.mean(latencies) + + print(f" QPS (batch): {qps_batch:.2f}") + print(f" Latency - Mean: {mean_latency:.3f} ms") + print(f" Latency - p50: {p50_latency:.3f} ms") + print(f" Latency - p95: {p95_latency:.3f} ms") + print(f" Latency - p99: {p99_latency:.3f} ms") + + # 3. Compute Recall@k for multiple k values + print(f"Computing Recall@k...") + recalls = {} + + for k in k_values: + # Pad results to have at least k elements + padded_results = [] + for res in results_array: + if len(res) < k: + padded = list(res) + [-1] * (k - len(res)) + else: + padded = res[:k] + padded_results.append(padded) + + recall = compute_recall_at_k(padded_results, gt, k) + recalls[f'recall@{k}'] = recall + print(f" Recall@{k}: {recall*100:.2f}%") + + # 4. Measure memory usage + mem_after = get_memory_usage_mb() + + return { + 'nlist': nlist, + 'nprobe': nprobe, + 'qps_batch': qps_batch, + 'latency_mean_ms': mean_latency, + 'latency_p50_ms': p50_latency, + 'latency_p95_ms': p95_latency, + 'latency_p99_ms': p99_latency, + **recalls, # Unpack all recall values + 'memory_mb': mem_after, + } + + +def main(args): + import zenann + + print("="*70) + print("ZenANN Comprehensive Benchmark") + print("="*70) + + # 1. Load dataset + print(f"\nLoading dataset...") + base = load_fvecs(args.base) + queries = load_fvecs(args.query) + gt = load_ivecs(args.groundtruth) + + print(f" Base vectors: {base.shape}") + print(f" Queries: {queries.shape}") + print(f" Ground truth: {gt.shape}") + + n_base = base.shape[0] + dim = base.shape[1] + + # 2. Build or load index + mem_before = get_memory_usage_mb() + + if args.index_file and os.path.exists(args.index_file): + print(f"\nLoading index from {args.index_file}...") + index = zenann.IVFFlatIndex.read_index(args.index_file) + print("Index loaded.") + build_time = 0 # Not measured when loading + else: + print(f"\nBuilding IVF index (nlist={args.nlist})...") + index = zenann.IVFFlatIndex(dim=dim, nlist=args.nlist, nprobe=1) + + t0 = time.perf_counter() + index.build(base) + build_time = time.perf_counter() - t0 + + print(f" Build time: {build_time:.3f} s") + + if args.index_file: + print(f" Saving index to {args.index_file}...") + index.write_index(args.index_file) + + mem_after = get_memory_usage_mb() + index_memory_mb = mem_after - mem_before + bytes_per_vector = (index_memory_mb * 1024 * 1024) / n_base + + print(f"\nIndex memory usage:") + print(f" Total: {index_memory_mb:.2f} MB") + print(f" Per vector: {bytes_per_vector:.2f} bytes/vector") + print(f" Ratio to raw data: {bytes_per_vector / (dim * 4):.2f}x") + + # 3. Parse nprobe values + nprobe_list = [int(x.strip()) for x in args.nprobe_list.split(',')] + k_values = [int(x.strip()) for x in args.k_list.split(',')] + + print(f"\nTesting configurations:") + print(f" nprobe values: {nprobe_list}") + print(f" k values: {k_values}") + + # 4. Run benchmarks for each nprobe + results = [] + + for nprobe in nprobe_list: + result = benchmark_single_config( + index, queries, gt, args.nlist, nprobe, k_values, n_base + ) + result['build_time_sec'] = build_time + result['bytes_per_vector'] = bytes_per_vector + results.append(result) + + # 5. Save results + output_dir = Path(args.output_dir) + output_dir.mkdir(exist_ok=True) + + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + dataset_name = Path(args.base).stem.split('_')[0] # e.g., "sift" from "sift_base.fvecs" + + # Save as JSON + json_file = output_dir / f'{dataset_name}_nlist{args.nlist}_{timestamp}.json' + with open(json_file, 'w') as f: + json.dump({ + 'metadata': { + 'dataset': dataset_name, + 'n_base': n_base, + 'n_queries': len(queries), + 'dimension': dim, + 'nlist': args.nlist, + 'nprobe_list': nprobe_list, + 'k_values': k_values, + 'build_time_sec': build_time, + 'bytes_per_vector': bytes_per_vector, + 'timestamp': timestamp, + }, + 'results': results + }, f, indent=2) + + print(f"\n{'='*70}") + print(f"Results saved to: {json_file}") + + # Save as CSV + csv_file = output_dir / f'{dataset_name}_nlist{args.nlist}_{timestamp}.csv' + if results: + with open(csv_file, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=results[0].keys()) + writer.writeheader() + writer.writerows(results) + print(f"Results saved to: {csv_file}") + + # 6. Print summary table + print(f"\n{'='*70}") + print("SUMMARY: Recall-QPS Trade-off") + print(f"{'='*70}") + print(f"{'nprobe':<8} {'QPS':<10} {'p50(ms)':<10} {'p95(ms)':<10} ", end='') + for k in k_values: + print(f"{'R@'+str(k):<10} ", end='') + print() + print("-"*70) + + for r in results: + print(f"{r['nprobe']:<8} {r['qps_batch']:<10.1f} {r['latency_p50_ms']:<10.3f} {r['latency_p95_ms']:<10.3f} ", end='') + for k in k_values: + recall = r[f'recall@{k}'] * 100 + print(f"{recall:<10.2f} ", end='') + print() + + print("="*70) + + # Check if Recall@10 >= 0.95 is achieved + max_recall_10 = max(r['recall@10'] for r in results) + if max_recall_10 >= 0.95: + print(f"\n✅ Target achieved: Recall@10 = {max_recall_10*100:.2f}% >= 95%") + best_config = max(results, key=lambda x: x['recall@10']) + print(f" Best config: nprobe={best_config['nprobe']}, QPS={best_config['qps_batch']:.1f}") + else: + print(f"\n⚠️ Target not met: Best Recall@10 = {max_recall_10*100:.2f}% < 95%") + print(f" Consider increasing nprobe or nlist") + + print() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Comprehensive benchmark for ZenANN IVFFlatIndex', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Example usage: + # SIFT1M benchmark + python3 comprehensive_bench.py \\ + --base data/sift/sift_base.fvecs \\ + --query data/sift/sift_query.fvecs \\ + --groundtruth data/sift/sift_groundtruth.ivecs \\ + --nlist 1024 \\ + --nprobe-list "1,2,4,8,16,32,64,128" \\ + --k-list "1,10,100" + + # GIST1M benchmark + python3 comprehensive_bench.py \\ + --base data/gist/gist_base.fvecs \\ + --query data/gist/gist_query.fvecs \\ + --groundtruth data/gist/gist_groundtruth.ivecs \\ + --nlist 1024 \\ + --nprobe-list "1,4,16,64,256" \\ + --k-list "1,10,100" + """ + ) + + parser.add_argument('--base', required=True, help='Path to base.fvecs') + parser.add_argument('--query', required=True, help='Path to query.fvecs') + parser.add_argument('--groundtruth', required=True, help='Path to groundtruth.ivecs') + parser.add_argument('--nlist', type=int, default=1024, help='Number of IVF clusters (default: 1024)') + parser.add_argument('--nprobe-list', type=str, default='1,4,8,16,32,64', + help='Comma-separated nprobe values to test (default: 1,4,8,16,32,64)') + parser.add_argument('--k-list', type=str, default='1,10,100', + help='Comma-separated k values for Recall@k (default: 1,10,100)') + parser.add_argument('--index-file', default=None, + help='Path to save/load index (optional, speeds up repeated tests)') + parser.add_argument('--output-dir', default='benchmark_results', + help='Output directory for results (default: benchmark_results)') + + args = parser.parse_args() + + # Check dependencies + try: + import psutil + except ImportError: + print("ERROR: psutil module required for memory measurement") + print("Install with: pip install psutil") + sys.exit(1) + + main(args) diff --git a/benchmark/plot_tradeoff.py b/benchmark/plot_tradeoff.py new file mode 100755 index 0000000..4880f39 --- /dev/null +++ b/benchmark/plot_tradeoff.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +""" +Plot Recall-QPS trade-off curves from benchmark results +""" + +import json +import sys +import matplotlib.pyplot as plt +import numpy as np +from pathlib import Path + + +def load_benchmark_result(json_file): + """Load benchmark result from JSON file""" + with open(json_file) as f: + data = json.load(f) + return data + + +def plot_recall_qps_tradeoff(results_files, output_file='recall_qps_tradeoff.png'): + """ + Plot Recall vs QPS trade-off curves + + Args: + results_files: List of JSON result files to compare + output_file: Output image file path + """ + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + fig.suptitle('Recall-QPS Trade-off Curves', fontsize=16, fontweight='bold') + + k_values = [1, 10, 100] + colors = plt.cm.tab10(np.linspace(0, 1, len(results_files))) + + for idx, k in enumerate(k_values): + ax = axes[idx] + + for file_idx, result_file in enumerate(results_files): + data = load_benchmark_result(result_file) + metadata = data['metadata'] + results = data['results'] + + # Extract data + nprobes = [r['nprobe'] for r in results] + qps_values = [r['qps_batch'] for r in results] + recall_values = [r[f'recall@{k}'] * 100 for r in results] + + # Create label + dataset = metadata.get('dataset', 'unknown') + nlist = metadata.get('nlist', '?') + label = f"{dataset} (nlist={nlist})" + + # Plot + ax.plot(recall_values, qps_values, 'o-', + color=colors[file_idx], label=label, linewidth=2, markersize=8) + + # Annotate nprobe values + for i, nprobe in enumerate(nprobes): + if i % 2 == 0: # Only annotate every other point to avoid clutter + ax.annotate(f'np={nprobe}', + (recall_values[i], qps_values[i]), + textcoords="offset points", + xytext=(0, 10), ha='center', fontsize=8) + + ax.set_xlabel(f'Recall@{k} (%)', fontsize=12) + ax.set_ylabel('QPS (queries/sec)', fontsize=12) + ax.set_title(f'Recall@{k} vs QPS', fontsize=14, fontweight='bold') + ax.grid(True, alpha=0.3) + ax.legend(fontsize=10) + + # Add target line for Recall@10 + if k == 10: + ax.axvline(x=95, color='red', linestyle='--', linewidth=2, label='Target (95%)') + ax.legend(fontsize=10) + + plt.tight_layout() + plt.savefig(output_file, dpi=300, bbox_inches='tight') + print(f"Plot saved to: {output_file}") + plt.show() + + +def plot_latency_distribution(results_files, output_file='latency_distribution.png'): + """ + Plot latency percentiles across nprobe values + """ + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + fig.suptitle('Latency Analysis', fontsize=16, fontweight='bold') + + colors = plt.cm.tab10(np.linspace(0, 1, len(results_files))) + + for file_idx, result_file in enumerate(results_files): + data = load_benchmark_result(result_file) + metadata = data['metadata'] + results = data['results'] + + nprobes = [r['nprobe'] for r in results] + p50 = [r['latency_p50_ms'] for r in results] + p95 = [r['latency_p95_ms'] for r in results] + p99 = [r['latency_p99_ms'] for r in results] + mean_latency = [r['latency_mean_ms'] for r in results] + + dataset = metadata.get('dataset', 'unknown') + nlist = metadata.get('nlist', '?') + label = f"{dataset} (nlist={nlist})" + + # Plot 1: Latency vs nprobe + ax = axes[0] + ax.plot(nprobes, mean_latency, 'o-', color=colors[file_idx], + label=f'{label} (mean)', linewidth=2) + ax.plot(nprobes, p95, 's--', color=colors[file_idx], + label=f'{label} (p95)', linewidth=1.5, alpha=0.7) + + # Plot 2: p50 vs p95 + ax = axes[1] + ax.plot(p50, p95, 'o-', color=colors[file_idx], + label=label, linewidth=2, markersize=8) + + # Annotate nprobe values + for i, nprobe in enumerate(nprobes): + if i % 2 == 0: + ax.annotate(f'np={nprobe}', + (p50[i], p95[i]), + textcoords="offset points", + xytext=(5, 5), ha='left', fontsize=8) + + # Configure plot 1 + axes[0].set_xlabel('nprobe', fontsize=12) + axes[0].set_ylabel('Latency (ms)', fontsize=12) + axes[0].set_title('Latency vs nprobe', fontsize=14, fontweight='bold') + axes[0].set_xscale('log') + axes[0].set_yscale('log') + axes[0].grid(True, alpha=0.3) + axes[0].legend(fontsize=9) + + # Configure plot 2 + axes[1].set_xlabel('p50 Latency (ms)', fontsize=12) + axes[1].set_ylabel('p95 Latency (ms)', fontsize=12) + axes[1].set_title('p50 vs p95 Latency', fontsize=14, fontweight='bold') + axes[1].grid(True, alpha=0.3) + axes[1].legend(fontsize=10) + + plt.tight_layout() + plt.savefig(output_file, dpi=300, bbox_inches='tight') + print(f"Plot saved to: {output_file}") + plt.show() + + +def generate_report(results_files, output_file='benchmark_report.txt'): + """Generate a text report comparing all results""" + with open(output_file, 'w') as f: + f.write("="*80 + "\n") + f.write("BENCHMARK COMPARISON REPORT\n") + f.write("="*80 + "\n\n") + + for result_file in results_files: + data = load_benchmark_result(result_file) + metadata = data['metadata'] + results = data['results'] + + dataset = metadata.get('dataset', 'unknown') + nlist = metadata.get('nlist', '?') + build_time = metadata.get('build_time_sec', '?') + bytes_per_vec = metadata.get('bytes_per_vector', '?') + + f.write(f"\nDataset: {dataset}\n") + f.write(f" nlist: {nlist}\n") + f.write(f" Build time: {build_time:.2f} s\n") + f.write(f" Memory: {bytes_per_vec:.2f} bytes/vector\n") + f.write(f"\n {'nprobe':<8} {'QPS':<10} {'p50(ms)':<10} {'p95(ms)':<10} {'R@1':<8} {'R@10':<8} {'R@100':<8}\n") + f.write(" " + "-"*70 + "\n") + + for r in results: + f.write(f" {r['nprobe']:<8} {r['qps_batch']:<10.1f} " + f"{r['latency_p50_ms']:<10.3f} {r['latency_p95_ms']:<10.3f} " + f"{r.get('recall@1', 0)*100:<8.2f} " + f"{r.get('recall@10', 0)*100:<8.2f} " + f"{r.get('recall@100', 0)*100:<8.2f}\n") + + # Find best config for Recall@10 >= 95% + candidates = [r for r in results if r.get('recall@10', 0) >= 0.95] + if candidates: + best = max(candidates, key=lambda x: x['qps_batch']) + f.write(f"\n ✅ Best config (R@10≥95%): nprobe={best['nprobe']}, " + f"QPS={best['qps_batch']:.1f}, R@10={best['recall@10']*100:.2f}%\n") + else: + f.write(f"\n ⚠️ No config achieves R@10≥95%\n") + + f.write("\n" + "-"*80 + "\n") + + print(f"Report saved to: {output_file}") + + +def main(): + if len(sys.argv) < 2: + print("Usage: python3 plot_tradeoff.py [result2.json ...]") + print("\nExample:") + print(" python3 plot_tradeoff.py benchmark_results/sift_nlist1024_*.json") + sys.exit(1) + + result_files = sys.argv[1:] + + # Verify files exist + for f in result_files: + if not Path(f).exists(): + print(f"Error: File not found: {f}") + sys.exit(1) + + print(f"Plotting {len(result_files)} result file(s)...") + + # Generate plots + plot_recall_qps_tradeoff(result_files, 'recall_qps_tradeoff.png') + plot_latency_distribution(result_files, 'latency_distribution.png') + + # Generate report + generate_report(result_files, 'benchmark_report.txt') + + print("\nDone!") + + +if __name__ == '__main__': + main()