Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions embodichain/agents/rl/algo/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,13 @@ def _compute_step_group_advantages(
return advantages.view(n_envs, t_steps) * seq_mask

def update(self, rollout: TensorDict) -> Dict[str, float]:
raw_obs = getattr(rollout, "raw_obs", None)
chunk_step = getattr(rollout, "chunk_step", None)
rollout = rollout.clone()
if raw_obs is not None:
rollout.raw_obs = raw_obs
if chunk_step is not None:
rollout.chunk_step = chunk_step
num_envs = rollout.batch_size[0]
if num_envs % self.cfg.group_size != 0:
raise ValueError(
Expand Down Expand Up @@ -147,7 +153,9 @@ def update(self, rollout: TensorDict) -> Dict[str, float]:
advantages = batch["advantage"].detach()
seq_mask_batch = batch["seq_mask"].float()

eval_batch = self.policy.evaluate_actions(batch)
eval_batch = self.policy.evaluate_actions(
batch, rollout=rollout, num_envs=num_envs
)
logprobs = eval_batch["sample_log_prob"]
entropy = eval_batch["entropy"]
ratio = (logprobs - old_logprobs).exp()
Expand All @@ -166,7 +174,9 @@ def update(self, rollout: TensorDict) -> Dict[str, float]:

if self.ref_policy is not None:
with torch.no_grad():
ref_batch = self.ref_policy.evaluate_actions(batch)
ref_batch = self.ref_policy.evaluate_actions(
batch, rollout=rollout, num_envs=num_envs
)
ref_logprobs = ref_batch["sample_log_prob"]
log_ref_over_pi = ref_logprobs - logprobs
kl_per = torch.exp(log_ref_over_pi) - log_ref_over_pi - 1.0
Expand Down
11 changes: 10 additions & 1 deletion embodichain/agents/rl/buffer/standard_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ def __init__(
obs_dim: int,
action_dim: int,
device: torch.device,
use_raw_obs: bool = False,
) -> None:
self.num_envs = num_envs
self.rollout_len = rollout_len
self.obs_dim = obs_dim
self.action_dim = action_dim
self.device = device
self.use_raw_obs = use_raw_obs
self._rollout = self._allocate_rollout()
self._is_full = False

Expand All @@ -54,6 +56,8 @@ def start_rollout(self) -> TensorDict:
if self._is_full:
raise RuntimeError("RolloutBuffer already contains a rollout.")
self._clear_dynamic_fields()
if self.use_raw_obs:
self._rollout.raw_obs = [None] * (self.rollout_len + 1)
return self._rollout

def add(self, rollout: TensorDict) -> None:
Expand Down Expand Up @@ -93,7 +97,7 @@ def is_full(self) -> bool:

def _allocate_rollout(self) -> TensorDict:
"""Preallocate rollout storage with uniform `[num_envs, time + 1]` shape."""
return TensorDict(
td = TensorDict(
{
"obs": torch.empty(
self.num_envs,
Expand Down Expand Up @@ -149,12 +153,17 @@ def _allocate_rollout(self) -> TensorDict:
batch_size=[self.num_envs, self.rollout_len + 1],
device=self.device,
)
return td

def _clear_dynamic_fields(self) -> None:
"""Drop algorithm-added fields before reusing the shared rollout."""
for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"):
if key in self._rollout.keys():
del self._rollout[key]
if self.use_raw_obs and hasattr(self._rollout, "raw_obs"):
delattr(self._rollout, "raw_obs")
if hasattr(self._rollout, "chunk_step"):
delattr(self._rollout, "chunk_step")
self._reset_padding_slot()

def _reset_padding_slot(self) -> None:
Expand Down
10 changes: 8 additions & 2 deletions embodichain/agents/rl/buffer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict:
if key in rollout.keys():
td[key] = rollout[key][:, :-1]

if hasattr(rollout, "chunk_step") and rollout.chunk_step is not None:
td["chunk_step"] = rollout.chunk_step

if flatten:
return td.reshape(num_envs * time_dim)
return td
Expand All @@ -72,6 +75,9 @@ def iterate_minibatches(
) -> Iterator[TensorDict]:
"""Yield shuffled minibatches from a flattened rollout."""
total = rollout.batch_size[0]
indices = torch.randperm(total, device=device)
indices = torch.randperm(total)
for start in range(0, total, batch_size):
yield rollout[indices[start : start + batch_size]]
batch_indices = indices[start : start + batch_size]
batch = rollout[batch_indices].clone()
batch["_indices"] = batch_indices
yield batch
131 changes: 120 additions & 11 deletions embodichain/agents/rl/collector/sync_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,119 @@ def collect(
if self._supports_shared_rollout:
self.env.set_rollout_buffer(rollout)

initial_obs = flatten_dict_observation(self.obs_td)
rollout["obs"][:, 0] = initial_obs
for step_idx in range(num_steps):
step_td = TensorDict(
{"obs": rollout["obs"][:, step_idx]},
batch_size=[rollout.batch_size[0]],
use_raw_obs = getattr(self.policy, "use_raw_obs", False)
raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None

if use_raw_obs:
if raw_obs_list is None:
raise ValueError(
"Policy requires raw observations, "
"but the provided rollout TensorDict has no 'raw_obs' buffer. "
"Create the rollout via RolloutBuffer or "
"start_rollout so that 'raw_obs' is allocated."
)
try:
raw_obs_len = len(raw_obs_list)
except TypeError:
raise ValueError(
"Rollout field 'raw_obs' must be an indexable sequence of length "
f"{num_steps + 1} when policy.use_raw_obs=True."
)
expected_len = num_steps + 1
if raw_obs_len != expected_len:
raise ValueError(
"Rollout 'raw_obs' length mismatch: "
f"expected {expected_len} (num_steps + 1), got {raw_obs_len}. "
"Ensure the rollout was created with use_raw_obs=True and "
"its time dimension matches the requested num_steps."
)

action_chunk_size = getattr(self.policy, "action_chunk_size", 0)
use_action_chunk = (
getattr(self.policy, "use_action_chunk", False) and action_chunk_size > 0
)
cached_chunk = None

if use_action_chunk:
rollout.chunk_step = torch.zeros(
self.env.num_envs,
num_steps,
dtype=torch.long,
device=self.device,
)
step_td = self.policy.get_action(step_td)

if use_raw_obs and raw_obs_list is not None:
raw_obs_list[0] = self.obs_td
rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td)
else:
rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td)

for step_idx in range(num_steps):
step_in_chunk = step_idx % action_chunk_size if use_action_chunk else 0

# At chunk boundary, or cached invalidated by env reset, we need a new chunk
need_new_chunk = use_action_chunk and (
step_in_chunk == 0 or cached_chunk is None
)

if need_new_chunk:
if use_raw_obs and raw_obs_list is not None:
step_td = TensorDict(
{"obs": raw_obs_list[step_idx]},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
else:
step_td = TensorDict(
{"obs": rollout["obs"][:, step_idx]},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
step_td = self.policy.get_action(step_td)
cached_chunk = step_td["action_chunk"]
action = step_td["action"]
effective_step_in_chunk = 0
elif use_action_chunk and cached_chunk is not None:
action = cached_chunk[:, step_in_chunk]
effective_step_in_chunk = step_in_chunk
step_td = TensorDict(
{
"action": action,
"sample_log_prob": torch.zeros(
action.shape[0], device=self.device, dtype=torch.float32
),
"value": torch.zeros(
action.shape[0], device=self.device, dtype=torch.float32
),
},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
else:
if use_raw_obs and raw_obs_list is not None:
step_td = TensorDict(
{"obs": raw_obs_list[step_idx]},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
else:
step_td = TensorDict(
{"obs": rollout["obs"][:, step_idx]},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
step_td = self.policy.get_action(step_td)
action = step_td["action"]

next_obs, reward, terminated, truncated, env_info = self.env.step(
self._to_action_dict(step_td["action"])
self._to_action_dict(action)
)
next_obs_td = dict_to_tensordict(next_obs, self.device)
if use_action_chunk:
rollout.chunk_step[:, step_idx] = effective_step_in_chunk
# Invalidate cached_chunk on any env reset to avoid using old chunk for new episode
if (terminated | truncated).any():
cached_chunk = None
self._write_step(
rollout=rollout,
step_idx=step_idx,
Expand All @@ -95,7 +194,11 @@ def collect(
terminated=terminated,
truncated=truncated,
)
rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td)
if use_raw_obs and raw_obs_list is not None:
raw_obs_list[step_idx + 1] = next_obs_td
rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td)
else:
rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td)

if on_step_callback is not None:
on_step_callback(rollout[:, step_idx], env_info)
Expand All @@ -107,7 +210,12 @@ def collect(

def _attach_final_value(self, rollout: TensorDict) -> None:
"""Populate the bootstrap value for the final observed state."""
final_obs = rollout["obs"][:, -1]
use_raw_obs = getattr(self.policy, "use_raw_obs", False)
raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None
if use_raw_obs and raw_obs_list is not None:
final_obs = raw_obs_list[-1]
else:
final_obs = rollout["obs"][:, -1]
last_next_td = TensorDict(
{"obs": final_obs},
batch_size=[rollout.batch_size[0]],
Expand Down Expand Up @@ -155,8 +263,9 @@ def _write_env_step(

def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None:
"""Validate rollout layout expected by the collector."""
obs_dim = rollout["obs"].shape[-1]
expected_shapes = {
"obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim),
"obs": (self.env.num_envs, num_steps + 1, obs_dim),
"action": (self.env.num_envs, num_steps + 1, self.policy.action_dim),
"sample_log_prob": (self.env.num_envs, num_steps + 1),
"value": (self.env.num_envs, num_steps + 1),
Expand Down
21 changes: 19 additions & 2 deletions embodichain/agents/rl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import inspect
from typing import Dict, Type
from typing import Any, Dict, Optional, Type

from gymnasium import spaces
import torch
Expand All @@ -26,6 +26,7 @@
from .actor_only import ActorOnly
from .policy import Policy
from .mlp import MLP
from .vla_policy import VLAPolicy

# In-module policy registry
_POLICY_REGISTRY: Dict[str, Type[Policy]] = {}
Expand Down Expand Up @@ -63,13 +64,16 @@ def build_policy(
device: torch.device,
actor: torch.nn.Module | None = None,
critic: torch.nn.Module | None = None,
env: Optional[Any] = None,
) -> Policy:
"""Build a policy from config using spaces for extensibility.

Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while
custom policies may accept richer `obs_space` / `action_space` inputs.
For vla_policy, pass env to enable set_env and _load_vla initialization.
"""
name = policy_block["name"].lower()

if name not in _POLICY_REGISTRY:
available = ", ".join(get_registered_policy_names())
raise ValueError(
Expand Down Expand Up @@ -119,7 +123,18 @@ def build_policy(
build_kwargs["actor"] = actor
if "critic" in init_params and critic is not None:
build_kwargs["critic"] = critic
return policy_cls(**build_kwargs)
if "policy_cfg" in init_params:
build_kwargs["policy_cfg"] = policy_block
policy = policy_cls(**build_kwargs)
if name == "vla_policy":
if env is None:
raise ValueError(
"VLAPolicy requires an 'env' argument to be passed to build_policy "
"so that set_env and _load_vla can be called before use."
)
policy.set_env(env)
policy._load_vla()
return policy


def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP:
Expand All @@ -143,10 +158,12 @@ def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP:
# default registrations
register_policy("actor_critic", ActorCritic)
register_policy("actor_only", ActorOnly)
register_policy("vla_policy", VLAPolicy)

__all__ = [
"ActorCritic",
"ActorOnly",
"VLAPolicy",
"register_policy",
"get_registered_policy_names",
"build_policy",
Expand Down
2 changes: 1 addition & 1 deletion embodichain/agents/rl/models/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict:
tensordict["value"] = self.critic(tensordict["obs"]).squeeze(-1)
return tensordict

def evaluate_actions(self, tensordict: TensorDict) -> TensorDict:
def evaluate_actions(self, tensordict: TensorDict, **kwargs) -> TensorDict:
obs = tensordict["obs"]
action = tensordict["action"]
dist = self._distribution(obs)
Expand Down
2 changes: 1 addition & 1 deletion embodichain/agents/rl/models/actor_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict:
)
return tensordict

def evaluate_actions(self, tensordict: TensorDict) -> TensorDict:
def evaluate_actions(self, tensordict: TensorDict, **kwargs) -> TensorDict:
obs = tensordict["obs"]
action = tensordict["action"]
dist = self._distribution(obs)
Expand Down
Loading
Loading