@@ -91,6 +91,72 @@ def extract_frames_from_mp4(mp4_path: str, max_frames: int = 10) -> List[np.ndar
9191 return frames
9292
9393
94+ def create_state_visualization (data : Dict [str , Any ], max_frames : int = 10 ) -> List [np .ndarray ]:
95+ """
96+ Create visualization images from trajectory state data when no camera images are available.
97+
98+ Args:
99+ data: Trajectory data dictionary
100+ max_frames: Maximum number of visualization frames to create
101+
102+ Returns:
103+ List of visualization images as numpy arrays
104+ """
105+ try :
106+ # Find state-related keys (joint positions, gripper states, etc.)
107+ state_keys = [k for k in data .keys () if any (term in k .lower () for term in
108+ ['state' , 'joint' , 'position' , 'gripper' , 'action' , 'pose' ])]
109+
110+ if not state_keys :
111+ print (f" ⚠️ No state data found for visualization" )
112+ return []
113+
114+ # Use the first available state key
115+ state_key = state_keys [0 ]
116+ state_data = data [state_key ]
117+
118+ print (f" 📊 Creating state visualization from { state_key } " )
119+
120+ if len (state_data ) == 0 :
121+ return []
122+
123+ # Select frames to visualize
124+ num_frames = min (max_frames , len (state_data ))
125+ if len (state_data ) > num_frames :
126+ indices = np .linspace (0 , len (state_data ) - 1 , num_frames , dtype = int )
127+ else :
128+ indices = list (range (len (state_data )))
129+
130+ # Create simple plot-based visualizations
131+ visualizations = []
132+ for i , idx in enumerate (indices ):
133+ fig , ax = plt .subplots (figsize = (8 , 6 ))
134+
135+ state_vec = state_data [idx ] if hasattr (state_data [idx ], '__len__' ) else [state_data [idx ]]
136+
137+ # Create a simple bar plot of the state values
138+ ax .bar (range (len (state_vec )), state_vec )
139+ ax .set_title (f'State at timestep { idx } ({ i + 1 } /{ num_frames } )' )
140+ ax .set_xlabel ('State dimension' )
141+ ax .set_ylabel ('Value' )
142+ ax .grid (True )
143+
144+ # Convert plot to image
145+ fig .canvas .draw ()
146+ buf = np .frombuffer (fig .canvas .tostring_rgb (), dtype = np .uint8 )
147+ buf = buf .reshape (fig .canvas .get_width_height ()[::- 1 ] + (3 ,))
148+
149+ visualizations .append (buf .copy ())
150+ plt .close (fig )
151+
152+ print (f" ✅ Created { len (visualizations )} state visualizations" )
153+ return visualizations
154+
155+ except Exception as e :
156+ print (f" ❌ Failed to create state visualization: { e } " )
157+ return []
158+
159+
94160def find_video_files_in_trajectory (trajectory_dir : str , video_path_key : str = None ) -> List [str ]:
95161 """
96162 Find MP4 video files in a DROID trajectory directory.
@@ -128,13 +194,28 @@ def find_video_files_in_trajectory(trajectory_dir: str, video_path_key: str = No
128194
129195 if not video_files :
130196 # Fallback to original logic - find all MP4 files
131- mp4_pattern = os .path .join (trajectory_dir , "recordings" , "MP4" , "*.mp4" )
132- video_files = glob .glob (mp4_pattern )
133-
134- # Filter out stereo files (we want the mono camera feeds)
135- video_files = [f for f in video_files if '-stereo.mp4' not in f ]
197+ # Try multiple potential directories
198+ potential_dirs = [
199+ os .path .join (trajectory_dir , "recordings" , "MP4" ),
200+ os .path .join (trajectory_dir , "recordings" ),
201+ trajectory_dir
202+ ]
203+
204+ for search_dir in potential_dirs :
205+ if os .path .exists (search_dir ):
206+ mp4_pattern = os .path .join (search_dir , "*.mp4" )
207+ found_files = glob .glob (mp4_pattern )
208+
209+ # Filter out stereo files (we want the mono camera feeds)
210+ found_files = [f for f in found_files if '-stereo.mp4' not in f ]
211+
212+ if found_files :
213+ video_files = found_files
214+ print (f" 📁 Found { len (video_files )} video files in { search_dir } : { [os .path .basename (f ) for f in video_files ]} " )
215+ break
136216
137- print (f" 📁 Found { len (video_files )} video files: { [os .path .basename (f ) for f in video_files ]} " )
217+ if not video_files :
218+ print (f" ⚠️ No video files found in any potential directory" )
138219
139220 return video_files
140221
@@ -192,11 +273,32 @@ def process_single_trajectory(
192273 images = extract_frames_from_mp4 (primary_video , max_frames = 10 )
193274
194275 if not images :
195- print (f" ⚠️ Failed to extract frames from video, falling back to state visualization " )
276+ print (f" ⚠️ Failed to extract frames from video, trying HDF5 fallback " )
196277 use_state_visualization = True
197278 else :
198- print (f" ⚠️ No video files found in DROID directory" )
279+ print (f" ⚠️ No video files found in DROID directory, trying HDF5 fallback " )
199280 use_state_visualization = True
281+
282+ # Try to load images from HDF5 as fallback
283+ hdf5_file = os .path .join (trajectory_path , "trajectory.h5" )
284+ if os .path .exists (hdf5_file ):
285+ try :
286+ print (f" 📂 Attempting to load images from HDF5 fallback" )
287+ traj = Trajectory (hdf5_file , mode = "r" )
288+ data = traj .load ()
289+ traj .close ()
290+
291+ # Look for any image keys
292+ image_keys = [k for k in data .keys () if 'image' in k .lower ()]
293+ if image_keys :
294+ fallback_key = image_keys [0 ]
295+ images = data [fallback_key ]
296+ use_state_visualization = False
297+ print (f" 📷 Found fallback images: { fallback_key } with { len (images )} frames" )
298+
299+ except Exception as hdf5_e :
300+ print (f" ⚠️ HDF5 fallback also failed: { hdf5_e } " )
301+ # Keep use_state_visualization = True
200302
201303 # Try to extract language instruction from HDF5 file
202304 hdf5_file = os .path .join (trajectory_path , "trajectory.h5" )
@@ -227,6 +329,7 @@ def process_single_trajectory(
227329
228330 except Exception as e :
229331 print (f" ⚠️ Could not load language instruction from HDF5: { e } " )
332+ # Continue without language instruction rather than failing completely
230333
231334 else :
232335 # Traditional trajectory file format
0 commit comments