@@ -402,52 +402,58 @@ def reset(
402402
403403
404404class PickCubeSuccessWrapper (gym .Wrapper ):
405- """Wrapper to check if the cube is successfully picked up in the FR3SimplePickUpSim environment."""
405+ """
406+ Wrapper to check if the cube is successfully picked up in the FR3SimplePickUpSim environment.
407+ Cube must be lifted 10 cm above the robot base.
408+ Computes a reward between 0 and 1 based on:
409+ - TCP to object distance
410+ - cube z height
411+ - whether the arm is standing still once the task is solved.
412+ """
406413
407- EE_HOME = np .array ([0.34169773 , 0.00047028 , 0.4309004 ])
414+ # In robot coordinates
415+ EE_HOME = np .array ([3.06890567e-01 , 3.76703856e-23 , 4.40282052e-01 ])
408416
409417 def __init__ (self , env , cube_joint_name = "box_joint" ):
410418 super ().__init__ (env )
411419 self .unwrapped : RobotEnv
412420 assert isinstance (self .unwrapped .robot , sim .SimRobot ), "Robot must be a sim.SimRobot instance."
413421 self .sim = env .get_wrapper_attr ("sim" )
414- self .cube_joint_name = cube_joint_name
422+ self .cube_geom_name = "box_geom"
423+ self ._gripper_closing = 0
415424
416425 def step (self , action : dict [str , Any ]):
417426 obs , reward , _ , truncated , info = super ().step (action )
418-
419- success = (
420- self .sim . data . joint ( self . cube_joint_name ). qpos [ 2 ] > 0.15 + 0.852
427+ if (
428+ self . _gripper . get_normalized_width () > 0.01
429+ and self ._gripper . get_normalized_width () < 0.99
421430 and obs ["gripper" ] == GripperWrapper .BINARY_GRIPPER_CLOSED
431+ ):
432+ self ._gripper_closing += 1
433+ else :
434+ self ._gripper_closing = 0
435+ cube_pose = rcs .common .Pose (translation = self .sim .data .geom (self .cube_geom_name ).xpos )
436+ cube_pose = self .unwrapped .robot .to_pose_in_robot_coordinates (cube_pose )
437+ tcp_to_obj_dist = np .linalg .norm (
438+ cube_pose .translation () - self .unwrapped .robot .get_cartesian_position ().translation ()
439+ )
440+ obj_to_goal_dist = 0.10 - min (cube_pose .translation ()[- 1 ], 0.10 )
441+ # NOTE: 4 depends on the time passing between each step.
442+ is_grasped = (
443+ self ._gripper_closing >= 4 # gripper is closing since more than 4 steps
444+ and obs ["gripper" ] == GripperWrapper .BINARY_GRIPPER_CLOSED # command is still close
445+ and tcp_to_obj_dist <= 0.007 # tcp to cube center is max 7mm
422446 )
447+ success = obj_to_goal_dist == 0 and info ["is_grasped" ]
448+ movement = np .linalg .norm (self .sim .data .qvel )
449+
450+ reaching_reward = 1 - np .tanh (5 * tcp_to_obj_dist )
451+ place_reward = 1 - np .tanh (5 * obj_to_goal_dist ) * is_grasped
452+ static_reward = 1 - np .tanh (5 * movement ) * success
453+ info ["is_grasped" ] = is_grasped
423454 info ["success" ] = success
424- if success :
425- reward = 5
426- else :
427- tcp_to_obj_dist = np .linalg .norm (
428- self .sim .data .joint (self .cube_joint_name ).qpos [:3 ]
429- - self .unwrapped .robot .get_cartesian_position ().translation ()
430- )
431- obj_to_goal_dist = np .linalg .norm (self .sim .data .joint (self .cube_joint_name ).qpos [:3 ] - self .EE_HOME )
432-
433- # old reward
434- # reward = -obj_to_goal_dist - tcp_to_obj_dist
435-
436- # Maniskill grasp reward
437- reaching_reward = 1 - np .tanh (5 * tcp_to_obj_dist )
438- reward = reaching_reward
439- is_grasped = info ["is_grasped" ]
440- reward += is_grasped
441- place_reward = 1 - np .tanh (5 * obj_to_goal_dist )
442- reward += place_reward * is_grasped
443-
444- # velocities are currently always zero after a step
445- # qvel = self.agent.robot.get_qvel()
446- # static_reward = 1 - np.tanh(5 * np.linalg.norm(qvel, axis=1))
447- # reward += static_reward * info["is_obj_placed"]
448-
449- # normalize
450- reward /= 5 # type: ignore
455+ reward = reaching_reward + place_reward + static_reward
456+ reward /= 3 # type: ignore
451457 return obs , reward , success , truncated , info
452458
453459
0 commit comments