From 41706f0722e0cc3735d29832eb9a9a81acc3a470 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sun, 15 Mar 2026 14:15:14 +0000 Subject: [PATCH] wip --- AGENTS.md | 1 - embodichain/lab/gym/envs/base_env.py | 20 + embodichain/lab/gym/envs/embodied_env.py | 6 + embodichain/lab/gym/envs/managers/__init__.py | 13 +- .../lab/gym/envs/managers/action_manager.py | 265 ++++-------- embodichain/lab/gym/envs/managers/actions.py | 383 ++++++++++++++++++ embodichain/lab/gym/envs/managers/cfg.py | 10 +- embodichain/lab/gym/utils/gym_utils.py | 2 +- .../gym/envs/managers/test_action_manager.py | 168 +++++++- 9 files changed, 669 insertions(+), 199 deletions(-) create mode 100644 embodichain/lab/gym/envs/managers/actions.py diff --git a/AGENTS.md b/AGENTS.md index 8749b172..c52497c6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -5,7 +5,6 @@ **IMPORTANT**: The Python package name is `embodichain` (all lowercase, one word). - Repository folder: `EmbodiChain` (PascalCase) - Python package: `embodichain` (lowercase) -- NEVER use: `embodiedchain`, `embodyichain`, or any other variant ## Project Structure diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index ab945b45..fcd89c98 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -550,6 +550,21 @@ def _preprocess_action(self, action: EnvAction) -> EnvAction: """ return action + def _postprocess_action(self, action: EnvAction) -> EnvAction: + """Postprocess action after applying to robot. + + Post processing is usually used to modify the action after it has been applied to the robot, + performing normalization, noise addition, or any other modifications that need to be applied + for policy learning or evaluation purposes. + + Args: + action: Action after preprocessing and robot control command generation + + Returns: + Final action to be applied in the simulation + """ + return action + def _step_action(self, action: EnvAction) -> EnvAction: """Set action control command into simulation. @@ -608,8 +623,10 @@ def step( Returns: A tuple contraining the observation, reward, terminated, truncated, and info dictionary. """ + action = self._preprocess_action(action=action) action = self._step_action(action=action) + self.sim.update(self.sim_cfg.physics_dt, self.cfg.sim_steps_per_control) self._update_sim_state(**kwargs) @@ -620,6 +637,9 @@ def step( rewards=rewards, obs=obs, action=action, info=info ) + # Apply postprocessing to the action after all computations are done. + action = self._postprocess_action(action=action) + self._elapsed_steps += 1 terminateds = torch.logical_or( diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 5e40d6fd..6fc7d332 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -707,6 +707,12 @@ def _preprocess_action(self, action: EnvAction) -> EnvAction: return self.action_manager.process_action(action) return super()._preprocess_action(action) + def _postprocess_action(self, action): + if self.action_manager is not None: + pass + # return self.action_manager.postprocess_action(action) + return super()._postprocess_action(action) + def _setup_robot(self, **kwargs) -> Robot: """Setup the robot in the environment. diff --git a/embodichain/lab/gym/envs/managers/__init__.py b/embodichain/lab/gym/envs/managers/__init__.py index 5c38352b..939f190c 100644 --- a/embodichain/lab/gym/envs/managers/__init__.py +++ b/embodichain/lab/gym/envs/managers/__init__.py @@ -27,15 +27,6 @@ from .event_manager import EventManager from .observation_manager import ObservationManager from .reward_manager import RewardManager -from .action_manager import ( - ActionManager, - ActionTerm, - DeltaQposTerm, - QposTerm, - QposNormalizedTerm, - EefPoseTerm, - QvelTerm, - QfTerm, -) +from .action_manager import * +from .actions import * from .dataset_manager import DatasetManager -from .datasets import * diff --git a/embodichain/lab/gym/envs/managers/action_manager.py b/embodichain/lab/gym/envs/managers/action_manager.py index adcd1ec4..b4dc6812 100644 --- a/embodichain/lab/gym/envs/managers/action_manager.py +++ b/embodichain/lab/gym/envs/managers/action_manager.py @@ -14,20 +14,28 @@ # limitations under the License. # ---------------------------------------------------------------------------- +"""Action manager for processing policy actions into robot control commands. + +This module provides the :class:`ActionManager` class which handles the interpretation +and preprocessing of raw actions from the policy into the format expected by the robot. + +The concrete action term implementations (e.g., :class:`QposTerm`, :class:`DeltaQposTerm`) +are available in :mod:`actions` module. +""" + from __future__ import annotations import inspect from abc import abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import torch from prettytable import PrettyTable from tensordict import TensorDict from embodichain.lab.sim.types import EnvAction -from embodichain.utils.math import matrix_from_euler, matrix_from_quat - from embodichain.utils.string import string_to_callable +from embodichain.utils import logger from .cfg import ActionTermCfg from .manager_base import Functor, ManagerBase @@ -35,6 +43,8 @@ if TYPE_CHECKING: from embodichain.lab.gym.envs import EmbodiedEnv +__all__ = ["ActionTerm", "ActionManager"] + class ActionTerm(Functor): """Base class for action terms. @@ -92,6 +102,11 @@ def __init__(self, cfg: object, env: EmbodiedEnv): """ self._term_names: list[str] = [] self._terms: dict[str, ActionTerm] = {} + self._term_modes: dict[str, Literal["pre", "post"]] = {} + self._mode_term_names: dict[Literal["pre", "post"], list[str]] = { + "pre": [], + "post": [], + } super().__init__(cfg, env) def __str__(self) -> str: @@ -99,12 +114,14 @@ def __str__(self) -> str: msg = f" contains {len(self._term_names)} active term(s).\n" table = PrettyTable() table.title = "Active Action Terms" - table.field_names = ["Index", "Name", "Dimension"] + table.field_names = ["Index", "Name", "Mode", "Dimension"] table.align["Name"] = "l" + table.align["Mode"] = "c" table.align["Dimension"] = "r" for index, name in enumerate(self._term_names): term = self._terms[name] - table.add_row([index, name, term.action_dim]) + mode = self._term_modes.get(name, "pre") + table.add_row([index, name, mode, term.action_dim]) msg += table.get_string() msg += "\n" return msg @@ -114,6 +131,19 @@ def active_functors(self) -> list[str]: """Name of active action terms.""" return self._term_names + def get_functors_by_mode( + self, mode: Literal["pre", "post"] + ) -> list[tuple[str, ActionTerm]]: + """Get action terms filtered by mode. + + Args: + mode: The mode to filter by ("pre" or "post"). + + Returns: + List of (name, term) tuples for terms with the specified mode. + """ + return [(name, self._terms[name]) for name in self._mode_term_names[mode]] + @property def action_type(self) -> str: """The active action type (term name) for backward compatibility.""" @@ -124,27 +154,58 @@ def total_action_dim(self) -> int: """Total dimension of actions (sum of all term dimensions).""" return sum(t.action_dim for t in self._terms.values()) - def process_action(self, action: EnvAction) -> EnvAction: + def get_action_dim_by_mode(self, mode: Literal["pre", "post"]) -> int: + """Get total action dimension for terms of a specific mode. + + Args: + mode: The mode to filter by ("pre" or "post"). + + Returns: + Sum of action dimensions for terms with the specified mode. + """ + mode_terms = self.get_functors_by_mode(mode) + return sum(term.action_dim for _, term in mode_terms) + + def process_action( + self, action: EnvAction, mode: Literal["pre", "post"] = "pre" + ) -> EnvAction: """Process raw action from policy into robot control format. Supports: - 1. Tensor input: Passed to the active (first) term. + 1. Tensor input: Passed to the active (first) term of the specified mode. 2. Dict/TensorDict input: Uses key matching term name; raises an error if no match. Args: action: Raw action from policy (tensor or dict). + mode: The processing mode - "pre" for preprocessing (default) or "post" + for postprocessing. When "post", only terms with mode="post" are applied. Returns: TensorDict action ready for robot control. """ + # Filter terms by mode + mode_terms = self._mode_term_names[mode] + + if not mode_terms: + logger.log_error( + f"No action terms found for mode '{mode}'. " + f"Available terms: {self._term_names}", + error_type=ValueError, + ) + + # TODO: We should refactor the action manager to support multiple active terms. if not isinstance(action, (dict, TensorDict)): - return self._terms[self._term_names[0]].process_action(action) + return self._terms[mode_terms[0]].process_action(action) # Dict input: find matching term - for term_name in self._term_names: + for term_name in mode_terms: if term_name in action: return self._terms[term_name].process_action(action[term_name]) - raise ValueError(f"No valid action keys. Expected one of: {self._term_names}") + + logger.log_error( + f"No valid action keys. Expected one of: {mode_terms}", + error_type=ValueError, + ) def get_term(self, name: str) -> ActionTerm: """Get action term by name.""" @@ -166,192 +227,30 @@ def _prepare_functors(self) -> None: if term_cfg is None: continue if not isinstance(term_cfg, ActionTermCfg): - raise TypeError( + logger.log_error( f"Configuration for the term '{term_name}' is not of type ActionTermCfg. " - f"Received: '{type(term_cfg)}'." + f"Received: '{type(term_cfg)}'.", + error_type=TypeError, ) # Resolve string to callable (skip base class params check for ActionTerm) if isinstance(term_cfg.func, str): term_cfg.func = string_to_callable(term_cfg.func) if not callable(term_cfg.func): - raise AttributeError( + logger.log_error( f"The action term '{term_name}' is not callable. " - f"Received: '{term_cfg.func}'" + f"Received: '{term_cfg.func}'", + error_type=TypeError, ) if inspect.isclass(term_cfg.func) and not issubclass( term_cfg.func, ActionTerm ): - raise TypeError( + logger.log_error( f"Configuration for the term '{term_name}' must be a subclass of " - f"ActionTerm. Received: '{type(term_cfg.func)}'." + f"ActionTerm. Received: '{type(term_cfg.func)}'.", + error_type=TypeError, ) self._process_functor_cfg_at_play(term_name, term_cfg) self._term_names.append(term_name) self._terms[term_name] = term_cfg.func - - -# ---------------------------------------------------------------------------- -# Concrete ActionTerm implementations -# ---------------------------------------------------------------------------- - - -class DeltaQposTerm(ActionTerm): - """Delta joint position action: current_qpos + scale * action -> qpos.""" - - def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): - super().__init__(cfg, env) - self._scale = cfg.params.get("scale", 1.0) - - @property - def action_dim(self) -> int: - return len(self._env.active_joint_ids) - - def process_action(self, action: torch.Tensor) -> EnvAction: - scaled = action * self._scale - current_qpos = self._env.robot.get_qpos() - qpos = current_qpos + scaled - batch_size = qpos.shape[0] - return TensorDict({"qpos": qpos}, batch_size=[batch_size], device=self.device) - - -class QposTerm(ActionTerm): - """Absolute joint position action: scale * action -> qpos.""" - - def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): - super().__init__(cfg, env) - self._scale = cfg.params.get("scale", 1.0) - - @property - def action_dim(self) -> int: - return len(self._env.active_joint_ids) - - def process_action(self, action: torch.Tensor) -> EnvAction: - qpos = action * self._scale - batch_size = qpos.shape[0] - return TensorDict({"qpos": qpos}, batch_size=[batch_size], device=self.device) - - -class QposNormalizedTerm(ActionTerm): - """Normalized action in [-1, 1] -> denormalize to joint limits -> qpos. - - The policy output is scaled by ``params.scale`` before denormalization. - With scale=1.0 (default), action in [-1, 1] maps to [low, high]. - With scale<1.0, the effective range shrinks toward the center (e.g. scale=0.5 - maps to 25%-75% of joint range). Use scale=1.0 for standard normalized control. - """ - - def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): - super().__init__(cfg, env) - self._scale = cfg.params.get("scale", 1.0) - - @property - def action_dim(self) -> int: - return len(self._env.active_joint_ids) - - def process_action(self, action: torch.Tensor) -> EnvAction: - scaled = action * self._scale - qpos_limits = self._env.robot.body_data.qpos_limits[ - 0, self._env.active_joint_ids - ] - low = qpos_limits[:, 0] - high = qpos_limits[:, 1] - qpos = low + (scaled + 1.0) * 0.5 * (high - low) - batch_size = qpos.shape[0] - return TensorDict({"qpos": qpos}, batch_size=[batch_size], device=self.device) - - -class EefPoseTerm(ActionTerm): - """End-effector pose (6D or 7D) -> IK -> qpos. - - On IK failure, falls back to current_qpos for that env. - Returns ``ik_success`` in the TensorDict so reward/observation - can penalize or condition on IK failures. - """ - - def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): - super().__init__(cfg, env) - self._scale = cfg.params.get("scale", 1.0) - self._pose_dim = cfg.params.get("pose_dim", 7) # 6 for euler, 7 for quat - - @property - def action_dim(self) -> int: - return self._pose_dim - - def process_action(self, action: torch.Tensor) -> EnvAction: - scaled = action * self._scale - current_qpos = self._env.robot.get_qpos() - batch_size = scaled.shape[0] - target_pose = ( - torch.eye(4, device=self.device).unsqueeze(0).repeat(batch_size, 1, 1) - ) - if scaled.shape[-1] == 6: - target_pose[:, :3, 3] = scaled[:, :3] - target_pose[:, :3, :3] = matrix_from_euler(scaled[:, 3:6]) - elif scaled.shape[-1] == 7: - target_pose[:, :3, 3] = scaled[:, :3] - target_pose[:, :3, :3] = matrix_from_quat(scaled[:, 3:7]) - else: - raise ValueError( - f"EEF pose action must be 6D or 7D, got {scaled.shape[-1]}D" - ) - # Batch IK: robot.compute_ik supports (n_envs, 4, 4) pose and (n_envs, dof) seed - ret, qpos_ik = self._env.robot.compute_ik( - pose=target_pose, - joint_seed=current_qpos, - ) - # Fallback to current_qpos where IK failed - result_qpos = torch.where( - ret.unsqueeze(-1).expand_as(qpos_ik), qpos_ik, current_qpos - ) - return TensorDict( - {"qpos": result_qpos, "ik_success": ret}, - batch_size=[batch_size], - device=self.device, - ) - - -class QvelTerm(ActionTerm): - """Joint velocity action: scale * action -> qvel.""" - - def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): - super().__init__(cfg, env) - self._scale = cfg.params.get("scale", 1.0) - - @property - def action_dim(self) -> int: - return len(self._env.active_joint_ids) - - def process_action(self, action: torch.Tensor) -> EnvAction: - qvel = action * self._scale - batch_size = qvel.shape[0] - return TensorDict({"qvel": qvel}, batch_size=[batch_size], device=self.device) - - -class QfTerm(ActionTerm): - """Joint force/torque action: scale * action -> qf.""" - - def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): - super().__init__(cfg, env) - self._scale = cfg.params.get("scale", 1.0) - - @property - def action_dim(self) -> int: - return len(self._env.active_joint_ids) - - def process_action(self, action: torch.Tensor) -> EnvAction: - qf = action * self._scale - batch_size = qf.shape[0] - return TensorDict({"qf": qf}, batch_size=[batch_size], device=self.device) - - -__all__ = [ - "ActionTerm", - "ActionManager", - "ActionTermCfg", - "DeltaQposTerm", - "QposTerm", - "QposNormalizedTerm", - "EefPoseTerm", - "QvelTerm", - "QfTerm", -] + self._term_modes[term_name] = term_cfg.mode + self._mode_term_names[term_cfg.mode].append(term_name) diff --git a/embodichain/lab/gym/envs/managers/actions.py b/embodichain/lab/gym/envs/managers/actions.py new file mode 100644 index 00000000..0202eee6 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/actions.py @@ -0,0 +1,383 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Action terms for processing policy actions into robot control commands. + +This module provides concrete implementations of :class:`ActionTerm` that convert +raw policy actions into different control formats (e.g., joint positions, velocities, +forces, or end-effector poses). + +The action terms are typically used in conjunction with :class:`ActionManager` which +handles calling the appropriate term based on configuration. + +Example usage in environment config:: + + action_terms: + # Pre-processing: raw action -> joint position + joint_pos: + func: QposTerm + mode: pre + params: + scale: 1.0 + # Post-processing: clamp the output + clamp: + func: ActionClampTerm + mode: post + params: + min: -1.0 + max: 1.0 + +Available action terms: + +- :class:`DeltaQposTerm`: Delta joint position (current + scale * action) +- :class:`QposTerm`: Absolute joint position (scale * action) +- :class:`QposNormalizedTerm`: Normalized action [-1,1] -> joint limits +- :class:`EefPoseTerm`: End-effector pose -> IK -> joint position +- :class:`QvelTerm`: Joint velocity (scale * action) +- :class:`QfTerm`: Joint force/torque (scale * action) +- :class:`ActionClampTerm`: Post-processing clamp to min/max limits +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from tensordict import TensorDict + +from embodichain.lab.sim.types import EnvAction +from embodichain.utils.math import matrix_from_euler, matrix_from_quat +from .action_manager import ActionTerm +from .cfg import ActionTermCfg + +# Import ActionTerm from action_manager after it's defined +# This is a late import to avoid circular dependency +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +__all__ = [ + "DeltaQposTerm", + "QposTerm", + "QposNormalizedTerm", + "EefPoseTerm", + "QvelTerm", + "QfTerm", + "ActionClampTerm", +] + + +# ---------------------------------------------------------------------------- +# Concrete ActionTerm implementations +# ---------------------------------------------------------------------------- + + +class DeltaQposTerm(ActionTerm): + """Delta joint position action: current_qpos + scale * action -> qpos. + + This action term adds a scaled delta to the current joint positions. + Useful for relative position control where the policy outputs position offsets. + + Args: + scale: Scaling factor for the action. Defaults to 1.0. + + Example: + >>> cfg = ActionTermCfg(func=DeltaQposTerm, params={"scale": 0.1}) + >>> term = DeltaQposTerm(cfg, env) + >>> action = torch.ones(num_envs, dof) * 2.0 + >>> result = term.process_action(action) + >>> # result["qpos"] = current_qpos + 0.1 * action + """ + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + + @property + def action_dim(self) -> int: + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + scaled = action * self._scale + current_qpos = self._env.robot.get_qpos() + qpos = current_qpos + scaled + batch_size = qpos.shape[0] + return TensorDict({"qpos": qpos}, batch_size=[batch_size], device=self.device) + + +class QposTerm(ActionTerm): + """Absolute joint position action: scale * action -> qpos. + + This action term directly uses the scaled action as target joint positions. + Useful for absolute position control. + + Args: + scale: Scaling factor for the action. Defaults to 1.0. + + Example: + >>> cfg = ActionTermCfg(func=QposTerm, params={"scale": 1.0}) + >>> term = QposTerm(cfg, env) + >>> action = torch.ones(num_envs, dof) * 0.5 + >>> result = term.process_action(action) + >>> # result["qpos"] = 0.5 * action + """ + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + + @property + def action_dim(self) -> int: + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + qpos = action * self._scale + batch_size = qpos.shape[0] + return TensorDict({"qpos": qpos}, batch_size=[batch_size], device=self.device) + + +class QposNormalizedTerm(ActionTerm): + """Normalized action in [-1, 1] -> denormalize to joint limits -> qpos. + + The policy outputs normalized actions in the range [-1, 1] which are then + mapped to the joint's position limits. + + The policy output is scaled by ``params.scale`` before denormalization. + With scale=1.0 (default), action in [-1, 1] maps to [low, high]. + With scale<1.0, the effective range shrinks toward the center (e.g. scale=0.5 + maps to 25%-75% of joint range). Use scale=1.0 for standard normalized control. + + Args: + scale: Scaling factor applied before denormalization. Defaults to 1.0. + + Example: + >>> cfg = ActionTermCfg(func=QposNormalizedTerm, params={"scale": 1.0}) + >>> term = QposNormalizedTerm(cfg, env) + >>> action = torch.tensor([[-1.0, 1.0], [0.0, 0.0]]) # min/max per joint + >>> result = term.process_action(action) + >>> # Maps [-1, 1] to [qpos_limits_low, qpos_limits_high] + """ + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + + @property + def action_dim(self) -> int: + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + scaled = action * self._scale + qpos_limits = self._env.robot.body_data.qpos_limits[ + 0, self._env.active_joint_ids + ] + low = qpos_limits[:, 0] + high = qpos_limits[:, 1] + qpos = low + (scaled + 1.0) * 0.5 * (high - low) + batch_size = qpos.shape[0] + return TensorDict({"qpos": qpos}, batch_size=[batch_size], device=self.device) + + +class EefPoseTerm(ActionTerm): + """End-effector pose (6D or 7D) -> IK -> qpos. + + The policy outputs a target end-effector pose which is converted to joint + positions using inverse kinematics. + + Supports two pose representations: + - 6D: position (3) + Euler angles (3) + - 7D: position (3) + quaternion (4) + + On IK failure, falls back to current_qpos for that env. + Returns ``ik_success`` in the TensorDict so reward/observation + can penalize or condition on IK failures. + + Args: + scale: Scaling factor for the pose. Defaults to 1.0. + pose_dim: Dimension of the pose (6 for Euler, 7 for quaternion). Defaults to 7. + + Example: + >>> cfg = ActionTermCfg(func=EefPoseTerm, params={"scale": 1.0, "pose_dim": 7}) + >>> term = EefPoseTerm(cfg, env) + >>> # 7D: position (3) + quaternion (4) + >>> action = torch.zeros(num_envs, 7) + >>> action[:, :3] = 0.1 # target position + >>> action[:, 3] = 1.0 # quaternion w + >>> result = term.process_action(action) + >>> # result["qpos"] = IK solution + >>> # result["ik_success"] = bool tensor indicating IK success + """ + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + self._pose_dim = cfg.params.get("pose_dim", 7) # 6 for euler, 7 for quat + + @property + def action_dim(self) -> int: + return self._pose_dim + + def process_action(self, action: torch.Tensor) -> EnvAction: + scaled = action * self._scale + current_qpos = self._env.robot.get_qpos() + batch_size = scaled.shape[0] + target_pose = ( + torch.eye(4, device=self.device).unsqueeze(0).repeat(batch_size, 1, 1) + ) + if scaled.shape[-1] == 6: + target_pose[:, :3, 3] = scaled[:, :3] + target_pose[:, :3, :3] = matrix_from_euler(scaled[:, 3:6]) + elif scaled.shape[-1] == 7: + target_pose[:, :3, 3] = scaled[:, :3] + target_pose[:, :3, :3] = matrix_from_quat(scaled[:, 3:7]) + else: + raise ValueError( + f"EEF pose action must be 6D or 7D, got {scaled.shape[-1]}D" + ) + # Batch IK: robot.compute_ik supports (n_envs, 4, 4) pose and (n_envs, dof) seed + ret, qpos_ik = self._env.robot.compute_ik( + pose=target_pose, + joint_seed=current_qpos, + ) + # Fallback to current_qpos where IK failed + result_qpos = torch.where( + ret.unsqueeze(-1).expand_as(qpos_ik), qpos_ik, current_qpos + ) + return TensorDict( + {"qpos": result_qpos, "ik_success": ret}, + batch_size=[batch_size], + device=self.device, + ) + + +class QvelTerm(ActionTerm): + """Joint velocity action: scale * action -> qvel. + + This action term outputs target joint velocities. + Useful for velocity control tasks. + + Args: + scale: Scaling factor for the action. Defaults to 1.0. + + Example: + >>> cfg = ActionTermCfg(func=QvelTerm, params={"scale": 0.2}) + >>> term = QvelTerm(cfg, env) + >>> action = torch.ones(num_envs, dof) + >>> result = term.process_action(action) + >>> # result["qvel"] = 0.2 * action + """ + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + + @property + def action_dim(self) -> int: + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + qvel = action * self._scale + batch_size = qvel.shape[0] + return TensorDict({"qvel": qvel}, batch_size=[batch_size], device=self.device) + + +class QfTerm(ActionTerm): + """Joint force/torque action: scale * action -> qf. + + This action term outputs target joint forces/torques. + Useful for impedance control or force-based tasks. + + Args: + scale: Scaling factor for the action. Defaults to 1.0. + + Example: + >>> cfg = ActionTermCfg(func=QfTerm, params={"scale": 10.0}) + >>> term = QfTerm(cfg, env) + >>> action = torch.ones(num_envs, dof) + >>> result = term.process_action(action) + >>> # result["qf"] = 10.0 * action + """ + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + + @property + def action_dim(self) -> int: + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + qf = action * self._scale + batch_size = qf.shape[0] + return TensorDict({"qf": qf}, batch_size=[batch_size], device=self.device) + + +class ActionClampTerm(ActionTerm): + """Post-processing term that clamps action values to specified limits. + + This term is typically used in "post" mode to clamp the output of another + action term (e.g., QposTerm) to valid ranges. + + Args: + min: Minimum value for clamping. If None, no lower bound. Defaults to None. + max: Maximum value for clamping. If None, no upper bound. Defaults to None. + + Example: + >>> # Config with both pre and post terms + >>> cfg = { + ... "qpos": ActionTermCfg(func=QposTerm, params={"scale": 1.0}, mode="pre"), + ... "clamp": ActionTermCfg( + ... func=ActionClampTerm, params={"min": -1.0, "max": 1.0}, mode="post" + ... ), + ... } + + Example config (YAML): + .. code-block:: yaml + + action_terms: + qpos: + func: QposTerm + mode: pre + params: + scale: 1.0 + clamp: + func: ActionClampTerm + mode: post + params: + min: -1.0 + max: 1.0 + """ + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._min = cfg.params.get("min", None) + self._max = cfg.params.get("max", None) + + @property + def action_dim(self) -> int: + # Post-processing term inherits dimension from input action + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + clamped = action + if self._min is not None: + clamped = torch.clamp(clamped, min=self._min) + if self._max is not None: + clamped = torch.clamp(clamped, max=self._max) + batch_size = clamped.shape[0] + return TensorDict( + {"qpos": clamped}, batch_size=[batch_size], device=self.device + ) diff --git a/embodichain/lab/gym/envs/managers/cfg.py b/embodichain/lab/gym/envs/managers/cfg.py index 208e5f4f..39abe86a 100644 --- a/embodichain/lab/gym/envs/managers/cfg.py +++ b/embodichain/lab/gym/envs/managers/cfg.py @@ -342,7 +342,15 @@ class ActionTermCfg(FunctorCfg): the format expected by the robot (e.g., qpos, qvel, qf). """ - pass + mode: Literal["pre", "post"] = "pre" + """The mode for the action term. + + - ``pre``: Preprocess raw action from policy (default). This is applied before + the action is sent to the robot control. + - ``post``: Postprocess the action after it has been processed by another term. + This is useful for applying additional transformations like noise, clipping, + or filtering to the output actions. + """ @configclass diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 20ac316a..1318e5e9 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -33,7 +33,7 @@ # Default manager modules for config parsing DEFAULT_MANAGER_MODULES = [ - "embodichain.lab.gym.envs.managers.action_manager", + "embodichain.lab.gym.envs.managers.actions", "embodichain.lab.gym.envs.managers.datasets", "embodichain.lab.gym.envs.managers.randomization", "embodichain.lab.gym.envs.managers.record", diff --git a/tests/gym/envs/managers/test_action_manager.py b/tests/gym/envs/managers/test_action_manager.py index efaa926d..0af09e03 100644 --- a/tests/gym/envs/managers/test_action_manager.py +++ b/tests/gym/envs/managers/test_action_manager.py @@ -19,8 +19,9 @@ import pytest import torch -from embodichain.lab.gym.envs.managers import ( - ActionManager, +from embodichain.lab.gym.envs.managers import ActionManager +from embodichain.lab.gym.envs.managers.actions import ( + ActionClampTerm, DeltaQposTerm, EefPoseTerm, QposTerm, @@ -252,3 +253,166 @@ def test_action_manager_invalid_dict_raises(): with torch.no_grad(): with pytest.raises(ValueError, match="No valid action keys"): manager.process_action({"unknown_key": torch.ones(2, 3)}) + + +# Tests for action term mode (pre/post) + + +def test_action_term_cfg_default_mode(): + """ActionTermCfg defaults to mode='pre'.""" + cfg = ActionTermCfg(func=DeltaQposTerm, params={}) + assert cfg.mode == "pre" + + +def test_action_term_cfg_post_mode(): + """ActionTermCfg supports mode='post'.""" + cfg = ActionTermCfg(func=ActionClampTerm, params={}, mode="post") + assert cfg.mode == "post" + + +def test_action_manager_process_action_pre_mode(): + """ActionManager.process_action defaults to pre mode.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = { + "delta_qpos": ActionTermCfg(func=DeltaQposTerm, params={"scale": 0.1}), + } + manager = ActionManager(cfg, env) + + action = torch.ones(2, 3) + result = manager.process_action(action, mode="pre") + + assert "qpos" in result + expected = env.get_qpos() + 0.1 * action + torch.testing.assert_close(result["qpos"], expected) + + +def test_action_manager_process_action_post_mode(): + """ActionManager.process_action with post mode uses post terms only.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = { + "clamp": ActionTermCfg( + func=ActionClampTerm, params={"min": -0.5, "max": 0.5}, mode="post" + ), + } + manager = ActionManager(cfg, env) + + # Action values exceed clamp limits + action = torch.ones(2, 3) * 2.0 + result = manager.process_action(action, mode="post") + + assert "qpos" in result + # Values should be clamped to [-0.5, 0.5] + torch.testing.assert_close(result["qpos"], torch.ones(2, 3) * 0.5) + + +def test_action_manager_mixed_pre_post_terms(): + """ActionManager with both pre and post terms works correctly.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = { + "qpos": ActionTermCfg(func=QposTerm, params={"scale": 1.0}, mode="pre"), + "clamp": ActionTermCfg( + func=ActionClampTerm, params={"min": 0.0, "max": 1.0}, mode="post" + ), + } + manager = ActionManager(cfg, env) + + # Pre mode: should return qpos term output + action = torch.ones(2, 3) * 0.5 + result_pre = manager.process_action(action, mode="pre") + assert "qpos" in result_pre + torch.testing.assert_close(result_pre["qpos"], torch.ones(2, 3) * 0.5) + + # Post mode: should return clamped output + result_post = manager.process_action(action, mode="post") + assert "qpos" in result_post + # 0.5 is within [0, 1], so no clamping needed + torch.testing.assert_close(result_post["qpos"], torch.ones(2, 3) * 0.5) + + +def test_action_manager_get_functors_by_mode(): + """ActionManager.get_functors_by_mode returns correct terms.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = { + "qpos": ActionTermCfg(func=QposTerm, params={}, mode="pre"), + "clamp": ActionTermCfg(func=ActionClampTerm, params={}, mode="post"), + } + manager = ActionManager(cfg, env) + + pre_terms = manager.get_functors_by_mode("pre") + assert len(pre_terms) == 1 + assert pre_terms[0][0] == "qpos" + + post_terms = manager.get_functors_by_mode("post") + assert len(post_terms) == 1 + assert post_terms[0][0] == "clamp" + + +def test_action_manager_get_action_dim_by_mode(): + """ActionManager.get_action_dim_by_mode returns correct dimensions.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = { + "qpos": ActionTermCfg(func=QposTerm, params={}, mode="pre"), + "clamp": ActionTermCfg(func=ActionClampTerm, params={}, mode="post"), + } + manager = ActionManager(cfg, env) + + assert manager.get_action_dim_by_mode("pre") == 3 + assert manager.get_action_dim_by_mode("post") == 3 + + +def test_action_manager_no_terms_for_mode_raises(): + """ActionManager raises when no terms exist for specified mode.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = { + "qpos": ActionTermCfg(func=QposTerm, params={}, mode="pre"), + } + manager = ActionManager(cfg, env) + + with pytest.raises(ValueError, match="No action terms found for mode 'post'"): + manager.process_action(torch.ones(2, 3), mode="post") + + +def test_action_clamp_term_process_action(): + """ActionClampTerm clamps action values to specified range.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = ActionTermCfg( + func=ActionClampTerm, params={"min": -1.0, "max": 1.0}, mode="post" + ) + term = ActionClampTerm(cfg, env) + + # Test clamping from above + action = torch.ones(2, 3) * 2.0 + result = term.process_action(action) + torch.testing.assert_close(result["qpos"], torch.ones(2, 3)) + + # Test clamping from below + action = torch.ones(2, 3) * -2.0 + result = term.process_action(action) + torch.testing.assert_close(result["qpos"], torch.ones(2, 3) * -1.0) + + # Test no clamping needed + action = torch.ones(2, 3) * 0.5 + result = term.process_action(action) + torch.testing.assert_close(result["qpos"], torch.ones(2, 3) * 0.5) + + +def test_action_clamp_term_only_min(): + """ActionClampTerm with only min specified.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = ActionTermCfg(func=ActionClampTerm, params={"min": 0.0}, mode="post") + term = ActionClampTerm(cfg, env) + + action = torch.ones(2, 3) * -1.0 + result = term.process_action(action) + torch.testing.assert_close(result["qpos"], torch.zeros(2, 3)) + + +def test_action_clamp_term_only_max(): + """ActionClampTerm with only max specified.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = ActionTermCfg(func=ActionClampTerm, params={"max": 1.0}, mode="post") + term = ActionClampTerm(cfg, env) + + action = torch.ones(2, 3) * 2.0 + result = term.process_action(action) + torch.testing.assert_close(result["qpos"], torch.ones(2, 3))