Skip to content

Commit a804132

Browse files
authored
Merge pull request #243 from RobotControlStack/krack/pick_cube_reward
feat: more precise is_grasped and rewards for pick cube
2 parents a34af93 + 6930b14 commit a804132

1 file changed

Lines changed: 39 additions & 33 deletions

File tree

python/rcs/envs/sim.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -402,52 +402,58 @@ def reset(
402402

403403

404404
class PickCubeSuccessWrapper(gym.Wrapper):
405-
"""Wrapper to check if the cube is successfully picked up in the FR3SimplePickUpSim environment."""
405+
"""
406+
Wrapper to check if the cube is successfully picked up in the FR3SimplePickUpSim environment.
407+
Cube must be lifted 10 cm above the robot base.
408+
Computes a reward between 0 and 1 based on:
409+
- TCP to object distance
410+
- cube z height
411+
- whether the arm is standing still once the task is solved.
412+
"""
406413

407-
EE_HOME = np.array([0.34169773, 0.00047028, 0.4309004])
414+
# In robot coordinates
415+
EE_HOME = np.array([3.06890567e-01, 3.76703856e-23, 4.40282052e-01])
408416

409417
def __init__(self, env, cube_joint_name="box_joint"):
410418
super().__init__(env)
411419
self.unwrapped: RobotEnv
412420
assert isinstance(self.unwrapped.robot, sim.SimRobot), "Robot must be a sim.SimRobot instance."
413421
self.sim = env.get_wrapper_attr("sim")
414-
self.cube_joint_name = cube_joint_name
422+
self.cube_geom_name = "box_geom"
423+
self._gripper_closing = 0
415424

416425
def step(self, action: dict[str, Any]):
417426
obs, reward, _, truncated, info = super().step(action)
418-
419-
success = (
420-
self.sim.data.joint(self.cube_joint_name).qpos[2] > 0.15 + 0.852
427+
if (
428+
self._gripper.get_normalized_width() > 0.01
429+
and self._gripper.get_normalized_width() < 0.99
421430
and obs["gripper"] == GripperWrapper.BINARY_GRIPPER_CLOSED
431+
):
432+
self._gripper_closing += 1
433+
else:
434+
self._gripper_closing = 0
435+
cube_pose = rcs.common.Pose(translation=self.sim.data.geom(self.cube_geom_name).xpos)
436+
cube_pose = self.unwrapped.robot.to_pose_in_robot_coordinates(cube_pose)
437+
tcp_to_obj_dist = np.linalg.norm(
438+
cube_pose.translation() - self.unwrapped.robot.get_cartesian_position().translation()
439+
)
440+
obj_to_goal_dist = 0.10 - min(cube_pose.translation()[-1], 0.10)
441+
# NOTE: 4 depends on the time passing between each step.
442+
is_grasped = (
443+
self._gripper_closing >= 4 # gripper is closing since more than 4 steps
444+
and obs["gripper"] == GripperWrapper.BINARY_GRIPPER_CLOSED # command is still close
445+
and tcp_to_obj_dist <= 0.007 # tcp to cube center is max 7mm
422446
)
447+
success = obj_to_goal_dist == 0 and info["is_grasped"]
448+
movement = np.linalg.norm(self.sim.data.qvel)
449+
450+
reaching_reward = 1 - np.tanh(5 * tcp_to_obj_dist)
451+
place_reward = 1 - np.tanh(5 * obj_to_goal_dist) * is_grasped
452+
static_reward = 1 - np.tanh(5 * movement) * success
453+
info["is_grasped"] = is_grasped
423454
info["success"] = success
424-
if success:
425-
reward = 5
426-
else:
427-
tcp_to_obj_dist = np.linalg.norm(
428-
self.sim.data.joint(self.cube_joint_name).qpos[:3]
429-
- self.unwrapped.robot.get_cartesian_position().translation()
430-
)
431-
obj_to_goal_dist = np.linalg.norm(self.sim.data.joint(self.cube_joint_name).qpos[:3] - self.EE_HOME)
432-
433-
# old reward
434-
# reward = -obj_to_goal_dist - tcp_to_obj_dist
435-
436-
# Maniskill grasp reward
437-
reaching_reward = 1 - np.tanh(5 * tcp_to_obj_dist)
438-
reward = reaching_reward
439-
is_grasped = info["is_grasped"]
440-
reward += is_grasped
441-
place_reward = 1 - np.tanh(5 * obj_to_goal_dist)
442-
reward += place_reward * is_grasped
443-
444-
# velocities are currently always zero after a step
445-
# qvel = self.agent.robot.get_qvel()
446-
# static_reward = 1 - np.tanh(5 * np.linalg.norm(qvel, axis=1))
447-
# reward += static_reward * info["is_obj_placed"]
448-
449-
# normalize
450-
reward /= 5 # type: ignore
455+
reward = reaching_reward + place_reward + static_reward
456+
reward /= 3 # type: ignore
451457
return obs, reward, success, truncated, info
452458

453459

0 commit comments

Comments
 (0)