Skip to content

Commit 83b235e

Browse files
author
Your Name
committed
Add support for analyzing multiple images in VLM processing
- Updated `.gitignore` to include `eval_runs/`. - Introduced `make_image_grid` function for creating tiled grid images from a list of RGB images. - Enhanced `process_single_trajectory` to support different methods for passing frames to VLM: either as a stream or as a concatenated grid. - Modified `VLMService` to analyze multiple images together with a single prompt. - Updated command-line arguments to allow configuration of frame sampling and passing method. - Improved documentation and comments for clarity on new functionalities.
1 parent 10b489a commit 83b235e

File tree

5 files changed

+424
-52
lines changed

5 files changed

+424
-52
lines changed

examples/droid_h5/.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
results/
2-
output/
2+
output/
3+
eval_runs/
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Evaluate VLM configurations on DROID trajectories.
4+
5+
Features:
6+
- Download trajectories once, reuse across runs
7+
- Vary number of evenly sampled frames (e.g., 4, 8, 16, 32)
8+
- Vary passing method: 'stream' (per-frame) vs 'concat' (tiled grid)
9+
- Vary camera video path keys (e.g., 'ext1_mp4_path', 'wrist_mp4_path')
10+
- Save per-run outputs into distinct folders
11+
- Produce a summary CSV of accuracy per configuration
12+
13+
Usage examples:
14+
python evaluate_vlm_configs.py \
15+
--paths-file results/all_droid_trajectory_paths.txt \
16+
--num-trajectories 50 \
17+
--eval-root ./eval_runs \
18+
--frame-counts 4 8 16 32 \
19+
--passing-methods stream concat \
20+
--video-path-keys ext1_mp4_path wrist_mp4_path
21+
22+
# Or specify GCS trajectories directly
23+
python evaluate_vlm_configs.py \
24+
--trajectories gs://.../success/... gs://.../failure/... \
25+
--eval-root ./eval_runs
26+
"""
27+
28+
import argparse
29+
import csv
30+
import json
31+
import os
32+
import random
33+
import time
34+
from pathlib import Path
35+
from typing import Dict, List, Optional, Tuple
36+
37+
import numpy as np
38+
39+
# Local imports
40+
from simple_vlm_processing import process_trajectories_parallel
41+
from droid_pipeline import download_trajectories
42+
43+
44+
def load_paths(paths_file: str) -> List[str]:
45+
try:
46+
with open(paths_file, 'r') as f:
47+
return [line.strip() for line in f if line.strip()]
48+
except Exception as e:
49+
print(f"❌ Failed to load paths from {paths_file}: {e}")
50+
return []
51+
52+
53+
def sample_paths(paths: List[str], k: Optional[int], balance: Optional[float], seed: Optional[int]) -> List[str]:
54+
if seed is not None:
55+
random.seed(seed)
56+
if k is None or k <= 0 or k >= len(paths):
57+
return list(paths)
58+
if balance is None:
59+
return random.sample(paths, k)
60+
success_paths = [p for p in paths if 'success' in p.lower()]
61+
failure_paths = [p for p in paths if 'failure' in p.lower()]
62+
k_success = int(round(k * balance))
63+
k_failure = k - k_success
64+
chosen = random.sample(success_paths, min(k_success, len(success_paths)))
65+
chosen += random.sample(failure_paths, min(k_failure, len(failure_paths)))
66+
if len(chosen) < k:
67+
remaining = [p for p in paths if p not in chosen]
68+
chosen += random.sample(remaining, min(k - len(chosen), len(remaining)))
69+
return chosen
70+
71+
72+
def infer_label_from_gcs_path(gcs_path: str) -> Optional[bool]:
73+
g = gcs_path.lower()
74+
if 'success' in g:
75+
return True
76+
if 'failure' in g:
77+
return False
78+
return None
79+
80+
81+
def build_ground_truth_by_name(gcs_paths: List[str]) -> Dict[str, bool]:
82+
gt: Dict[str, bool] = {}
83+
for p in gcs_paths:
84+
traj_name = p.rstrip('/').split('/')[-1]
85+
label = infer_label_from_gcs_path(p)
86+
if label is not None:
87+
gt[traj_name] = label
88+
return gt
89+
90+
91+
def compute_accuracy(results: Dict[str, Dict], gt_by_name: Dict[str, bool]) -> Tuple[int, int, int, float]:
92+
total = 0
93+
predicted = 0
94+
correct = 0
95+
for local_path, res in results.items():
96+
traj_name = os.path.basename(local_path.rstrip('/'))
97+
if traj_name not in gt_by_name:
98+
continue
99+
total += 1
100+
if not res.get('success', False):
101+
continue
102+
predicted += 1
103+
pred = bool(res.get('vlm_prediction', False))
104+
if pred == gt_by_name[traj_name]:
105+
correct += 1
106+
acc = (correct / predicted) if predicted > 0 else 0.0
107+
return total, predicted, correct, acc
108+
109+
110+
def main():
111+
parser = argparse.ArgumentParser(description="Evaluate VLM configs on DROID trajectories")
112+
group = parser.add_mutually_exclusive_group(required=False)
113+
group.add_argument("--paths-file", default="results/all_droid_trajectory_paths.txt",
114+
help="File containing GCS trajectory paths")
115+
group.add_argument("--trajectories", nargs='+', help="GCS paths to DROID trajectory directories")
116+
117+
parser.add_argument("--num-trajectories", type=int, help="Number of trajectories to sample")
118+
parser.add_argument("--balance", type=float, help="Success ratio target in sampling, e.g., 0.5")
119+
parser.add_argument("--seed", type=int, help="Random seed")
120+
parser.add_argument("--max-workers", type=int, default=4, help="Parallel workers for VLM")
121+
parser.add_argument("--eval-root", default="./eval_runs", help="Root folder for evaluation outputs")
122+
123+
parser.add_argument("--frame-counts", type=int, nargs='+', default=[4, 8, 16, 32],
124+
help="Frame counts to evaluate")
125+
parser.add_argument("--passing-methods", nargs='+', default=["stream", "concat"],
126+
choices=["stream", "concat"], help="Passing methods to evaluate")
127+
parser.add_argument("--video-path-keys", nargs='*', default=None,
128+
help="Video path keys from metadata (e.g., ext1_mp4_path wrist_mp4_path). If omitted, auto-detect.")
129+
130+
parser.add_argument("--language-key", default="metadata/language_instruction",
131+
help="Language key to extract from HDF5 fallback")
132+
parser.add_argument("--question", default="Is this trajectory successful?",
133+
help="VLM question")
134+
135+
args = parser.parse_args()
136+
137+
# Resolve GCS paths
138+
if args.trajectories:
139+
gcs_paths = list(args.trajectories)
140+
else:
141+
gcs_paths = load_paths(args.paths_file)
142+
if not gcs_paths:
143+
print("❌ No GCS trajectory paths provided or loaded")
144+
return 1
145+
146+
# Sample
147+
gcs_paths = sample_paths(gcs_paths, args.num_trajectories, args.balance, args.seed)
148+
print(f"📊 Using {len(gcs_paths)} trajectories for evaluation")
149+
150+
# Prepare eval root
151+
eval_root = Path(args.eval_root)
152+
runs_root = eval_root / "runs"
153+
downloads_root = eval_root / "droid_trajectories"
154+
os.makedirs(runs_root, exist_ok=True)
155+
156+
# Download once
157+
print("\n📥 Downloading trajectories once for reuse...")
158+
successful_local_paths, failed = download_trajectories(gcs_paths, str(downloads_root), max_workers=args.max_workers)
159+
if not successful_local_paths:
160+
print("❌ Download failed for all trajectories")
161+
return 1
162+
print(f"✅ Downloaded {len(successful_local_paths)} trajectories; {len(failed)} failed")
163+
164+
# Ground truth by traj_name
165+
gt_by_name = build_ground_truth_by_name(gcs_paths)
166+
# Persist ground truth CSV
167+
with open(eval_root / "ground_truth.csv", 'w', newline='') as f:
168+
writer = csv.writer(f)
169+
writer.writerow(["trajectory_name", "label_success"])
170+
for name, label in sorted(gt_by_name.items()):
171+
writer.writerow([name, int(label)])
172+
173+
# Evaluate configurations
174+
summary_rows = []
175+
configs = []
176+
for method in args.passing_methods:
177+
for n in args.frame_counts:
178+
if args.video_path_keys is None or len(args.video_path_keys) == 0:
179+
configs.append((method, n, None))
180+
else:
181+
for cam_key in args.video_path_keys:
182+
configs.append((method, n, cam_key))
183+
184+
start_all = time.time()
185+
for (method, n, cam_key) in configs:
186+
run_name = f"method={method}_frames={n}" + (f"_cam={cam_key}" if cam_key else "")
187+
run_out_dir = runs_root / run_name
188+
os.makedirs(run_out_dir, exist_ok=True)
189+
190+
print(f"\n🚀 Run: {run_name}")
191+
results = process_trajectories_parallel(
192+
trajectory_paths=successful_local_paths,
193+
image_key="", # not used for DROID directories when MP4s present
194+
language_key=args.language_key,
195+
question=args.question,
196+
max_workers=args.max_workers,
197+
output_dir=str(run_out_dir),
198+
video_path_key=cam_key,
199+
num_frames=n,
200+
passing_method=method,
201+
concat_grid_cols=None
202+
)
203+
204+
# Persist raw results
205+
with open(run_out_dir / "vlm_results.json", 'w') as f:
206+
json.dump(results, f, indent=2)
207+
208+
total, predicted, correct, acc = compute_accuracy(results, gt_by_name)
209+
print(f"📈 Accuracy: {acc:.3f} ({correct}/{predicted}) | total {total}")
210+
211+
# Save metrics per run
212+
with open(run_out_dir / "metrics.csv", 'w', newline='') as f:
213+
writer = csv.writer(f)
214+
writer.writerow(["method", "frames", "camera_key", "total", "predicted", "correct", "accuracy"])
215+
writer.writerow([method, n, cam_key or "auto", total, predicted, correct, f"{acc:.6f}"])
216+
217+
summary_rows.append({
218+
"method": method,
219+
"frames": n,
220+
"camera_key": cam_key or "auto",
221+
"total": total,
222+
"predicted": predicted,
223+
"correct": correct,
224+
"accuracy": acc,
225+
"run_dir": str(run_out_dir)
226+
})
227+
228+
# Write overall summary
229+
with open(eval_root / "summary.csv", 'w', newline='') as f:
230+
writer = csv.writer(f)
231+
writer.writerow(["method", "frames", "camera_key", "total", "predicted", "correct", "accuracy", "run_dir"])
232+
for r in summary_rows:
233+
writer.writerow([r["method"], r["frames"], r["camera_key"], r["total"], r["predicted"], r["correct"], f"{r['accuracy']:.6f}", r["run_dir"]])
234+
235+
elapsed = time.time() - start_all
236+
print(f"\n🎉 Evaluation complete in {elapsed/60:.1f} minutes")
237+
print(f"📁 Outputs in: {eval_root}")
238+
return 0
239+
240+
241+
if __name__ == "__main__":
242+
raise SystemExit(main())
243+
244+

0 commit comments

Comments
 (0)