Skip to content

Commit 0e38eb9

Browse files
committed
feat(wrapper)!: added hdf5 recorder wrapper
1 parent 6c92668 commit 0e38eb9

1 file changed

Lines changed: 204 additions & 2 deletions

File tree

python/rcsss/envs/wrappers.py

Lines changed: 204 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
from datetime import datetime
32
import os
43
from pathlib import Path
@@ -13,8 +12,11 @@
1312
from PIL import Image
1413
from rcsss.camera.hw import BaseHardwareCameraSet
1514

15+
import subprocess
16+
import h5py
1617

17-
class StorageWrapper(gym.Wrapper):
18+
19+
class StorageWrapperNumpy(gym.Wrapper):
1820
# TODO: this should also record the instruction
1921
FILE = "episode_{}.npz"
2022
GIF = "{}_episode_{}_{}.gif"
@@ -135,6 +137,206 @@ def log_files(self, file2content: dict[str, str]):
135137
f.write(content)
136138

137139

140+
# TODO: gifs should not be created after each episode, but there should rather be tool
141+
# to create them from a dataset, how about video?
142+
class StorageWrapperHDF5(gym.Wrapper):
143+
FILE = "data.h5"
144+
GIF = "{}_{}.gif"
145+
FOLDER = "experiment_{}"
146+
GIF_DURATION_S = 0.5
147+
148+
def __init__(
149+
self,
150+
env: gym.Env,
151+
path: str,
152+
instruction: str | None = None,
153+
description: str | None = None,
154+
gif: bool = True,
155+
camera_set: BaseHardwareCameraSet | None = None,
156+
):
157+
super().__init__(env)
158+
self.episode_count = 0
159+
self.step_count = 0
160+
self.timestamp = str(datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
161+
self.gif = gif
162+
self.prev_obs: dict | None = None
163+
self.datasets = {}
164+
self.camera_set = camera_set
165+
166+
# Make folders
167+
self.path = Path(path) # / self.FOLDER.format(self.timestamp)
168+
Path(self.path).mkdir(parents=True, exist_ok=True)
169+
if description is None:
170+
# Write a small description from input into file
171+
description = input("Please enter a description for this experiment: ")
172+
self.description = description
173+
174+
# with open(self.path / "description.txt", "w") as f:
175+
# f.write(self.description)
176+
177+
if instruction is None:
178+
# Write instruction from input into file
179+
instruction = input("Instruction: ")
180+
self.language_instruction = str(instruction)
181+
# Open HDF5 file in append mode
182+
self.h5file = h5py.File(self.path / self.FILE, "a")
183+
# Check if instruction group exists
184+
if self.language_instruction in self.h5file:
185+
self.instruction_group = self.h5file[self.language_instruction]
186+
else:
187+
self.instruction_group = self.h5file.create_group(self.language_instruction)
188+
189+
self.gif_path = self.path / "gifs"
190+
if self.gif:
191+
self.gif_path.mkdir(parents=True, exist_ok=True)
192+
193+
def append_to_hdf5(self, group, data_dict, index):
194+
for key, value in data_dict.items():
195+
if isinstance(value, dict):
196+
# Handle subgroup
197+
if key not in group:
198+
subgroup = group.create_group(key)
199+
else:
200+
subgroup = group[key]
201+
self.append_to_hdf5(subgroup, value, index)
202+
else:
203+
# Handle dataset
204+
dataset_name = key
205+
full_dataset_path = group.name + "/" + dataset_name
206+
if full_dataset_path not in self.datasets:
207+
# First time seeing this dataset
208+
# Determine dtype
209+
if isinstance(value, str):
210+
# Variable-length string
211+
dtype = h5py.string_dtype(encoding="utf-8")
212+
shape = ()
213+
elif np.isscalar(value):
214+
# Numeric scalar
215+
dtype = type(value)
216+
shape = ()
217+
elif isinstance(value, np.ndarray):
218+
# Numpy array
219+
dtype = value.dtype
220+
shape = value.shape
221+
else:
222+
# Other types, try to convert to numpy array
223+
try:
224+
value = np.array(value)
225+
dtype = value.dtype
226+
shape = value.shape
227+
except Exception as e:
228+
raise ValueError(f"Unsupported data type for key '{key}': {type(value)}") from e
229+
# Create dataset
230+
initial_shape = (index + 1,) + shape
231+
maxshape = (None,) + shape
232+
dataset = group.create_dataset(
233+
dataset_name, shape=initial_shape, maxshape=maxshape, chunks=True, dtype=dtype
234+
)
235+
self.datasets[full_dataset_path] = dataset
236+
else:
237+
dataset = self.datasets[full_dataset_path]
238+
if dataset.shape[0] <= index:
239+
new_size = index + 1
240+
dataset.resize(new_size, axis=0)
241+
# Store value
242+
if isinstance(value, str):
243+
dataset[index] = value
244+
elif np.isscalar(value):
245+
dataset[index] = value
246+
else:
247+
dataset[index, ...] = value
248+
249+
def flush(self):
250+
"""Writes data to disk and generates GIFs if enabled."""
251+
if self.step_count == 0:
252+
return
253+
# Flush HDF5 file
254+
self.h5file.flush()
255+
# Stop camera recording if applicable
256+
if self.camera_set is not None and self.camera_set.recording_ongoing():
257+
self.camera_set.stop_video()
258+
# Generate GIFs if enabled
259+
if self.gif:
260+
for key in ["side", "right_side", "bird_eye", "left_side", "front"]:
261+
img_dataset_path = f"observation/frames/{key}/rgb"
262+
if img_dataset_path in self.episode_group:
263+
dataset = self.episode_group[img_dataset_path]
264+
imgs = []
265+
previous_timestamp = 0
266+
timestamp_dataset = self.episode_group["timestamp"]
267+
for idx in range(min(len(dataset), len(timestamp_dataset))):
268+
# Skip images that have timestamps closer together than self.GIF_DURATION_S
269+
img = dataset[idx]
270+
timestamp = timestamp_dataset[idx]
271+
if timestamp - previous_timestamp < self.GIF_DURATION_S:
272+
continue
273+
previous_timestamp = timestamp
274+
imgs.append(Image.fromarray(img))
275+
if imgs:
276+
imgs[0].save(
277+
self.gif_path / self.GIF.format(self.timestamp, key),
278+
save_all=True,
279+
append_images=imgs[1:],
280+
duration=self.GIF_DURATION_S * 1000,
281+
loop=0,
282+
)
283+
# Reset datasets for the next episode
284+
self.datasets = {}
285+
self.episode_count += 1
286+
287+
def step(self, action: dict) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
288+
obs, reward, terminated, truncated, info = super().step(action)
289+
# Delay observation by one time step
290+
act_obs = {"action": action, "observation": self.prev_obs, "timestamp": datetime.now().timestamp()}
291+
self.prev_obs = obs # Update prev_obs for next step
292+
# Append data to HDF5
293+
self.append_to_hdf5(self.episode_group, act_obs, self.step_count)
294+
self.step_count += 1
295+
return obs, reward, terminated, truncated, info
296+
297+
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[Any, dict[str, Any]]:
298+
self.flush()
299+
self.step_count = 0
300+
self.prev_obs = None
301+
# Create a new episode group
302+
episode_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_%f")
303+
self.episode_group = self.instruction_group.create_group(episode_name)
304+
self.datasets = {}
305+
# Get git metadata
306+
try:
307+
git_diff = subprocess.check_output(["git", "diff", "--submodule=diff"]).decode("utf-8")
308+
git_commit_id = subprocess.check_output(["git", "log", "--format=%H", "-n", "1"]).decode("utf-8").strip()
309+
git_submodule_status = subprocess.check_output(["git", "submodule", "status"]).decode("utf-8")
310+
except Exception as e:
311+
git_diff = ""
312+
git_commit_id = ""
313+
git_submodule_status = ""
314+
# Store git info as attributes
315+
self.episode_group.attrs["git_diff"] = git_diff
316+
self.episode_group.attrs["git_commit_id"] = git_commit_id
317+
self.episode_group.attrs["git_submodule_status"] = git_submodule_status
318+
# Also store description and language instruction
319+
self.episode_group.attrs["description"] = self.description
320+
self.episode_group.attrs["language_instruction"] = self.language_instruction
321+
result = super().reset(seed=seed, options=options)
322+
self.prev_obs = result[0] # Initialize prev_obs
323+
return result
324+
325+
def close(self):
326+
self.flush()
327+
self.h5file.close()
328+
return super().close()
329+
330+
@property
331+
def logger_dir(self):
332+
return self.path
333+
334+
def log_files(self, file2content: dict[str, str]):
335+
for fn, content in file2content.items():
336+
with open(self.path / fn, "w") as f:
337+
f.write(content)
338+
339+
138340
def listdict2dictlist(LD):
139341
return {k: [dic[k] for dic in LD] for k in LD[0]}
140342

0 commit comments

Comments
 (0)