Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions embodichain/lab/gym/envs/managers/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ class FunctorCfg:
in the :class:`SceneEntityCfg` object.
"""

extra: dict[str, Any] = dict()
"""Extra metadata about the functor. Defaults to an empty dict.

This can be used to store additional configuration information such as the output shape
of observation functors, which can be used for pre-allocating buffers.

For observation functors, common keys include:
- ``shape``: A tuple defining the output shape of the functor (excluding num_envs dimension).
"""


@configclass
class EventCfg(FunctorCfg):
Expand Down
37 changes: 37 additions & 0 deletions embodichain/lab/gym/utils/gym_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,12 @@ def init_rollout_buffer_from_config(
) -> TensorDict:
"""Initialize a rollout buffer based on the environment configuration.

The function creates a rollout buffer containing:
- Basic observations: ``robot/qpos``, ``robot/qvel``, ``robot/qf``
- Sensor observations: ``sensor/<uid>`` for each sensor in config
- Extra observations: Custom observations from observation functors in ``add`` mode
that have a ``shape`` specified in their ``extra`` parameter

Args:
config (dict): The environment configuration dictionary.
max_episode_steps (int): The number of steps in an episode.
Expand All @@ -975,6 +981,32 @@ def init_rollout_buffer_from_config(
TensorDict: A TensorDict containing the initialized rollout buffer with keys 'obs', 'actions' and 'rewards'.
"""

# TODO: Currently we use this method to pre-allocate a rollout buffer with fixed size for simplicity.
# Parse extra observations from observation functors in 'add' mode
extra_obs_desc = {}
env_config = config.get("env", {})
if "observations" in env_config:
for obs_name, obs_params in env_config["observations"].items():
obs_mode = obs_params.get("mode", "modify")
if obs_mode == "add":
obs_extra = obs_params.get("extra", {})
shape = obs_extra.get("shape", None)
if shape is not None:
# Ensure shape is a tuple
if isinstance(shape, list):
shape = tuple(shape)

key = obs_params.get(
"name", obs_name
) # Use 'name' if provided, otherwise use obs_name

# Create buffer with shape (batch_size, max_episode_steps, *shape)
extra_obs_desc[key] = torch.zeros(
(batch_size, max_episode_steps, *shape),
dtype=torch.float32,
device=device,
)

# Parse sensor
sensor_desc = {}
for cfg in config.get("sensor", []):
Expand Down Expand Up @@ -1095,4 +1127,9 @@ def init_rollout_buffer_from_config(
sensor_desc, batch_size=[batch_size, max_episode_steps], device=device
)

# Add extra observations from functors in 'add' mode
if extra_obs_desc:
for obs_name, obs_tensor in extra_obs_desc.items():
assign_data_to_dict(rollout_buffer["obs"], obs_name, obs_tensor)

return rollout_buffer
Loading
Loading