diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 12f7c32f..8f96aa42 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -112,7 +112,13 @@ def _compute_step_group_advantages( return advantages.view(n_envs, t_steps) * seq_mask def update(self, rollout: TensorDict) -> Dict[str, float]: + raw_obs = getattr(rollout, "raw_obs", None) + chunk_step = getattr(rollout, "chunk_step", None) rollout = rollout.clone() + if raw_obs is not None: + rollout.raw_obs = raw_obs + if chunk_step is not None: + rollout.chunk_step = chunk_step num_envs = rollout.batch_size[0] if num_envs % self.cfg.group_size != 0: raise ValueError( @@ -147,7 +153,9 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: advantages = batch["advantage"].detach() seq_mask_batch = batch["seq_mask"].float() - eval_batch = self.policy.evaluate_actions(batch) + eval_batch = self.policy.evaluate_actions( + batch, rollout=rollout, num_envs=num_envs + ) logprobs = eval_batch["sample_log_prob"] entropy = eval_batch["entropy"] ratio = (logprobs - old_logprobs).exp() @@ -166,7 +174,9 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: if self.ref_policy is not None: with torch.no_grad(): - ref_batch = self.ref_policy.evaluate_actions(batch) + ref_batch = self.ref_policy.evaluate_actions( + batch, rollout=rollout, num_envs=num_envs + ) ref_logprobs = ref_batch["sample_log_prob"] log_ref_over_pi = ref_logprobs - logprobs kl_per = torch.exp(log_ref_over_pi) - log_ref_over_pi - 1.0 diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index 2df69f86..a0dad10a 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -40,12 +40,14 @@ def __init__( obs_dim: int, action_dim: int, device: torch.device, + use_raw_obs: bool = False, ) -> None: self.num_envs = num_envs self.rollout_len = rollout_len self.obs_dim = obs_dim self.action_dim = action_dim self.device = device + self.use_raw_obs = use_raw_obs self._rollout = self._allocate_rollout() self._is_full = False @@ -54,6 +56,8 @@ def start_rollout(self) -> TensorDict: if self._is_full: raise RuntimeError("RolloutBuffer already contains a rollout.") self._clear_dynamic_fields() + if self.use_raw_obs: + self._rollout.raw_obs = [None] * (self.rollout_len + 1) return self._rollout def add(self, rollout: TensorDict) -> None: @@ -93,7 +97,7 @@ def is_full(self) -> bool: def _allocate_rollout(self) -> TensorDict: """Preallocate rollout storage with uniform `[num_envs, time + 1]` shape.""" - return TensorDict( + td = TensorDict( { "obs": torch.empty( self.num_envs, @@ -149,12 +153,17 @@ def _allocate_rollout(self) -> TensorDict: batch_size=[self.num_envs, self.rollout_len + 1], device=self.device, ) + return td def _clear_dynamic_fields(self) -> None: """Drop algorithm-added fields before reusing the shared rollout.""" for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): if key in self._rollout.keys(): del self._rollout[key] + if self.use_raw_obs and hasattr(self._rollout, "raw_obs"): + delattr(self._rollout, "raw_obs") + if hasattr(self._rollout, "chunk_step"): + delattr(self._rollout, "chunk_step") self._reset_padding_slot() def _reset_padding_slot(self) -> None: diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py index 7c0d265b..1e82c7e4 100644 --- a/embodichain/agents/rl/buffer/utils.py +++ b/embodichain/agents/rl/buffer/utils.py @@ -62,6 +62,9 @@ def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict: if key in rollout.keys(): td[key] = rollout[key][:, :-1] + if hasattr(rollout, "chunk_step") and rollout.chunk_step is not None: + td["chunk_step"] = rollout.chunk_step + if flatten: return td.reshape(num_envs * time_dim) return td @@ -72,6 +75,9 @@ def iterate_minibatches( ) -> Iterator[TensorDict]: """Yield shuffled minibatches from a flattened rollout.""" total = rollout.batch_size[0] - indices = torch.randperm(total, device=device) + indices = torch.randperm(total) for start in range(0, total, batch_size): - yield rollout[indices[start : start + batch_size]] + batch_indices = indices[start : start + batch_size] + batch = rollout[batch_indices].clone() + batch["_indices"] = batch_indices + yield batch diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 16c5b584..6e81c0f9 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -68,20 +68,119 @@ def collect( if self._supports_shared_rollout: self.env.set_rollout_buffer(rollout) - initial_obs = flatten_dict_observation(self.obs_td) - rollout["obs"][:, 0] = initial_obs - for step_idx in range(num_steps): - step_td = TensorDict( - {"obs": rollout["obs"][:, step_idx]}, - batch_size=[rollout.batch_size[0]], + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None + + if use_raw_obs: + if raw_obs_list is None: + raise ValueError( + "Policy requires raw observations, " + "but the provided rollout TensorDict has no 'raw_obs' buffer. " + "Create the rollout via RolloutBuffer or " + "start_rollout so that 'raw_obs' is allocated." + ) + try: + raw_obs_len = len(raw_obs_list) + except TypeError: + raise ValueError( + "Rollout field 'raw_obs' must be an indexable sequence of length " + f"{num_steps + 1} when policy.use_raw_obs=True." + ) + expected_len = num_steps + 1 + if raw_obs_len != expected_len: + raise ValueError( + "Rollout 'raw_obs' length mismatch: " + f"expected {expected_len} (num_steps + 1), got {raw_obs_len}. " + "Ensure the rollout was created with use_raw_obs=True and " + "its time dimension matches the requested num_steps." + ) + + action_chunk_size = getattr(self.policy, "action_chunk_size", 0) + use_action_chunk = ( + getattr(self.policy, "use_action_chunk", False) and action_chunk_size > 0 + ) + cached_chunk = None + + if use_action_chunk: + rollout.chunk_step = torch.zeros( + self.env.num_envs, + num_steps, + dtype=torch.long, device=self.device, ) - step_td = self.policy.get_action(step_td) + + if use_raw_obs and raw_obs_list is not None: + raw_obs_list[0] = self.obs_td + rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) + else: + rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) + + for step_idx in range(num_steps): + step_in_chunk = step_idx % action_chunk_size if use_action_chunk else 0 + + # At chunk boundary, or cached invalidated by env reset, we need a new chunk + need_new_chunk = use_action_chunk and ( + step_in_chunk == 0 or cached_chunk is None + ) + + if need_new_chunk: + if use_raw_obs and raw_obs_list is not None: + step_td = TensorDict( + {"obs": raw_obs_list[step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + else: + step_td = TensorDict( + {"obs": rollout["obs"][:, step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + step_td = self.policy.get_action(step_td) + cached_chunk = step_td["action_chunk"] + action = step_td["action"] + effective_step_in_chunk = 0 + elif use_action_chunk and cached_chunk is not None: + action = cached_chunk[:, step_in_chunk] + effective_step_in_chunk = step_in_chunk + step_td = TensorDict( + { + "action": action, + "sample_log_prob": torch.zeros( + action.shape[0], device=self.device, dtype=torch.float32 + ), + "value": torch.zeros( + action.shape[0], device=self.device, dtype=torch.float32 + ), + }, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + else: + if use_raw_obs and raw_obs_list is not None: + step_td = TensorDict( + {"obs": raw_obs_list[step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + else: + step_td = TensorDict( + {"obs": rollout["obs"][:, step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + step_td = self.policy.get_action(step_td) + action = step_td["action"] next_obs, reward, terminated, truncated, env_info = self.env.step( - self._to_action_dict(step_td["action"]) + self._to_action_dict(action) ) next_obs_td = dict_to_tensordict(next_obs, self.device) + if use_action_chunk: + rollout.chunk_step[:, step_idx] = effective_step_in_chunk + # Invalidate cached_chunk on any env reset to avoid using old chunk for new episode + if (terminated | truncated).any(): + cached_chunk = None self._write_step( rollout=rollout, step_idx=step_idx, @@ -95,7 +194,11 @@ def collect( terminated=terminated, truncated=truncated, ) - rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) + if use_raw_obs and raw_obs_list is not None: + raw_obs_list[step_idx + 1] = next_obs_td + rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) + else: + rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) if on_step_callback is not None: on_step_callback(rollout[:, step_idx], env_info) @@ -107,7 +210,12 @@ def collect( def _attach_final_value(self, rollout: TensorDict) -> None: """Populate the bootstrap value for the final observed state.""" - final_obs = rollout["obs"][:, -1] + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None + if use_raw_obs and raw_obs_list is not None: + final_obs = raw_obs_list[-1] + else: + final_obs = rollout["obs"][:, -1] last_next_td = TensorDict( {"obs": final_obs}, batch_size=[rollout.batch_size[0]], @@ -155,8 +263,9 @@ def _write_env_step( def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: """Validate rollout layout expected by the collector.""" + obs_dim = rollout["obs"].shape[-1] expected_shapes = { - "obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim), + "obs": (self.env.num_envs, num_steps + 1, obs_dim), "action": (self.env.num_envs, num_steps + 1, self.policy.action_dim), "sample_log_prob": (self.env.num_envs, num_steps + 1), "value": (self.env.num_envs, num_steps + 1), diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 51cf7653..46231005 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -17,7 +17,7 @@ from __future__ import annotations import inspect -from typing import Dict, Type +from typing import Any, Dict, Optional, Type from gymnasium import spaces import torch @@ -26,6 +26,7 @@ from .actor_only import ActorOnly from .policy import Policy from .mlp import MLP +from .vla_policy import VLAPolicy # In-module policy registry _POLICY_REGISTRY: Dict[str, Type[Policy]] = {} @@ -63,13 +64,16 @@ def build_policy( device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, + env: Optional[Any] = None, ) -> Policy: """Build a policy from config using spaces for extensibility. Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while custom policies may accept richer `obs_space` / `action_space` inputs. + For vla_policy, pass env to enable set_env and _load_vla initialization. """ name = policy_block["name"].lower() + if name not in _POLICY_REGISTRY: available = ", ".join(get_registered_policy_names()) raise ValueError( @@ -119,7 +123,18 @@ def build_policy( build_kwargs["actor"] = actor if "critic" in init_params and critic is not None: build_kwargs["critic"] = critic - return policy_cls(**build_kwargs) + if "policy_cfg" in init_params: + build_kwargs["policy_cfg"] = policy_block + policy = policy_cls(**build_kwargs) + if name == "vla_policy": + if env is None: + raise ValueError( + "VLAPolicy requires an 'env' argument to be passed to build_policy " + "so that set_env and _load_vla can be called before use." + ) + policy.set_env(env) + policy._load_vla() + return policy def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: @@ -143,10 +158,12 @@ def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: # default registrations register_policy("actor_critic", ActorCritic) register_policy("actor_only", ActorOnly) +register_policy("vla_policy", VLAPolicy) __all__ = [ "ActorCritic", "ActorOnly", + "VLAPolicy", "register_policy", "get_registered_policy_names", "build_policy", diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index 32caf0e3..8016ddcd 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -86,7 +86,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: tensordict["value"] = self.critic(tensordict["obs"]).squeeze(-1) return tensordict - def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + def evaluate_actions(self, tensordict: TensorDict, **kwargs) -> TensorDict: obs = tensordict["obs"] action = tensordict["action"] dist = self._distribution(obs) diff --git a/embodichain/agents/rl/models/actor_only.py b/embodichain/agents/rl/models/actor_only.py index 3d6d1f78..0f93ce8f 100644 --- a/embodichain/agents/rl/models/actor_only.py +++ b/embodichain/agents/rl/models/actor_only.py @@ -77,7 +77,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: ) return tensordict - def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + def evaluate_actions(self, tensordict: TensorDict, **kwargs) -> TensorDict: obs = tensordict["obs"] action = tensordict["action"] dist = self._distribution(obs) diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py new file mode 100644 index 00000000..a3c1648f --- /dev/null +++ b/embodichain/agents/rl/models/vla_policy.py @@ -0,0 +1,248 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from tensordict import TensorDict +from embodichain.agents.rl.vla_registry import create_vla_backend +from .policy import Policy + +__all__ = ["VLAPolicy"] + + +class VLAPolicy(Policy): + """Wraps DexForceVLA as Policy for GRPO fine-tuning.""" + + def __init__( + self, + device: torch.device, + policy_cfg: dict[str, Any], + obs_space=None, + action_space=None, + ) -> None: + super().__init__() + self.device = device + self.policy_cfg = dict(policy_cfg) + self.vla_cfg = dict(self.policy_cfg.get("vla", {})) + self.model_path = str(self.vla_cfg.get("model_path", "")) + self.action_horizon = int(self.vla_cfg.get("action_horizon", 32)) + self.gaussian_sigma = float(self.vla_cfg.get("gaussian_sigma", 0.1)) + + if not self.model_path: + raise ValueError("VLAPolicy requires 'policy.vla.model_path'.") + + self._vla_model: nn.Module | None = None + self._action_indices: list[int] | None = None + + if action_space is None: + self.action_dim = 14 + elif isinstance(action_space, int): + self.action_dim = action_space + elif hasattr(action_space, "shape") and len(action_space.shape) > 0: + self.action_dim = int(action_space.shape[-1]) + else: + self.action_dim = 14 + self.obs_dim = 0 # VLA uses raw ob + self.use_raw_obs = True # Tell collector to pass raw ob + + self.use_action_chunk = True + self.action_chunk_size = self.action_horizon + self._env = None + + def set_env(self, env) -> None: + """Set env reference in forward.""" + self._env = env + + def _load_vla(self) -> None: + if self._vla_model is not None: + return + backend = create_vla_backend( + "dexforce_vla", + model_path=self.model_path, + device=self.device, + action_horizon=self.action_horizon, + **{ + k: v + for k, v in self.vla_cfg.items() + if k not in ("backend", "model_path", "action_horizon") + }, + ) + self._vla_model, self._action_indices, self._prepare_batch_fn = backend + + def _vla_chunk_to_env_chunk( + self, action_chunk: torch.Tensor, env=None + ) -> torch.Tensor: + """Convert VLA output (N, T, va_dim) chunk to env format (N, T, env_dim).""" + if self._action_indices is not None: + step = action_chunk[:, :, self._action_indices] + else: + step = action_chunk + + if env is not None: + env_dim = getattr(env.action_space, "shape", (None,)) + if len(env_dim) > 0 and env_dim[-1] is not None: + env_dim = int(env_dim[-1]) + if step.shape[-1] > env_dim: + step = step[..., :env_dim] + elif step.shape[-1] < env_dim: + pad = torch.zeros( + step.shape[0], + step.shape[1], + env_dim - step.shape[-1], + device=step.device, + dtype=step.dtype, + ) + step = torch.cat([step, pad], dim=-1) + return step + + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + env = getattr(tensordict, "env", None) + if env is None: + env = getattr(self, "_env", None) + if env is None: + raise ValueError( + "VLAPolicy needs env. Set policy._env or pass env in tensordict." + ) + + self._load_vla() + self._vla_model.eval() + if hasattr(obs, "batch_size") and len(obs.batch_size) > 0: + batch_size = int(obs.batch_size[0]) + elif isinstance(obs, dict) and "robot" in obs and "qpos" in obs["robot"]: + q = obs["robot"]["qpos"] + batch_size = q.shape[0] if hasattr(q, "shape") and len(q.shape) > 0 else 1 + else: + batch_size = 1 + if batch_size == 1: + batch = self._prepare_batch_fn(obs, env) + vla_chunk = self._vla_model.predict_action( + batch, + action_only=True, + inference_horizon=self.action_horizon, + allow_grad=False, + use_fix_aug=False, + ) + action_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) + else: + chunks_env = [] + for i in range(batch_size): + obs_i = obs[i] if hasattr(obs, "__getitem__") else obs + batch_i = self._prepare_batch_fn(obs_i, env) + vla_chunk = self._vla_model.predict_action( + batch_i, + action_only=True, + inference_horizon=self.action_horizon, + allow_grad=False, + use_fix_aug=False, + ) + chunk_i = self._vla_chunk_to_env_chunk(vla_chunk, env=env) + chunks_env.append(chunk_i) + action_chunk_env = torch.cat(chunks_env, dim=0) + + action_chunk_env = action_chunk_env.to(self.device, dtype=torch.float32) + action = action_chunk_env[:, 0] + + tensordict["action"] = action + tensordict["sample_log_prob"] = torch.zeros( + action.shape[0], device=self.device, dtype=torch.float32 + ) + tensordict["value"] = torch.zeros( + action.shape[0], device=self.device, dtype=torch.float32 + ) + if self.use_action_chunk: + tensordict["action_chunk"] = action_chunk_env + return tensordict + + def get_value(self, tensordict: TensorDict) -> TensorDict: + b = tensordict.batch_size[0] + tensordict["value"] = torch.zeros(b, device=self.device, dtype=torch.float32) + return tensordict + + def evaluate_actions( + self, tensordict: TensorDict, rollout=None, num_envs=None, **kwargs + ) -> TensorDict: + """Compute log_prob via Gaussian proxy""" + b = tensordict.batch_size[0] + env = getattr(self, "_env", None) + if env is None: + raise ValueError( + "VLAPolicy.evaluate_actions requires env. Call policy.set_env(env)." + ) + + raw_obs = getattr(rollout, "raw_obs", None) + chunk_step = tensordict.get("chunk_step", None) + indices = tensordict.get("_indices", None) + if raw_obs is None or chunk_step is None or indices is None or num_envs is None: + raise ValueError( + "VLAPolicy.evaluate_actions requires rollout.raw_obs, chunk_step, _indices, num_envs. " + "Ensure collector uses use_raw_obs and use_action_chunk, and GRPO passes rollout and num_envs." + ) + + time_dim = len(raw_obs) - 1 + sigma = self.gaussian_sigma + log_probs = [] + self._load_vla() + self._vla_model.eval() + + for i in range(b): + idx = int(indices[i].item()) + env_idx = idx // time_dim + step_idx = idx % time_dim + step_in_chunk = int(chunk_step[i].item()) + # Action came from chunk predicted at chunk start + chunk_start_idx = max(0, step_idx - step_in_chunk) + obs_i = raw_obs[chunk_start_idx][env_idx] + action_gt = tensordict["action"][i] + + batch_i = self._prepare_batch_fn(obs_i, env) + vla_chunk = self._vla_model.predict_action( + batch_i, + action_only=True, + inference_horizon=self.action_horizon, + allow_grad=True, + use_fix_aug=False, + ) + pred_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) + pred = pred_chunk_env[0, step_in_chunk] + if pred.shape[-1] != action_gt.shape[-1]: + pred = pred[: action_gt.shape[-1]] + mse = ((action_gt - pred).pow(2)).sum(-1) + log_prob = -0.5 * mse / (sigma * sigma + 1e-8) + log_probs.append(log_prob) + + log_probs = torch.stack(log_probs) + entropy = ( + 0.5 * self.action_dim * (1 + np.log(2 * np.pi) + 2 * np.log(sigma + 1e-8)) + ) + entropy = torch.full((b,), entropy, device=self.device, dtype=torch.float32) + + return TensorDict( + { + "sample_log_prob": log_probs, + "entropy": entropy, + "value": torch.zeros(b, device=self.device, dtype=torch.float32), + }, + batch_size=tensordict.batch_size, + device=self.device, + ) diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index bf08c746..b4678d26 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -230,7 +230,11 @@ def train_from_config(config_path: str): ) else: policy = build_policy( - policy_block, env.observation_space, env.action_space, device + policy_block, + env.observation_space, + env.action_space, + device, + env=env, ) # Build Algorithm via factory diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 4f660232..1ec97d82 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -27,6 +27,7 @@ from embodichain.agents.rl.buffer import RolloutBuffer from embodichain.agents.rl.collector import SyncCollector +from embodichain.agents.rl.utils import dict_to_tensordict from embodichain.lab.gym.envs.managers.event_manager import EventManager from .helper import flatten_dict_observation @@ -84,6 +85,29 @@ def __init__( action_dim = getattr(self.policy, "action_dim", None) if obs_dim is None or action_dim is None: raise RuntimeError("Policy must expose obs_dim and action_dim.") + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + action_chunk_size = getattr(self.policy, "action_chunk_size", 0) + use_action_chunk = getattr(self.policy, "use_action_chunk", False) + if use_action_chunk and action_chunk_size > 0: + self.buffer_size = ( + (self.buffer_size + action_chunk_size - 1) + // action_chunk_size + * action_chunk_size + ) + + if use_raw_obs: + try: + reset_out = self.env.reset() + sample_obs = reset_out[0] if isinstance(reset_out, tuple) else reset_out + obs_td = dict_to_tensordict(sample_obs, self.device) + flat_obs = flatten_dict_observation(obs_td) + obs_dim = int( + flat_obs.shape[-1] + if isinstance(flat_obs, torch.Tensor) + else np.asarray(flat_obs).shape[-1] + ) + except Exception: + obs_dim = max(1, obs_dim) self.buffer = RolloutBuffer( num_envs=num_envs, @@ -91,6 +115,7 @@ def __init__( obs_dim=obs_dim, action_dim=action_dim, device=self.device, + use_raw_obs=use_raw_obs, ) self.collector = SyncCollector( env=self.env, @@ -245,28 +270,60 @@ def _eval_once(self, num_episodes: int = 5): episode_returns = [] episode_lengths = [] + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + use_action_chunk = getattr(self.policy, "use_action_chunk", False) + action_chunk_size = getattr(self.policy, "action_chunk_size", 1) + effective_use_action_chunk = use_action_chunk and action_chunk_size > 0 + for _ in range(num_episodes): - # Reset and initialize episode tracking obs, _ = self.eval_env.reset() - obs = flatten_dict_observation(obs) - num_envs = obs.shape[0] if obs.ndim == 2 else 1 + if use_raw_obs: + obs_td = dict_to_tensordict(obs, self.device) + else: + obs_td = flatten_dict_observation(obs) + num_envs = ( + obs_td.batch_size[0] + if hasattr(obs_td, "batch_size") + else (obs_td.shape[0] if hasattr(obs_td, "shape") else 1) + ) done_mask = torch.zeros(num_envs, dtype=torch.bool, device=self.device) cumulative_reward = torch.zeros( num_envs, dtype=torch.float32, device=self.device ) step_count = torch.zeros(num_envs, dtype=torch.int32, device=self.device) + cached_chunk = None + step_in_chunk = 0 - # Run episode until all environments complete while not done_mask.all(): - # Get deterministic actions from policy - action_td = TensorDict( - {"obs": obs}, - batch_size=[num_envs], - device=self.device, + if effective_use_action_chunk and ( + cached_chunk is None or step_in_chunk == 0 + ): + action_td = TensorDict( + {"obs": obs_td}, + batch_size=[num_envs], + device=self.device, + ) + action_td = self.policy.get_action(action_td, deterministic=True) + cached_chunk = action_td.get("action_chunk") + actions = action_td["action"] + step_in_chunk = 0 + elif effective_use_action_chunk and cached_chunk is not None: + actions = cached_chunk[:, step_in_chunk] + else: + action_td = TensorDict( + {"obs": obs_td}, + batch_size=[num_envs], + device=self.device, + ) + action_td = self.policy.get_action(action_td, deterministic=True) + actions = action_td["action"] + + step_in_chunk = ( + (step_in_chunk + 1) % action_chunk_size + if effective_use_action_chunk + else 0 ) - action_td = self.policy.get_action(action_td, deterministic=True) - actions = action_td["action"] am = getattr(self.eval_env, "action_manager", None) action_type = ( am.action_type @@ -275,15 +332,17 @@ def _eval_once(self, num_episodes: int = 5): ) action_dict = {action_type: actions} - # Environment step obs, reward, terminated, truncated, info = self.eval_env.step( action_dict ) - obs = ( - flatten_dict_observation(obs) - if isinstance(obs, TensorDict) - else obs - ) + if use_raw_obs: + obs_td = dict_to_tensordict(obs, self.device) + else: + obs_td = ( + flatten_dict_observation(obs) + if isinstance(obs, TensorDict) + else obs + ) # Update statistics only for still-running environments done = terminated | truncated @@ -292,6 +351,10 @@ def _eval_once(self, num_episodes: int = 5): step_count[still_running] += 1 done_mask |= done + # Invalidate cached_chunk on any env reset + if effective_use_action_chunk and done.any(): + cached_chunk = None + # Trigger evaluation events (e.g., video recording) if hasattr(self, "eval_event_manager"): if "interval" in self.eval_event_manager.available_modes: diff --git a/embodichain/agents/rl/vla_registry.py b/embodichain/agents/rl/vla_registry.py new file mode 100644 index 00000000..e9f13546 --- /dev/null +++ b/embodichain/agents/rl/vla_registry.py @@ -0,0 +1,75 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from importlib.metadata import entry_points +from typing import Any, Callable + +__all__ = [ + "get_vla_backend", + "get_registered_vla_backend_names", + "create_vla_backend", +] + + +_VLA_BACKENDS: dict[str, Callable[..., Any]] = {} +_ENTRY_POINTS_DISCOVERED = False + + +def _discover_entry_points() -> None: + """Discover and register VLA backends from entry_points.""" + global _ENTRY_POINTS_DISCOVERED + if _ENTRY_POINTS_DISCOVERED: + return + _ENTRY_POINTS_DISCOVERED = True + try: + eps = entry_points(group="embodichain.vla_backends") + for ep in eps: + try: + factory = ep.load() + name = str(ep.name).lower() + if name not in _VLA_BACKENDS: + _VLA_BACKENDS[name] = factory + except Exception: + pass + except Exception: + pass + + +def get_vla_backend(name: str) -> Callable[..., Any] | None: + name = str(name).lower() + if name in _VLA_BACKENDS: + return _VLA_BACKENDS[name] + _discover_entry_points() + return _VLA_BACKENDS.get(name) + + +def get_registered_vla_backend_names() -> list[str]: + _discover_entry_points() + return list(_VLA_BACKENDS.keys()) + + +def create_vla_backend(name: str, **kwargs) -> Any: + factory = get_vla_backend(name) + if factory is None: + available = get_registered_vla_backend_names() + raise ValueError( + f"Unknown VLA backend '{name}'. Available: {available}. " + "Ensure a package providing the 'embodichain.vla_backends' entry point " + "group is installed." + ) + return factory(**kwargs)