Skip to content

Commit ed02c1c

Browse files
authored
Merge pull request #186 from utn-mi/juelg/multi-robot
feat(env): multi robot control support
2 parents a4673bb + 7221b01 commit ed02c1c

File tree

3 files changed

+135
-5
lines changed

3 files changed

+135
-5
lines changed

python/rcs/envs/base.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,54 @@ def close(self):
280280
super().close()
281281

282282

283+
class MultiRobotWrapper(gym.Env):
284+
"""Wraps a dictionary of environments to allow for multi robot control."""
285+
286+
def __init__(self, envs: dict[str, gym.Env] | dict[str, gym.Wrapper]):
287+
self.envs = envs
288+
self.unwrapped_multi = cast(dict[str, RobotEnv], {key: env.unwrapped for key, env in envs.items()})
289+
290+
def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
291+
# follows gym env by combinding a dict of envs into a single env
292+
obs = {}
293+
reward = 0.0
294+
terminated = False
295+
truncated = False
296+
info = {}
297+
for key, env in self.envs.items():
298+
obs[key], r, t, tr, info[key] = env.step(action[key])
299+
reward += float(r)
300+
terminated = terminated or t
301+
truncated = truncated or tr
302+
info[key]["terminated"] = t
303+
info[key]["truncated"] = tr
304+
return obs, reward, terminated, truncated, info
305+
306+
def reset(
307+
self, seed: dict[str, int] | None = None, options: dict[str, dict[str, Any]] | None = None # type: ignore
308+
) -> tuple[dict[str, Any], dict[str, Any]]:
309+
obs = {}
310+
info = {}
311+
312+
seed_ = seed if seed is not None else {key: None for key in self.envs} # type: ignore
313+
options_ = options if options is not None else {key: None for key in self.envs} # type: ignore
314+
for key, env in self.envs.items():
315+
obs[key], info[key] = env.reset(seed=seed_[key], options=options_[key])
316+
return obs, info
317+
318+
def get_wrapper_attr(self, name: str) -> Any:
319+
"""Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object.
320+
If lower environments have the same attribute, it returns a dictionary of the attribute values.
321+
"""
322+
if name in self.__dir__():
323+
return getattr(self, name)
324+
return {key: env.get_wrapper_attr(name) for key, env in self.envs.items()}
325+
326+
def close(self):
327+
for env in self.envs.values():
328+
env.close()
329+
330+
283331
class RelativeTo(Enum):
284332
LAST_STEP = auto()
285333
CONFIGURED_ORIGIN = auto()

