|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Evaluate VLM configurations on DROID trajectories. |
| 4 | +
|
| 5 | +Features: |
| 6 | +- Download trajectories once, reuse across runs |
| 7 | +- Vary number of evenly sampled frames (e.g., 4, 8, 16, 32) |
| 8 | +- Vary passing method: 'stream' (per-frame) vs 'concat' (tiled grid) |
| 9 | +- Vary camera video path keys (e.g., 'ext1_mp4_path', 'wrist_mp4_path') |
| 10 | +- Save per-run outputs into distinct folders |
| 11 | +- Produce a summary CSV of accuracy per configuration |
| 12 | +
|
| 13 | +Usage examples: |
| 14 | + python evaluate_vlm_configs.py \ |
| 15 | + --paths-file results/all_droid_trajectory_paths.txt \ |
| 16 | + --num-trajectories 50 \ |
| 17 | + --eval-root ./eval_runs \ |
| 18 | + --frame-counts 4 8 16 32 \ |
| 19 | + --passing-methods stream concat \ |
| 20 | + --video-path-keys ext1_mp4_path wrist_mp4_path |
| 21 | +
|
| 22 | + # Or specify GCS trajectories directly |
| 23 | + python evaluate_vlm_configs.py \ |
| 24 | + --trajectories gs://.../success/... gs://.../failure/... \ |
| 25 | + --eval-root ./eval_runs |
| 26 | +""" |
| 27 | + |
| 28 | +import argparse |
| 29 | +import csv |
| 30 | +import json |
| 31 | +import os |
| 32 | +import random |
| 33 | +import time |
| 34 | +from pathlib import Path |
| 35 | +from typing import Dict, List, Optional, Tuple |
| 36 | + |
| 37 | +import numpy as np |
| 38 | + |
| 39 | +# Local imports |
| 40 | +from simple_vlm_processing import process_trajectories_parallel |
| 41 | +from droid_pipeline import download_trajectories |
| 42 | + |
| 43 | + |
| 44 | +def load_paths(paths_file: str) -> List[str]: |
| 45 | + try: |
| 46 | + with open(paths_file, 'r') as f: |
| 47 | + return [line.strip() for line in f if line.strip()] |
| 48 | + except Exception as e: |
| 49 | + print(f"❌ Failed to load paths from {paths_file}: {e}") |
| 50 | + return [] |
| 51 | + |
| 52 | + |
| 53 | +def sample_paths(paths: List[str], k: Optional[int], balance: Optional[float], seed: Optional[int]) -> List[str]: |
| 54 | + if seed is not None: |
| 55 | + random.seed(seed) |
| 56 | + if k is None or k <= 0 or k >= len(paths): |
| 57 | + return list(paths) |
| 58 | + if balance is None: |
| 59 | + return random.sample(paths, k) |
| 60 | + success_paths = [p for p in paths if 'success' in p.lower()] |
| 61 | + failure_paths = [p for p in paths if 'failure' in p.lower()] |
| 62 | + k_success = int(round(k * balance)) |
| 63 | + k_failure = k - k_success |
| 64 | + chosen = random.sample(success_paths, min(k_success, len(success_paths))) |
| 65 | + chosen += random.sample(failure_paths, min(k_failure, len(failure_paths))) |
| 66 | + if len(chosen) < k: |
| 67 | + remaining = [p for p in paths if p not in chosen] |
| 68 | + chosen += random.sample(remaining, min(k - len(chosen), len(remaining))) |
| 69 | + return chosen |
| 70 | + |
| 71 | + |
| 72 | +def infer_label_from_gcs_path(gcs_path: str) -> Optional[bool]: |
| 73 | + g = gcs_path.lower() |
| 74 | + if 'success' in g: |
| 75 | + return True |
| 76 | + if 'failure' in g: |
| 77 | + return False |
| 78 | + return None |
| 79 | + |
| 80 | + |
| 81 | +def build_ground_truth_by_name(gcs_paths: List[str]) -> Dict[str, bool]: |
| 82 | + gt: Dict[str, bool] = {} |
| 83 | + for p in gcs_paths: |
| 84 | + traj_name = p.rstrip('/').split('/')[-1] |
| 85 | + label = infer_label_from_gcs_path(p) |
| 86 | + if label is not None: |
| 87 | + gt[traj_name] = label |
| 88 | + return gt |
| 89 | + |
| 90 | + |
| 91 | +def compute_accuracy(results: Dict[str, Dict], gt_by_name: Dict[str, bool]) -> Tuple[int, int, int, float]: |
| 92 | + total = 0 |
| 93 | + predicted = 0 |
| 94 | + correct = 0 |
| 95 | + for local_path, res in results.items(): |
| 96 | + traj_name = os.path.basename(local_path.rstrip('/')) |
| 97 | + if traj_name not in gt_by_name: |
| 98 | + continue |
| 99 | + total += 1 |
| 100 | + if not res.get('success', False): |
| 101 | + continue |
| 102 | + predicted += 1 |
| 103 | + pred = bool(res.get('vlm_prediction', False)) |
| 104 | + if pred == gt_by_name[traj_name]: |
| 105 | + correct += 1 |
| 106 | + acc = (correct / predicted) if predicted > 0 else 0.0 |
| 107 | + return total, predicted, correct, acc |
| 108 | + |
| 109 | + |
| 110 | +def main(): |
| 111 | + parser = argparse.ArgumentParser(description="Evaluate VLM configs on DROID trajectories") |
| 112 | + group = parser.add_mutually_exclusive_group(required=False) |
| 113 | + group.add_argument("--paths-file", default="results/all_droid_trajectory_paths.txt", |
| 114 | + help="File containing GCS trajectory paths") |
| 115 | + group.add_argument("--trajectories", nargs='+', help="GCS paths to DROID trajectory directories") |
| 116 | + |
| 117 | + parser.add_argument("--num-trajectories", type=int, help="Number of trajectories to sample") |
| 118 | + parser.add_argument("--balance", type=float, help="Success ratio target in sampling, e.g., 0.5") |
| 119 | + parser.add_argument("--seed", type=int, help="Random seed") |
| 120 | + parser.add_argument("--max-workers", type=int, default=4, help="Parallel workers for VLM") |
| 121 | + parser.add_argument("--eval-root", default="./eval_runs", help="Root folder for evaluation outputs") |
| 122 | + |
| 123 | + parser.add_argument("--frame-counts", type=int, nargs='+', default=[4, 8, 16, 32], |
| 124 | + help="Frame counts to evaluate") |
| 125 | + parser.add_argument("--passing-methods", nargs='+', default=["stream", "concat"], |
| 126 | + choices=["stream", "concat"], help="Passing methods to evaluate") |
| 127 | + parser.add_argument("--video-path-keys", nargs='*', default=None, |
| 128 | + help="Video path keys from metadata (e.g., ext1_mp4_path wrist_mp4_path). If omitted, auto-detect.") |
| 129 | + |
| 130 | + parser.add_argument("--language-key", default="metadata/language_instruction", |
| 131 | + help="Language key to extract from HDF5 fallback") |
| 132 | + parser.add_argument("--question", default="Is this trajectory successful?", |
| 133 | + help="VLM question") |
| 134 | + |
| 135 | + args = parser.parse_args() |
| 136 | + |
| 137 | + # Resolve GCS paths |
| 138 | + if args.trajectories: |
| 139 | + gcs_paths = list(args.trajectories) |
| 140 | + else: |
| 141 | + gcs_paths = load_paths(args.paths_file) |
| 142 | + if not gcs_paths: |
| 143 | + print("❌ No GCS trajectory paths provided or loaded") |
| 144 | + return 1 |
| 145 | + |
| 146 | + # Sample |
| 147 | + gcs_paths = sample_paths(gcs_paths, args.num_trajectories, args.balance, args.seed) |
| 148 | + print(f"📊 Using {len(gcs_paths)} trajectories for evaluation") |
| 149 | + |
| 150 | + # Prepare eval root |
| 151 | + eval_root = Path(args.eval_root) |
| 152 | + runs_root = eval_root / "runs" |
| 153 | + downloads_root = eval_root / "droid_trajectories" |
| 154 | + os.makedirs(runs_root, exist_ok=True) |
| 155 | + |
| 156 | + # Download once |
| 157 | + print("\n📥 Downloading trajectories once for reuse...") |
| 158 | + successful_local_paths, failed = download_trajectories(gcs_paths, str(downloads_root), max_workers=args.max_workers) |
| 159 | + if not successful_local_paths: |
| 160 | + print("❌ Download failed for all trajectories") |
| 161 | + return 1 |
| 162 | + print(f"✅ Downloaded {len(successful_local_paths)} trajectories; {len(failed)} failed") |
| 163 | + |
| 164 | + # Ground truth by traj_name |
| 165 | + gt_by_name = build_ground_truth_by_name(gcs_paths) |
| 166 | + # Persist ground truth CSV |
| 167 | + with open(eval_root / "ground_truth.csv", 'w', newline='') as f: |
| 168 | + writer = csv.writer(f) |
| 169 | + writer.writerow(["trajectory_name", "label_success"]) |
| 170 | + for name, label in sorted(gt_by_name.items()): |
| 171 | + writer.writerow([name, int(label)]) |
| 172 | + |
| 173 | + # Evaluate configurations |
| 174 | + summary_rows = [] |
| 175 | + configs = [] |
| 176 | + for method in args.passing_methods: |
| 177 | + for n in args.frame_counts: |
| 178 | + if args.video_path_keys is None or len(args.video_path_keys) == 0: |
| 179 | + configs.append((method, n, None)) |
| 180 | + else: |
| 181 | + for cam_key in args.video_path_keys: |
| 182 | + configs.append((method, n, cam_key)) |
| 183 | + |
| 184 | + start_all = time.time() |
| 185 | + for (method, n, cam_key) in configs: |
| 186 | + run_name = f"method={method}_frames={n}" + (f"_cam={cam_key}" if cam_key else "") |
| 187 | + run_out_dir = runs_root / run_name |
| 188 | + os.makedirs(run_out_dir, exist_ok=True) |
| 189 | + |
| 190 | + print(f"\n🚀 Run: {run_name}") |
| 191 | + results = process_trajectories_parallel( |
| 192 | + trajectory_paths=successful_local_paths, |
| 193 | + image_key="", # not used for DROID directories when MP4s present |
| 194 | + language_key=args.language_key, |
| 195 | + question=args.question, |
| 196 | + max_workers=args.max_workers, |
| 197 | + output_dir=str(run_out_dir), |
| 198 | + video_path_key=cam_key, |
| 199 | + num_frames=n, |
| 200 | + passing_method=method, |
| 201 | + concat_grid_cols=None |
| 202 | + ) |
| 203 | + |
| 204 | + # Persist raw results |
| 205 | + with open(run_out_dir / "vlm_results.json", 'w') as f: |
| 206 | + json.dump(results, f, indent=2) |
| 207 | + |
| 208 | + total, predicted, correct, acc = compute_accuracy(results, gt_by_name) |
| 209 | + print(f"📈 Accuracy: {acc:.3f} ({correct}/{predicted}) | total {total}") |
| 210 | + |
| 211 | + # Save metrics per run |
| 212 | + with open(run_out_dir / "metrics.csv", 'w', newline='') as f: |
| 213 | + writer = csv.writer(f) |
| 214 | + writer.writerow(["method", "frames", "camera_key", "total", "predicted", "correct", "accuracy"]) |
| 215 | + writer.writerow([method, n, cam_key or "auto", total, predicted, correct, f"{acc:.6f}"]) |
| 216 | + |
| 217 | + summary_rows.append({ |
| 218 | + "method": method, |
| 219 | + "frames": n, |
| 220 | + "camera_key": cam_key or "auto", |
| 221 | + "total": total, |
| 222 | + "predicted": predicted, |
| 223 | + "correct": correct, |
| 224 | + "accuracy": acc, |
| 225 | + "run_dir": str(run_out_dir) |
| 226 | + }) |
| 227 | + |
| 228 | + # Write overall summary |
| 229 | + with open(eval_root / "summary.csv", 'w', newline='') as f: |
| 230 | + writer = csv.writer(f) |
| 231 | + writer.writerow(["method", "frames", "camera_key", "total", "predicted", "correct", "accuracy", "run_dir"]) |
| 232 | + for r in summary_rows: |
| 233 | + writer.writerow([r["method"], r["frames"], r["camera_key"], r["total"], r["predicted"], r["correct"], f"{r['accuracy']:.6f}", r["run_dir"]]) |
| 234 | + |
| 235 | + elapsed = time.time() - start_all |
| 236 | + print(f"\n🎉 Evaluation complete in {elapsed/60:.1f} minutes") |
| 237 | + print(f"📁 Outputs in: {eval_root}") |
| 238 | + return 0 |
| 239 | + |
| 240 | + |
| 241 | +if __name__ == "__main__": |
| 242 | + raise SystemExit(main()) |
| 243 | + |
| 244 | + |
0 commit comments