Skip to content

Commit 2021b74

Browse files
authored
Merge pull request #247 from RobotControlStack/juelg/fix-fr3-robot-state
fix(fr3,panda): robot state
2 parents ec308e4 + 0e521b6 commit 2021b74

2 files changed

Lines changed: 29 additions & 4 deletions

File tree

extensions/rcs_fr3/src/rcs_fr3/envs.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(self, env):
1414
self.unwrapped: RobotEnv
1515
assert isinstance(self.unwrapped.robot, hw.Franka), "Robot must be a hw.Franka instance."
1616
self.hw_robot = cast(hw.Franka, self.unwrapped.robot)
17+
self._robot_state_keys: list[str] | None = None
1718

1819
def step(self, action: Any) -> tuple[dict[str, Any], SupportsFloat, bool, bool, dict]:
1920
try:
@@ -30,10 +31,17 @@ def step(self, action: Any) -> tuple[dict[str, Any], SupportsFloat, bool, bool,
3031
def get_obs(self, obs: dict | None = None) -> dict[str, Any]:
3132
if obs is None:
3233
obs = dict(self.unwrapped.get_obs())
33-
# robot_state = cast(hw.FrankaState, self.unwrapped.robot.get_state())
34-
# obs["robot_state"] = vars(robot_state.robot_state)
34+
robot_state = cast(hw.FrankaState, self.unwrapped.robot.get_state())
35+
obs["robot_state"] = self._rs2dict(robot_state.robot_state)
3536
return obs
3637

38+
def _rs2dict(self, state: hw.RobotState):
39+
if self._robot_state_keys is None:
40+
self._robot_state_keys = [
41+
attr for attr in dir(state) if not attr.startswith("__") and not callable(getattr(state, attr))
42+
]
43+
return {key: getattr(state, key) for key in self._robot_state_keys}
44+
3745
def reset(
3846
self, seed: int | None = None, options: dict[str, Any] | None = None
3947
) -> tuple[dict[str, Any], dict[str, Any]]:

extensions/rcs_panda/src/rcs_panda/envs.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,33 @@ def __init__(self, env):
1414
self.unwrapped: RobotEnv
1515
assert isinstance(self.unwrapped.robot, hw.Franka), "Robot must be a hw.Franka instance."
1616
self.hw_robot = cast(hw.Franka, self.unwrapped.robot)
17+
self._robot_state_keys: list[str] | None = None
1718

1819
def step(self, action: Any) -> tuple[dict[str, Any], SupportsFloat, bool, bool, dict]:
1920
try:
20-
return super().step(action)
21+
obs, reward, terminated, truncated, info = super().step(action)
22+
obs = self.get_obs(obs)
23+
return obs, reward, terminated, truncated, info
2124
except hw.exceptions.FrankaControlException as e:
2225
_logger.error("FrankaControlException: %s", e)
2326
self.hw_robot.automatic_error_recovery()
2427
# TODO: this does not work if some wrappers are in between
2528
# PandaHW and RobotEnv
26-
return dict(self.unwrapped.get_obs()), 0, False, True, {}
29+
return self.get_obs(), 0, False, True, {}
30+
31+
def get_obs(self, obs: dict | None = None) -> dict[str, Any]:
32+
if obs is None:
33+
obs = dict(self.unwrapped.get_obs())
34+
robot_state = cast(hw.FrankaState, self.unwrapped.robot.get_state())
35+
obs["robot_state"] = self._rs2dict(robot_state.robot_state)
36+
return obs
37+
38+
def _rs2dict(self, state: hw.RobotState):
39+
if self._robot_state_keys is None:
40+
self._robot_state_keys = [
41+
attr for attr in dir(state) if not attr.startswith("__") and not callable(getattr(state, attr))
42+
]
43+
return {key: getattr(state, key) for key in self._robot_state_keys}
2744

2845
def reset(
2946
self, seed: int | None = None, options: dict[str, Any] | None = None

0 commit comments

Comments
 (0)