python/rcs/envs/creators.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ControlMode,
1717
GripperWrapper,
1818
HandWrapper,
19+
MultiRobotWrapper,
1920
RelativeActionSpace,
2021
RelativeTo,
2122
RobotEnv,
@@ -24,10 +25,10 @@
2425
from rcs.envs.sim import (
2526
CamRobot,
2627
CollisionGuard,
27-
FR3Sim,
2828
GripperWrapperSim,
2929
PickCubeSuccessWrapper,
3030
RandomCubePos,
31+
RobotSimWrapper,
3132
SimWrapper,
3233
)
3334
from rcs.envs.space_utils import VecType
@@ -124,6 +125,46 @@ def __call__( # type: ignore
124125
return env
125126

126127

128+
class RCSFR3MultiEnvCreator(RCSHardwareEnvCreator):
129+
def __call__( # type: ignore
130+
ips: list[str],
131+
control_mode: ControlMode,
132+
robot_cfg: rcs.hw.FR3Config,
133+
gripper_cfg: rcs.hw.FHConfig | None = None,
134+
camera_set: BaseHardwareCameraSet | None = None,
135+
max_relative_movement: float | tuple[float, float] | None = None,
136+
relative_to: RelativeTo = RelativeTo.LAST_STEP,
137+
urdf_path: str | PathLike | None = None,
138+
) -> gym.Env:
139+
140+
urdf_path = get_urdf_path(urdf_path, allow_none_if_not_found=False)
141+
ik = rcs.common.IK(str(urdf_path)) if urdf_path is not None else None
142+
robots: dict[str, rcs.hw.FR3] = {}
143+
for ip in ips:
144+
robots[ip] = rcs.hw.FR3(ip, ik)
145+
robots[ip].set_parameters(robot_cfg)
146+
147+
envs = {}
148+
for ip in ips:
149+
env: gym.Env = RobotEnv(robots[ip], control_mode)
150+
env = FR3HW(env)
151+
if gripper_cfg is not None:
152+
gripper = rcs.hw.FrankaHand(ip, gripper_cfg)
153+
env = GripperWrapper(env, gripper, binary=True)
154+
155+
if max_relative_movement is not None:
156+
env = RelativeActionSpace(env, max_mov=max_relative_movement, relative_to=relative_to)
157+
envs[ip] = env
158+
159+
env = MultiRobotWrapper(envs)
160+
if camera_set is not None:
161+
camera_set.start()
162+
camera_set.wait_for_frames()
163+
logger.info("CameraSet started")
164+
env = CameraSetWrapper(env, camera_set)
165+
return env
166+
167+
127168
class RCSFR3DefaultEnvCreator(RCSHardwareEnvCreator):
128169
def __call__( # type: ignore
129170
self,
@@ -192,7 +233,7 @@ def __call__( # type: ignore
192233
ik = rcs.common.IK(urdf_path)
193234
robot = rcs.sim.SimRobot(simulation, ik, robot_cfg)
194235
env: gym.Env = RobotEnv(robot, control_mode)
195-
env = FR3Sim(env, simulation, sim_wrapper)
236+
env = RobotSimWrapper(env, simulation, sim_wrapper)
196237

197238
if camera_set_cfg is not None:
198239
camera_set = SimCameraSet(simulation, camera_set_cfg)

python/rcs/envs/sim.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import rcs
77
from rcs import sim
8-
from rcs.envs.base import ControlMode, GripperWrapper, RobotEnv
8+
from rcs.envs.base import ControlMode, GripperWrapper, MultiRobotWrapper, RobotEnv
99
from rcs.envs.space_utils import ActObsInfoWrapper, VecType
1010
from rcs.envs.utils import default_fr3_sim_robot_cfg
1111

@@ -25,7 +25,7 @@ def __init__(self, env: gym.Env, simulation: sim.Sim):
2525
self.sim = simulation
2626

2727

28-
class FR3Sim(gym.Wrapper):
28+
class RobotSimWrapper(gym.Wrapper):
2929
def __init__(self, env, simulation: sim.Sim, sim_wrapper: Type[SimWrapper] | None = None):
3030
self.sim_wrapper = sim_wrapper
3131
if sim_wrapper is not None:
@@ -58,6 +58,47 @@ def reset(
5858
return obs, info
5959

6060

61+
class MultiSimRobotWrapper(gym.Wrapper):
62+
"""Wraps a dictionary of environments to allow for multi robot control."""
63+
64+
def __init__(self, env: MultiRobotWrapper, simulation: sim.Sim):
65+
super().__init__(env)
66+
self.env: MultiRobotWrapper
67+
self.sim = simulation
68+
self.sim_robots = cast(dict[str, sim.SimRobot], {key: e.robot for key, e in self.env.unwrapped_multi.items()})
69+
70+
def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, bool, dict]:
71+
_, _, _, _, info = super().step(action)
72+
73+
self.sim.step_until_convergence()
74+
info["is_sim_converged"] = self.sim.is_converged()
75+
for key in self.envs.envs.items():
76+
state = self.sim_robots[key].get_state()
77+
info[key]["collision"] = state.collision
78+
info[key]["ik_success"] = state.ik_success
79+
80+
obs = {key: env.get_obs() for key, env in self.env.unwrapped_multi.items()}
81+
truncated = np.all([info[key]["collision"] or info[key]["ik_success"] for key in info])
82+
return obs, 0.0, False, bool(truncated), info
83+
84+
def reset(
85+
self, seed: dict[str, int | None] | None = None, options: dict[str, Any] | None = None # type: ignore
86+
) -> tuple[dict[str, Any], dict[str, Any]]:
87+
if seed is None:
88+
seed = {key: None for key in self.env.envs}
89+
if options is None:
90+
options = {key: {} for key in self.env.envs}
91+
obs = {}
92+
info = {}
93+
self.sim.reset()
94+
for key, env in self.env.envs.items():
95+
_, info[key] = env.reset(seed=seed[key], options=options[key])
96+
self.sim.step(1)
97+
for key, env in self.env.unwrapped_multi.items():
98+
obs[key] = cast(dict, env.get_obs())
99+
return obs, info
100+
101+
61102
class GripperWrapperSim(ActObsInfoWrapper):
62103
def __init__(self, env, gripper: sim.SimGripper):
63104
super().__init__(env)
@@ -178,7 +219,7 @@ def env_from_xml_paths(
178219
else:
179220
control_mode = env.unwrapped.get_control_mode()
180221
c_env: gym.Env = RobotEnv(robot, control_mode)
181-
c_env = FR3Sim(c_env, simulation)
222+
c_env = RobotSimWrapper(c_env, simulation)
182223
if gripper:
183224
gripper_cfg = sim.SimGripperConfig()
184225
gripper_cfg.add_id(id)

0 commit comments

Comments
 (0)