55import numpy as np
66import rcs
77from rcs import sim
8- from rcs .envs .base import ControlMode , GripperWrapper , RobotEnv
8+ from rcs .envs .base import ControlMode , GripperWrapper , MultiRobotWrapper , RobotEnv
99from rcs .envs .space_utils import ActObsInfoWrapper , VecType
1010from rcs .envs .utils import default_fr3_sim_robot_cfg
1111
@@ -25,7 +25,7 @@ def __init__(self, env: gym.Env, simulation: sim.Sim):
2525 self .sim = simulation
2626
2727
28- class FR3Sim (gym .Wrapper ):
28+ class RobotSimWrapper (gym .Wrapper ):
2929 def __init__ (self , env , simulation : sim .Sim , sim_wrapper : Type [SimWrapper ] | None = None ):
3030 self .sim_wrapper = sim_wrapper
3131 if sim_wrapper is not None :
@@ -58,6 +58,47 @@ def reset(
5858 return obs , info
5959
6060
61+ class MultiSimRobotWrapper (gym .Wrapper ):
62+ """Wraps a dictionary of environments to allow for multi robot control."""
63+
64+ def __init__ (self , env : MultiRobotWrapper , simulation : sim .Sim ):
65+ super ().__init__ (env )
66+ self .env : MultiRobotWrapper
67+ self .sim = simulation
68+ self .sim_robots = cast (dict [str , sim .SimRobot ], {key : e .robot for key , e in self .env .unwrapped_multi .items ()})
69+
70+ def step (self , action : dict [str , Any ]) -> tuple [dict [str , Any ], float , bool , bool , dict ]:
71+ _ , _ , _ , _ , info = super ().step (action )
72+
73+ self .sim .step_until_convergence ()
74+ info ["is_sim_converged" ] = self .sim .is_converged ()
75+ for key in self .envs .envs .items ():
76+ state = self .sim_robots [key ].get_state ()
77+ info [key ]["collision" ] = state .collision
78+ info [key ]["ik_success" ] = state .ik_success
79+
80+ obs = {key : env .get_obs () for key , env in self .env .unwrapped_multi .items ()}
81+ truncated = np .all ([info [key ]["collision" ] or info [key ]["ik_success" ] for key in info ])
82+ return obs , 0.0 , False , bool (truncated ), info
83+
84+ def reset (
85+ self , seed : dict [str , int | None ] | None = None , options : dict [str , Any ] | None = None # type: ignore
86+ ) -> tuple [dict [str , Any ], dict [str , Any ]]:
87+ if seed is None :
88+ seed = {key : None for key in self .env .envs }
89+ if options is None :
90+ options = {key : {} for key in self .env .envs }
91+ obs = {}
92+ info = {}
93+ self .sim .reset ()
94+ for key , env in self .env .envs .items ():
95+ _ , info [key ] = env .reset (seed = seed [key ], options = options [key ])
96+ self .sim .step (1 )
97+ for key , env in self .env .unwrapped_multi .items ():
98+ obs [key ] = cast (dict , env .get_obs ())
99+ return obs , info
100+
101+
61102class GripperWrapperSim (ActObsInfoWrapper ):
62103 def __init__ (self , env , gripper : sim .SimGripper ):
63104 super ().__init__ (env )
@@ -178,7 +219,7 @@ def env_from_xml_paths(
178219 else :
179220 control_mode = env .unwrapped .get_control_mode ()
180221 c_env : gym .Env = RobotEnv (robot , control_mode )
181- c_env = FR3Sim (c_env , simulation )
222+ c_env = RobotSimWrapper (c_env , simulation )
182223 if gripper :
183224 gripper_cfg = sim .SimGripperConfig ()
184225 gripper_cfg .add_id (id )
0 commit comments