@@ -14,16 +14,33 @@ def __init__(self, env):
1414 self .unwrapped : RobotEnv
1515 assert isinstance (self .unwrapped .robot , hw .Franka ), "Robot must be a hw.Franka instance."
1616 self .hw_robot = cast (hw .Franka , self .unwrapped .robot )
17+ self ._robot_state_keys : list [str ] | None = None
1718
1819 def step (self , action : Any ) -> tuple [dict [str , Any ], SupportsFloat , bool , bool , dict ]:
1920 try :
20- return super ().step (action )
21+ obs , reward , terminated , truncated , info = super ().step (action )
22+ obs = self .get_obs (obs )
23+ return obs , reward , terminated , truncated , info
2124 except hw .exceptions .FrankaControlException as e :
2225 _logger .error ("FrankaControlException: %s" , e )
2326 self .hw_robot .automatic_error_recovery ()
2427 # TODO: this does not work if some wrappers are in between
2528 # PandaHW and RobotEnv
26- return dict (self .unwrapped .get_obs ()), 0 , False , True , {}
29+ return self .get_obs (), 0 , False , True , {}
30+
31+ def get_obs (self , obs : dict | None = None ) -> dict [str , Any ]:
32+ if obs is None :
33+ obs = dict (self .unwrapped .get_obs ())
34+ robot_state = cast (hw .FrankaState , self .unwrapped .robot .get_state ())
35+ obs ["robot_state" ] = self ._rs2dict (robot_state .robot_state )
36+ return obs
37+
38+ def _rs2dict (self , state : hw .RobotState ):
39+ if self ._robot_state_keys is None :
40+ self ._robot_state_keys = [
41+ attr for attr in dir (state ) if not attr .startswith ("__" ) and not callable (getattr (state , attr ))
42+ ]
43+ return {key : getattr (state , key ) for key in self ._robot_state_keys }
2744
2845 def reset (
2946 self , seed : int | None = None , options : dict [str , Any ] | None = None
0 commit comments