Skip to content

Commit 10b489a

Browse files
author
Your Name
committed
d
1 parent 3c9890e commit 10b489a

File tree

2 files changed

+112
-9
lines changed

2 files changed

+112
-9
lines changed

examples/droid_h5/droid_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def main():
713713
parser.add_argument(
714714
"--num-trajectories",
715715
type=int,
716-
default=30,
716+
default=100,
717717
help="Number of trajectories to randomly select (default: 30)"
718718
)
719719
parser.add_argument(

examples/droid_h5/simple_vlm_processing.py

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,72 @@ def extract_frames_from_mp4(mp4_path: str, max_frames: int = 10) -> List[np.ndar
9191
return frames
9292

9393

94+
def create_state_visualization(data: Dict[str, Any], max_frames: int = 10) -> List[np.ndarray]:
95+
"""
96+
Create visualization images from trajectory state data when no camera images are available.
97+
98+
Args:
99+
data: Trajectory data dictionary
100+
max_frames: Maximum number of visualization frames to create
101+
102+
Returns:
103+
List of visualization images as numpy arrays
104+
"""
105+
try:
106+
# Find state-related keys (joint positions, gripper states, etc.)
107+
state_keys = [k for k in data.keys() if any(term in k.lower() for term in
108+
['state', 'joint', 'position', 'gripper', 'action', 'pose'])]
109+
110+
if not state_keys:
111+
print(f" ⚠️ No state data found for visualization")
112+
return []
113+
114+
# Use the first available state key
115+
state_key = state_keys[0]
116+
state_data = data[state_key]
117+
118+
print(f" 📊 Creating state visualization from {state_key}")
119+
120+
if len(state_data) == 0:
121+
return []
122+
123+
# Select frames to visualize
124+
num_frames = min(max_frames, len(state_data))
125+
if len(state_data) > num_frames:
126+
indices = np.linspace(0, len(state_data) - 1, num_frames, dtype=int)
127+
else:
128+
indices = list(range(len(state_data)))
129+
130+
# Create simple plot-based visualizations
131+
visualizations = []
132+
for i, idx in enumerate(indices):
133+
fig, ax = plt.subplots(figsize=(8, 6))
134+
135+
state_vec = state_data[idx] if hasattr(state_data[idx], '__len__') else [state_data[idx]]
136+
137+
# Create a simple bar plot of the state values
138+
ax.bar(range(len(state_vec)), state_vec)
139+
ax.set_title(f'State at timestep {idx} ({i+1}/{num_frames})')
140+
ax.set_xlabel('State dimension')
141+
ax.set_ylabel('Value')
142+
ax.grid(True)
143+
144+
# Convert plot to image
145+
fig.canvas.draw()
146+
buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
147+
buf = buf.reshape(fig.canvas.get_width_height()[::-1] + (3,))
148+
149+
visualizations.append(buf.copy())
150+
plt.close(fig)
151+
152+
print(f" ✅ Created {len(visualizations)} state visualizations")
153+
return visualizations
154+
155+
except Exception as e:
156+
print(f" ❌ Failed to create state visualization: {e}")
157+
return []
158+
159+
94160
def find_video_files_in_trajectory(trajectory_dir: str, video_path_key: str = None) -> List[str]:
95161
"""
96162
Find MP4 video files in a DROID trajectory directory.
@@ -128,13 +194,28 @@ def find_video_files_in_trajectory(trajectory_dir: str, video_path_key: str = No
128194

129195
if not video_files:
130196
# Fallback to original logic - find all MP4 files
131-
mp4_pattern = os.path.join(trajectory_dir, "recordings", "MP4", "*.mp4")
132-
video_files = glob.glob(mp4_pattern)
133-
134-
# Filter out stereo files (we want the mono camera feeds)
135-
video_files = [f for f in video_files if '-stereo.mp4' not in f]
197+
# Try multiple potential directories
198+
potential_dirs = [
199+
os.path.join(trajectory_dir, "recordings", "MP4"),
200+
os.path.join(trajectory_dir, "recordings"),
201+
trajectory_dir
202+
]
203+
204+
for search_dir in potential_dirs:
205+
if os.path.exists(search_dir):
206+
mp4_pattern = os.path.join(search_dir, "*.mp4")
207+
found_files = glob.glob(mp4_pattern)
208+
209+
# Filter out stereo files (we want the mono camera feeds)
210+
found_files = [f for f in found_files if '-stereo.mp4' not in f]
211+
212+
if found_files:
213+
video_files = found_files
214+
print(f" 📁 Found {len(video_files)} video files in {search_dir}: {[os.path.basename(f) for f in video_files]}")
215+
break
136216

137-
print(f" 📁 Found {len(video_files)} video files: {[os.path.basename(f) for f in video_files]}")
217+
if not video_files:
218+
print(f" ⚠️ No video files found in any potential directory")
138219

139220
return video_files
140221

@@ -192,11 +273,32 @@ def process_single_trajectory(
192273
images = extract_frames_from_mp4(primary_video, max_frames=10)
193274

194275
if not images:
195-
print(f" ⚠️ Failed to extract frames from video, falling back to state visualization")
276+
print(f" ⚠️ Failed to extract frames from video, trying HDF5 fallback")
196277
use_state_visualization = True
197278
else:
198-
print(f" ⚠️ No video files found in DROID directory")
279+
print(f" ⚠️ No video files found in DROID directory, trying HDF5 fallback")
199280
use_state_visualization = True
281+
282+
# Try to load images from HDF5 as fallback
283+
hdf5_file = os.path.join(trajectory_path, "trajectory.h5")
284+
if os.path.exists(hdf5_file):
285+
try:
286+
print(f" 📂 Attempting to load images from HDF5 fallback")
287+
traj = Trajectory(hdf5_file, mode="r")
288+
data = traj.load()
289+
traj.close()
290+
291+
# Look for any image keys
292+
image_keys = [k for k in data.keys() if 'image' in k.lower()]
293+
if image_keys:
294+
fallback_key = image_keys[0]
295+
images = data[fallback_key]
296+
use_state_visualization = False
297+
print(f" 📷 Found fallback images: {fallback_key} with {len(images)} frames")
298+
299+
except Exception as hdf5_e:
300+
print(f" ⚠️ HDF5 fallback also failed: {hdf5_e}")
301+
# Keep use_state_visualization = True
200302

201303
# Try to extract language instruction from HDF5 file
202304
hdf5_file = os.path.join(trajectory_path, "trajectory.h5")
@@ -227,6 +329,7 @@ def process_single_trajectory(
227329

228330
except Exception as e:
229331
print(f" ⚠️ Could not load language instruction from HDF5: {e}")
332+
# Continue without language instruction rather than failing completely
230333

231334
else:
232335
# Traditional trajectory file format

0 commit comments

Comments
 (0)