From 5fd8ddaf43e82ef9b14ecf06a0649d7a458dd77b Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 23 Mar 2026 18:59:45 +0800 Subject: [PATCH 1/6] draft --- embodichain/lab/sim/objects/rigid_object.py | 67 +++ .../graspkit/pg_grasp/antipodal_annotator.py | 489 ++++++++++++++++++ .../graspkit/pg_grasp/antipodal_sampler.py | 231 +++++++++ examples/sim/demo/grasp_mug.py | 257 +++++++++ 4 files changed, 1044 insertions(+) create mode 100644 embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py create mode 100644 embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py create mode 100644 examples/sim/demo/grasp_mug.py diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index 565c5bf4..62207baa 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -34,6 +34,11 @@ from embodichain.utils.math import convert_quat from embodichain.utils.math import matrix_from_quat, quat_from_matrix, matrix_from_euler from embodichain.utils import logger +from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import ( + GraspAnnotator, + GraspAnnotatorCfg, +) +import torch.nn.functional as F @dataclass @@ -1108,3 +1113,65 @@ def destroy(self) -> None: arenas = [env] for i, entity in enumerate(self._entities): arenas[i].remove_actor(entity) + + def get_grasp_pose( + self, + cfg: GraspAnnotatorCfg, + approach_direction: torch.Tensor = None, + is_visual: bool = False, + ) -> torch.Tensor: + if approach_direction is None: + approach_direction = torch.tensor( + [0, 0, -1], dtype=torch.float32, device=self.device + ) + approach_direction = F.normalize(approach_direction, dim=-1) + if hasattr(self, "_grasp_annotator") is False: + self._grasp_annotator = GraspAnnotator(cfg=cfg) + if hasattr(self, "_hit_point_pairs") is False or cfg.force_regenerate: + vertices = torch.tensor( + self._entities[0].get_vertices(), + dtype=torch.float32, + device=self.device, + ) + triangles = torch.tensor( + self._entities[0].get_triangles(), dtype=torch.int32, device=self.device + ) + scale = torch.tensor( + self._entities[0].get_body_scale(), + dtype=torch.float32, + device=self.device, + ) + vertices = vertices * scale + self._hit_point_pairs = self._grasp_annotator.annotate(vertices, triangles) + + poses = self.get_local_pose(to_matrix=True) + poses = torch.as_tensor(poses, dtype=torch.float32, device=self.device) + grasp_poses = [] + open_lengths = [] + for pose in poses: + grasp_pose, open_length = self._grasp_annotator.get_approach_grasp_poses( + self._hit_point_pairs, pose, approach_direction + ) + grasp_poses.append(grasp_pose) + open_lengths.append(open_length) + grasp_poses = torch.cat( + [grasp_pose.unsqueeze(0) for grasp_pose in grasp_poses], dim=0 + ) + + if is_visual: + vertices = self._entities[0].get_vertices() + triangles = self._entities[0].get_triangles() + scale = self._entities[0].get_body_scale() + vertices = vertices * scale + GraspAnnotator.visualize_grasp_pose( + vertices=torch.tensor( + vertices, dtype=torch.float32, device=self.device + ), + triangles=torch.tensor( + triangles, dtype=torch.int32, device=self.device + ), + obj_pose=poses[0], + grasp_pose=grasp_poses[0], + open_length=open_lengths[0], + ) + return grasp_poses diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py new file mode 100644 index 00000000..4852879e --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -0,0 +1,489 @@ +import os +import argparse +import open3d as o3d +import time +from pathlib import Path +from typing import Any, cast +import torch +import numpy as np +import trimesh + +import viser +import viser.transforms as tf +from embodichain.utils import logger +from dataclasses import dataclass +from embodichain.toolkits.graspkit.pg_grasp.antipodal_sampler import ( + AntipodalSampler, + AntipodalSamplerCfg, +) +import hashlib +import torch.nn.functional as F +import tempfile + + +@dataclass +class GraspAnnotatorCfg: + viser_port: int = 15531 + use_largest_connected_component: bool = False + antipodal_sampler_cfg: AntipodalSamplerCfg = AntipodalSamplerCfg() + force_regenerate: bool = False + max_deviation_angle: float = np.pi / 12 + + +@dataclass +class SelectResult: + vertex_indices: np.ndarray | None = None + face_indices: np.ndarray | None = None + vertices: np.ndarray | None = None + faces: np.ndarray | None = None + + +class GraspAnnotator: + def __init__(self, cfg: GraspAnnotatorCfg = GraspAnnotatorCfg()) -> None: + self.cfg = cfg + self.antipodal_sampler = AntipodalSampler(cfg=cfg.antipodal_sampler_cfg) + + def annotate(self, vertices: torch.Tensor, triangles: torch.Tensor): + cache_path = self._get_cache_dir(vertices, triangles) + if os.path.exists(cache_path) and not self.cfg.force_regenerate: + logger.log_info( + f"Found existing antipodal retult. Loading cached antipodal pairs from {cache_path}" + ) + hit_point_pairs = torch.tensor( + np.load(cache_path), dtype=torch.float32, device=vertices.device + ) + return hit_point_pairs + else: + logger.log_info( + f"[Viser] *****Annotate grasp region in http://localhost:{self.cfg.viser_port}" + ) + + self.mesh = trimesh.Trimesh( + vertices=vertices.to("cpu").numpy(), + faces=triangles.to("cpu").numpy(), + process=False, + force="mesh", + ) + self.device = vertices.device + + server = viser.ViserServer(port=self.cfg.viser_port) + server.gui.configure_theme(brand_color=(130, 0, 150)) + server.scene.set_up_direction("+z") + + mesh_handle = server.scene.add_mesh_trimesh(name="/mesh", mesh=self.mesh) + selected_overlay: viser.GlbHandle | None = None + selection: SelectResult = SelectResult() + + hit_point_pairs = None + return_flag = False + + @server.on_client_connect + def _(client: viser.ClientHandle) -> None: + nonlocal mesh_handle + nonlocal selected_overlay + nonlocal selection + + # client.camera.position = np.array([0.0, 0.0, -0.5]) + # client.camera.wxyz = np.array([1.0, 0.0, 0.0, 0.0]) + + select_button = client.gui.add_button( + "Rect Select Region", icon=viser.Icon.PAINT + ) + confirm_button = client.gui.add_button("Confirm Selection") + + @select_button.on_click + def _(_evt: viser.GuiEvent) -> None: + select_button.disabled = True + + @client.scene.on_pointer_event(event_type="rect-select") + def _(event: viser.ScenePointerEvent) -> None: + nonlocal mesh_handle + nonlocal selected_overlay + nonlocal selection + nonlocal hit_point_pairs + client.scene.remove_pointer_callback() + + proj, depth = GraspAnnotator._project_vertices_to_screen( + cast(np.ndarray, self.mesh.vertices), + mesh_handle, + event.client.camera, + ) + + lower = np.minimum( + np.array(event.screen_pos[0]), np.array(event.screen_pos[1]) + ) + upper = np.maximum( + np.array(event.screen_pos[0]), np.array(event.screen_pos[1]) + ) + vertex_mask = ((proj >= lower) & (proj <= upper)).all(axis=1) & ( + depth > 1e-6 + ) + + selection = GraspAnnotator._extract_selection( + self.mesh, vertex_mask, self.cfg.use_largest_connected_component + ) + if selection.vertices is None: + logger.log_warning("[Selection] No vertices selected.") + return + + color_mesh = self.mesh.copy() + used_vertex_indices = selection.vertex_indices + vertex_colors = np.tile( + np.array([[0.85, 0.85, 0.85, 1.0]]), + (self.mesh.vertices.shape[0], 1), + ) + vertex_colors[used_vertex_indices] = np.array( + [0.56, 0.17, 0.92, 1.0] + ) + color_mesh.visual.vertex_colors = vertex_colors # type: ignore + mesh_handle = server.scene.add_mesh_trimesh( + name="/mesh", mesh=color_mesh + ) + + if selected_overlay is not None: + selected_overlay.remove() + selected_mesh = trimesh.Trimesh( + vertices=selection.vertices, + faces=selection.faces, + process=False, + ) + selected_mesh.visual.face_colors = (0.9, 0.2, 0.2, 0.65) # type: ignore + selected_overlay = server.scene.add_mesh_trimesh( + name="/selected", mesh=selected_mesh + ) + logger.log_info( + f"[Selection] Selected {selection.vertex_indices.size} vertices and {selection.face_indices.size} faces." + ) + + hit_point_pairs = self.antipodal_sampler.sample( + torch.tensor(selection.vertices, device=self.device), + torch.tensor(selection.faces, device=self.device), + ) + extended_hit_point_pairs = GraspAnnotator._extend_hit_point_pairs( + hit_point_pairs + ) + server.scene.add_line_segments( + name="/antipodal_pairs", + points=extended_hit_point_pairs.to("cpu").numpy(), + colors=(20, 200, 200), + line_width=1.5, + ) + + @client.scene.on_pointer_callback_removed + def _() -> None: + select_button.disabled = False + + @confirm_button.on_click + def _(_evt: viser.GuiEvent) -> None: + nonlocal return_flag + if selection.vertices is None: + logger.log_warning("[Selection] No vertex selected.") + return + else: + logger.log_info( + f"[Selection] {selection.vertices.shape[0]}vertices selected. Generating antipodal point pairs." + ) + return_flag = True + + while True: + if return_flag: + # save result to cache + if hit_point_pairs is not None: + self._save_cache(cache_path, hit_point_pairs) + break + time.sleep(0.5) + return hit_point_pairs + + def _get_cache_dir(self, vertices: torch.Tensor, triangles: torch.Tensor): + vert_bytes = vertices.to("cpu").numpy().tobytes() + face_bytes = triangles.to("cpu").numpy().tobytes() + md5_hash = hashlib.md5(vert_bytes + face_bytes).hexdigest() + cache_path = os.path.join( + tempfile.gettempdir(), f"antipodal_cache_{md5_hash}.npy" + ) + return cache_path + + def _save_cache(self, cache_path: str, hit_point_pairs: torch.Tensor): + np.save(cache_path, hit_point_pairs.cpu().numpy().astype(np.float32)) + + @staticmethod + def _extend_hit_point_pairs(hit_point_pairs: torch.Tensor): + origin_points = hit_point_pairs[:, 0, :] + hit_points = hit_point_pairs[:, 1, :] + mid_points = (origin_points + hit_points) / 2 + point_diff = hit_points - origin_points + extended_origin = mid_points - 0.8 * point_diff + extended_hit = mid_points + 0.8 * point_diff + extended_point_pairs = torch.cat( + [extended_origin[:, None, :], extended_hit[:, None, :]], dim=1 + ) + return extended_point_pairs + + @staticmethod + def _project_vertices_to_screen( + vertices_mesh: np.ndarray, + mesh_handle: viser.GlbHandle, + camera: Any, + ) -> tuple[np.ndarray, np.ndarray]: + T_world_mesh = tf.SE3.from_rotation_and_translation( + tf.SO3(np.asarray(mesh_handle.wxyz)), + np.asarray(mesh_handle.position), + ) + vertices_world_h = ( + T_world_mesh.as_matrix() + @ np.hstack([vertices_mesh, np.ones((vertices_mesh.shape[0], 1))]).T + ).T + vertices_world = vertices_world_h[:, :3] + + T_camera_world = tf.SE3.from_rotation_and_translation( + tf.SO3(np.asarray(camera.wxyz)), + np.asarray(camera.position), + ).inverse() + vertices_camera_h = ( + T_camera_world.as_matrix() + @ np.hstack([vertices_world, np.ones((vertices_world.shape[0], 1))]).T + ).T + vertices_camera = vertices_camera_h[:, :3] + + fov = float(camera.fov) + aspect = float(camera.aspect) + projected = vertices_camera[:, :2] / np.maximum(vertices_camera[:, 2:3], 1e-8) + projected /= np.tan(fov / 2.0) + projected[:, 0] /= aspect + projected = (1.0 + projected) / 2.0 + return projected, vertices_camera[:, 2] + + def _extract_selection( + mesh: trimesh.Trimesh, + vertex_mask: np.ndarray, + largest_component: bool, + ) -> SelectResult: + def _largest_connected_face_component(face_ids: np.ndarray) -> np.ndarray: + if face_ids.size <= 1: + return face_ids + + face_id_set = set(face_ids.tolist()) + parent: dict[int, int] = { + int(face_id): int(face_id) for face_id in face_ids + } + + def find(x: int) -> int: + root = x + while parent[root] != root: + root = parent[root] + while parent[x] != x: + x_parent = parent[x] + parent[x] = root + x = x_parent + return root + + def union(a: int, b: int) -> None: + ra, rb = find(a), find(b) + if ra != rb: + parent[rb] = ra + + face_adjacency = cast(np.ndarray, mesh.face_adjacency) + for face_a, face_b in face_adjacency: + if int(face_a) in face_id_set and int(face_b) in face_id_set: + union(int(face_a), int(face_b)) + + groups: dict[int, list[int]] = {} + for face_id in face_ids: + root = find(int(face_id)) + groups.setdefault(root, []).append(int(face_id)) + + largest_group = max(groups.values(), key=len) + return np.array(largest_group, dtype=np.int32) + + faces = cast(np.ndarray, mesh.faces) + face_mask = np.all(vertex_mask[faces], axis=1) + + face_indices = np.flatnonzero(face_mask) + if face_indices.size == 0: + return SelectResult() + if largest_component: + face_indices = _largest_connected_face_component(face_indices) + if face_indices.size == 0: + return SelectResult() + + selected_face_vertices = faces[face_indices] + vertex_indices = np.unique(selected_face_vertices.reshape(-1)) + + old_to_new = np.full(mesh.vertices.shape[0], -1, dtype=np.int32) + old_to_new[vertex_indices] = np.arange(vertex_indices.size, dtype=np.int32) + + sub_vertices = np.asarray(mesh.vertices)[vertex_indices] + sub_faces = np.asarray(old_to_new)[selected_face_vertices] + + return SelectResult( + vertex_indices=vertex_indices, + face_indices=face_indices, + vertices=sub_vertices, + faces=sub_faces, + ) + + @staticmethod + def _apply_transform(points: torch.Tensor, transform: torch.Tensor) -> torch.Tensor: + r = transform[:3, :3] + t = transform[:3, 3] + return points @ r.T + t + + def get_approach_grasp_poses( + self, + hit_point_pairs: torch.Tensor, + object_pose: torch.Tensor, + approach_direction: torch.Tensor, + ) -> torch.Tensor: + """Get grasp pose given approach direction + + Args: + hit_point_pairs (torch.Tensor): (N, 2, 3) tensor of N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. + object_pose (torch.Tensor): (4, 4) homogeneous transformation matrix representing the pose of the object in the world frame. + approach_direction (torch.Tensor): (3,) unit vector representing the desired approach direction of the gripper in the world frame. + + Returns: + torch.Tensor: (4, 4) homogeneous transformation matrix representing the grasp pose in the world frame that aligns the gripper's approach direction with the given approach_direction. Returns None if no valid grasp pose can be found. + """ + origin_points = hit_point_pairs[:, 0, :] + hit_points = hit_point_pairs[:, 1, :] + print("origin_points dtype:", origin_points.dtype) + print("object_pose dtype:", object_pose.dtype) + origin_points_ = self._apply_transform(origin_points, object_pose) + hit_points_ = self._apply_transform(hit_points, object_pose) + centers = (origin_points_ + hit_points_) / 2 + center = centers.mean(dim=0) + + # get best grasp pose + grasp_x = F.normalize(hit_points_ - origin_points_, dim=-1) + cos_angle = torch.clamp((grasp_x * approach_direction).sum(dim=-1), -1.0, 1.0) + positive_angle = torch.abs(torch.acos(cos_angle)) + antipodal_length = torch.norm(hit_points_ - origin_points_, dim=-1) + length_cost = 1 - antipodal_length / antipodal_length.max() + angle_cost = torch.abs(positive_angle - 0.5 * torch.pi) / (0.5 * torch.pi) + center_distance = torch.norm(centers - center, dim=-1) + center_cost = center_distance / center_distance.max() + total_cost = 0.4 * angle_cost + 0.3 * length_cost + 0.3 * center_cost + best_idx = torch.argmin(total_cost) + + best_open_length = torch.norm(hit_points_[best_idx] - origin_points_[best_idx]) + best_grasp_x = grasp_x[best_idx] + best_grasp_center = centers[best_idx] + best_grasp_y = torch.cross(approach_direction, best_grasp_x, dim=0) + best_grasp_y = F.normalize(best_grasp_y, dim=-1) + best_grasp_z = torch.cross(best_grasp_x, best_grasp_y, dim=0) + best_grasp_z = F.normalize(best_grasp_z, dim=-1) + grasp_pose = torch.eye(4, device=hit_point_pairs.device, dtype=torch.float32) + grasp_pose[:3, 0] = best_grasp_x + grasp_pose[:3, 1] = best_grasp_y + grasp_pose[:3, 2] = best_grasp_z + grasp_pose[:3, 3] = best_grasp_center + return grasp_pose, best_open_length + + @staticmethod + def visualize_grasp_pose( + vertices: torch.Tensor, + triangles: torch.Tensor, + obj_pose: torch.Tensor, + grasp_pose: torch.Tensor, + open_length: float, + ): + mesh = o3d.geometry.TriangleMesh( + vertices=o3d.utility.Vector3dVector(vertices.to("cpu").numpy()), + triangles=o3d.utility.Vector3iVector(triangles.to("cpu").numpy()), + ) + mesh.compute_vertex_normals() + mesh.paint_uniform_color([0.3, 0.6, 0.3]) + mesh.transform(obj_pose.to("cpu").numpy()) + vertices_ = torch.tensor( + np.asarray(mesh.vertices), device=vertices.device, dtype=vertices.dtype + ) + mesh_scale = (vertices_.max(dim=0)[0] - vertices_.min(dim=0)[0]).max().item() + groud_plane = o3d.geometry.TriangleMesh.create_cylinder( + radius=mesh_scale, height=0.01 * mesh_scale + ) + groud_plane.compute_vertex_normals() + center = vertices_.mean(dim=0) + z_sim = vertices_.min(dim=0)[0][2].item() + groud_plane.translate( + (center[0].item(), center[1].item(), z_sim - 0.005 * mesh_scale) + ) + + draw_thickness = 0.02 * mesh_scale + draw_length = 0.3 * mesh_scale + grasp_finger1 = o3d.geometry.TriangleMesh.create_box( + draw_thickness, draw_thickness, draw_length + ) + grasp_finger1.translate( + (-0.5 * draw_thickness, -0.5 * draw_thickness, -0.5 * draw_length) + ) + grasp_finger2 = o3d.geometry.TriangleMesh.create_box( + draw_thickness, draw_thickness, draw_length + ) + grasp_finger2.translate( + (-0.5 * draw_thickness, -0.5 * draw_thickness, -0.5 * draw_length) + ) + grasp_finger1.translate((-open_length / 2, 0, -0.25 * draw_length)) + grasp_finger2.translate((open_length / 2, 0, -0.25 * draw_length)) + grasp_root1 = o3d.geometry.TriangleMesh.create_box( + open_length, draw_thickness, draw_thickness + ) + grasp_root1.translate( + (-open_length / 2, -0.5 * draw_thickness, -0.5 * draw_thickness) + ) + grasp_root1.translate((0, 0, -0.75 * draw_length)) + grasp_root2 = o3d.geometry.TriangleMesh.create_box( + draw_thickness, draw_thickness, draw_length + ) + grasp_root2.translate( + (-0.5 * draw_thickness, -0.5 * draw_thickness, -0.5 * draw_length) + ) + grasp_root2.translate((0, 0, -1.25 * draw_length)) + + grasp_visual = grasp_finger1 + grasp_finger2 + grasp_root1 + grasp_root2 + grasp_visual.paint_uniform_color([0.8, 0.2, 0.8]) + grasp_visual.transform(grasp_pose.to("cpu").numpy()) + o3d.visualization.draw_geometries( + [grasp_visual, mesh, groud_plane], + window_name="Grasp Pose Visualization", + mesh_show_back_face=True, + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Viser mesh 标注工具:框选并导出对应顶点与三角面" + ) + parser.add_argument( + "--mesh", type=Path, required=True, help="输入 mesh 文件路径,例如 mug.obj" + ) + parser.add_argument("--scale", type=float, default=1.0, help="加载后整体缩放系数") + parser.add_argument("--port", type=int, default=12151, help="viser 服务端口") + parser.add_argument( + "--output-dir", + type=Path, + default=Path("outputs/mesh_annotations"), + help="标注结果导出目录", + ) + parser.add_argument( + "--largest-component", + action="store_true", + help="只保留框选结果中的最大连通块(常用于稳定提取把手等局部)", + ) + args = parser.parse_args() + + mesh = trimesh.load(args.mesh, process=False, force="mesh") + vertices = mesh.vertices * args.scale + triangles = mesh.faces + cfg = GraspAnnotatorCfg( + force_regenerate=True, + ) + tool = GraspAnnotator(cfg=cfg) + hit_point_pairs = tool.annotate( + vertices=torch.from_numpy(vertices).float(), + triangles=torch.from_numpy(triangles).long(), + ) + logger.log_info(f"Sample {hit_point_pairs.shape[0]} antipodal point pairs.") + + +if __name__ == "__main__": + main() diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py new file mode 100644 index 00000000..1eb3ec61 --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -0,0 +1,231 @@ +import torch +import torch.nn.functional as F +import numpy as np +import open3d as o3d +import open3d.core as o3c +from dataclasses import dataclass +from embodichain.utils import logger + + +@dataclass +class AntipodalSamplerCfg: + n_sample: int = 10000 + """surface point sample number""" + max_angle: float = np.pi / 12 + """maximum angle (in radians) to randomly disturb the ray direction for antipodal point sampling, used to increase the diversity of sampled antipodal points. Note that setting max_angle to 0 will disable the random disturbance and sample antipodal points strictly along the surface normals, which may result in less diverse antipodal points and may not be ideal for all objects or grasping scenarios.""" + max_length: float = 0.1 + """maximum gripper open width, used to filter out antipodal points that are too far apart to be grasped""" + min_length: float = 0.001 + """minimum gripper open width, used to filter out antipodal points that are too close to be grasped""" + + +class AntipodalSampler: + def __init__( + self, + cfg: AntipodalSamplerCfg = AntipodalSamplerCfg(), + ): + self.mesh: o3d.t.geometry.TriangleMesh | None = None + self.cfg = cfg + + def sample(self, vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor: + """Get sample Antipodal point pair + + Returns: + hit_point_pairs: [N, 2, 3] tensor of N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. + """ + # update mesh + self.mesh = o3d.t.geometry.TriangleMesh() + self.mesh.vertex.positions = o3c.Tensor( + vertices.to("cpu").numpy(), dtype=o3c.float32 + ) + self.mesh.triangle.indices = o3c.Tensor( + faces.to("cpu").numpy(), dtype=o3c.int32 + ) + self.mesh.compute_vertex_normals() + # sample points and normals + sample_pcd = self.mesh.sample_points_uniformly( + number_of_points=self.cfg.n_sample + ) + sample_points = torch.tensor( + sample_pcd.point.positions.numpy(), + device=vertices.device, + dtype=vertices.dtype, + ) + sample_normals = torch.tensor( + sample_pcd.point.normals.numpy(), + device=vertices.device, + dtype=vertices.dtype, + ) + # generate rays + ray_direc = -sample_normals + ray_origin = ( + sample_points + 1e-3 * ray_direc + ) # Offset ray origin slightly along the normal to avoid self-intersection + disturb_direc = AntipodalSampler._random_rotate_unit_vectors( + ray_direc, max_angle=self.cfg.max_angle + ) + ray_origin = torch.vstack([ray_origin, ray_origin]) + ray_direc = torch.vstack([ray_direc, disturb_direc]) + # casting + return self.get_raycast_result( + ray_origin, + ray_direc, + surface_origin=torch.vstack([sample_points, sample_points]), + ) + + def get_raycast_result( + self, + ray_origin: torch.Tensor, + ray_direc: torch.Tensor, + surface_origin: torch.Tensor, + ): + if ray_origin.ndim != 2 or ray_origin.shape[-1] != 3: + raise ValueError("ray_origin must have shape [N, 3]") + if ray_direc.ndim != 2 or ray_direc.shape[-1] != 3: + raise ValueError("ray_direc must have shape [N, 3]") + if ray_origin.shape[0] != ray_direc.shape[0]: + raise ValueError( + "ray_origin and ray_direc must have the same number of rays" + ) + if ray_origin.shape[0] != surface_origin.shape[0]: + raise ValueError( + "ray_origin and surface_origin must have the same number of rays" + ) + + scene = o3d.t.geometry.RaycastingScene() + scene.add_triangles(self.mesh) + + rays = torch.cat([ray_origin, ray_direc], dim=-1) + rays_o3d = o3c.Tensor(rays.detach().to("cpu").numpy(), dtype=o3c.float32) + + ans = scene.cast_rays(rays_o3d) + t_hit = torch.from_numpy(ans["t_hit"].numpy()).to( + device=ray_origin.device, dtype=ray_origin.dtype + ) + hit_mask = torch.logical_and( + t_hit > self.cfg.min_length, t_hit < self.cfg.max_length + ) + hit_points = ray_origin[hit_mask] + t_hit[hit_mask, None] * ray_direc[hit_mask] + hit_origins = surface_origin[hit_mask] + hit_point_pairs = torch.cat( + [hit_points[:, None, :], hit_origins[:, None, :]], dim=1 + ) + hit_point_pairs = hit_point_pairs.to(dtype=torch.float32) + return hit_point_pairs + + @staticmethod + def _random_rotate_unit_vectors( + vectors: torch.Tensor, + max_angle: float, + degrees: bool = False, + eps: float = 1e-8, + ) -> torch.Tensor: + """ + Apply random small rotations to a batch of unit vectors [N, 3]. + + Args: + vectors: [N, 3], unit vectors + max_angle: Maximum rotation angle + degrees: If True, `max_angle` is given in degrees + eps: Numerical stability constant + + Returns: + rotated: [N, 3], rotated unit vectors + """ + assert vectors.ndim == 2 and vectors.shape[-1] == 3, "vectors must be [N, 3]" + + v = F.normalize(vectors, dim=-1) + + if degrees: + max_angle = torch.deg2rad( + torch.tensor(max_angle, dtype=v.dtype, device=v.device) + ).item() + + n = v.shape[0] + + # 1) Generate a random direction for each vector + # then project it onto the plane perpendicular to v to get the rotation axis k + rand_dir = torch.randn_like(v) + eps + proj = (rand_dir * v).sum(dim=-1, keepdim=True) * v + k = rand_dir - proj + k = F.normalize(k, dim=-1) + + # 2) Sample rotation angles in the range [eps, max_angle] + theta = ( + torch.rand(n, 1, device=v.device, dtype=v.dtype) * (max_angle - eps) + eps + ) + + # 3) Rodrigues' rotation formula + # R(v) = v*cosθ + (k×v)*sinθ + k*(k·v)*(1-cosθ) + # Since k ⟂ v, the last term is theoretically 0, but keeping the general formula is more robust + cos_t = torch.cos(theta) + sin_t = torch.sin(theta) + + kv = (k * v).sum(dim=-1, keepdim=True) + rotated = v * cos_t + torch.cross(k, v, dim=-1) * sin_t + k * kv * (1.0 - cos_t) + + return F.normalize(rotated, dim=-1) + + def visualize(self, hit_point_pairs: torch.Tensor): + if self.mesh is None: + logger.log_warning("Mesh is not initialized. Cannot visualize.") + return + + if hit_point_pairs.shape[0] == 0: + raise ValueError("No point pairs to visualize") + origin_points = hit_point_pairs[:, 0, :] + hit_points = hit_point_pairs[:, 1, :] + + origin_points_np = origin_points.to("cpu").numpy() + hit_points_np = hit_points.detach().to("cpu").numpy() + + n_pairs = hit_point_pairs.shape[0] + line_indices = np.stack( + [np.arange(n_pairs), np.arange(n_pairs) + n_pairs], axis=1 + ) + + mesh_legacy = self.mesh.to_legacy() + mesh_legacy.compute_vertex_normals() + mesh_legacy.paint_uniform_color([0.8, 0.8, 0.8]) + + origin_pcd = o3d.geometry.PointCloud() + origin_pcd.points = o3d.utility.Vector3dVector(origin_points_np) + origin_pcd.colors = o3d.utility.Vector3dVector( + np.tile(np.array([[0.1, 0.4, 1.0]]), (n_pairs, 1)) + ) + + hit_pcd = o3d.geometry.PointCloud() + hit_pcd.points = o3d.utility.Vector3dVector(hit_points_np) + hit_pcd.colors = o3d.utility.Vector3dVector( + np.tile(np.array([[1.0, 0.2, 0.2]]), (n_pairs, 1)) + ) + + line_set = o3d.geometry.LineSet() + mid_points = (origin_points_np + hit_points_np) / 2 + point_diff = hit_points_np - origin_points_np + draw_origin = mid_points - 0.6 * point_diff + draw_end = mid_points + 0.6 * point_diff + draw_pointpair = np.concatenate([draw_origin, draw_end], axis=0) + line_set.points = o3d.utility.Vector3dVector(draw_pointpair) + line_set.lines = o3d.utility.Vector2iVector(line_indices) + line_set.colors = o3d.utility.Vector3dVector( + np.tile(np.array([[0.2, 0.9, 0.2]]), (n_pairs, 1)) + ) + + o3d.visualization.draw_geometries( + [mesh_legacy, origin_pcd, hit_pcd, line_set], + window_name="Antipodal Point Pairs", + mesh_show_back_face=True, + ) + + +if __name__ == "__main__": + mesh_path = "/media/chenjian/_abc/project/grasp_annotator/dustpan_saa.ply" + mesh = o3d.t.io.read_triangle_mesh(mesh_path) + vertices = torch.from_numpy(mesh.vertex.positions.cpu().numpy()) + faces = torch.from_numpy(mesh.triangle.indices.cpu().numpy()) + + sampler = AntipodalSampler() + hit_point_pairs = sampler.sample(vertices, faces) + sampler.visualize(hit_point_pairs) + print(f"Sampled {hit_point_pairs.shape[0]} antipodal points") diff --git a/examples/sim/demo/grasp_mug.py b/examples/sim/demo/grasp_mug.py new file mode 100644 index 00000000..a0a138d0 --- /dev/null +++ b/examples/sim/demo/grasp_mug.py @@ -0,0 +1,257 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates the creation and simulation of a robot with a soft object, +and performs a pressing task in a simulated environment. +""" + +import argparse +import numpy as np +import time +import torch + +from dexsim.utility.path import get_resources_data_path + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot, RigidObject +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import logger +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + RobotCfg, + LightCfg, + RigidBodyAttributesCfg, + RigidObjectCfg, + URDFCfg, +) +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import ( + GraspAnnotatorCfg, + AntipodalSamplerCfg, +) + + +def parse_arguments(): + """ + Parse command-line arguments to configure the simulation. + + Returns: + argparse.Namespace: Parsed arguments including number of environments and rendering options. + """ + parser = argparse.ArgumentParser( + description="Create and simulate a robot in SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--enable_rt", action="store_true", help="Enable ray tracing rendering" + ) + parser.add_argument("--headless", action="store_true", help="Enable headless mode") + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + return parser.parse_args() + + +def initialize_simulation(args) -> SimulationManager: + """ + Initialize the simulation environment based on the provided arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + + Returns: + SimulationManager: Configured simulation manager instance. + """ + config = SimulationManagerCfg( + headless=True, + sim_device=args.device, + enable_rt=args.enable_rt, + physics_dt=1.0 / 100.0, + num_envs=args.num_envs, + arena_space=2.5, + ) + sim = SimulationManager(config) + + if args.enable_rt: + light = sim.add_light( + cfg=LightCfg( + uid="main_light", + color=(0.6, 0.6, 0.6), + intensity=30.0, + init_pos=(1.0, 0, 3.0), + ) + ) + + return sim + + +def create_robot(sim: SimulationManager, position=[0.0, 0.0, 0.0]) -> Robot: + """ + Create and configure a robot with an arm and a dexterous hand in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + Robot: The configured robot instance added to the simulation. + """ + # Retrieve URDF paths for the robot arm and hand + ur10_urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + gripper_urdf_path = get_data_path("DH_PGC_140_50_M/DH_PGC_140_50_M.urdf") + # Configure the robot with its components and control properties + cfg = RobotCfg( + uid="UR10", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur10_urdf_path}, + {"component_type": "hand", "urdf_path": gripper_urdf_path}, + ] + ), + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[0-9]": 1e4, "FINGER[1-2]": 1e3}, + damping={"JOINT[0-9]": 1e3, "FINGER[1-2]": 1e2}, + max_effort={"JOINT[0-9]": 1e5, "FINGER[1-2]": 1e4}, + drive_type="force", + ), + control_parts={ + "arm": ["JOINT[0-9]"], + "hand": ["FINGER[1-2]"], + }, + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=[ + [0.0, 1.0, 0.0, 0.0], + [-1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.12], + [0.0, 0.0, 0.0, 1.0], + ], + ) + }, + init_qpos=[0.0, -np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, 0.0, 0.0, 0.0], + init_pos=position, + ) + return sim.add_robot(cfg=cfg) + + +def create_mug(sim: SimulationManager): + mug_cfg = RigidObjectCfg( + uid="table", + shape=MeshCfg( + fpath=get_data_path("CoffeeCup/cup.ply"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.01, + dynamic_friction=0.97, + static_friction=0.99, + ), + max_convex_hull_num=16, + init_pos=[0.55, 0.0, 0.01], + init_rot=[0.0, 0.0, -90], + body_scale=(4, 4, 4), + ) + mug = sim.add_rigid_object(cfg=mug_cfg) + return mug + + +def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tensor): + n_envs = sim.num_envs + rest_arm_qpos = robot.get_qpos("arm") + + approach_xpos = grasp_xpos.clone() + approach_xpos[:, 2, 3] += 0.04 + + _, qpos_approach = robot.compute_ik( + pose=approach_xpos, joint_seed=rest_arm_qpos, name="arm" + ) + _, qpos_grasp = robot.compute_ik( + pose=grasp_xpos, joint_seed=qpos_approach, name="arm" + ) + hand_open_qpos = torch.tensor([0.00, 0.00], dtype=torch.float32, device=sim.device) + hand_close_qpos = torch.tensor( + [0.025, 0.025], dtype=torch.float32, device=sim.device + ) + + arm_trajectory = torch.cat( + [ + rest_arm_qpos[:, None, :], + qpos_approach[:, None, :], + qpos_grasp[:, None, :], + qpos_grasp[:, None, :], + qpos_approach[:, None, :], + rest_arm_qpos[:, None, :], + ], + dim=1, + ) + hand_trajectory = torch.cat( + [ + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + ], + dim=1, + ) + all_trajectory = torch.cat([arm_trajectory, hand_trajectory], dim=-1) + interp_trajectory = interpolate_with_distance( + trajectory=all_trajectory, interp_num=300, device=sim.device + ) + return interp_trajectory + + +if __name__ == "__main__": + args = parse_arguments() + sim = initialize_simulation(args) + robot = create_robot(sim, position=[0.0, 0.0, 0.0]) + mug = create_mug(sim) + + # get mug grasp pose + grasp_cfg = GraspAnnotatorCfg( + viser_port=11801, + antipodal_sampler_cfg=AntipodalSamplerCfg( + n_sample=5000, max_length=0.088, min_length=0.003 + ), + force_regenerate=True, + ) + sim.open_window() + grasp_xpos = mug.get_grasp_pose( + approach_direction=torch.tensor( + [0, 0, -1], dtype=torch.float32, device=sim.device + ), + cfg=grasp_cfg, + is_visual=True, + ) + + grab_traj = get_grasp_traj(sim, robot, grasp_xpos) + input("Press Enter to start the grab mug demo...") + n_waypoint = grab_traj.shape[1] + for i in range(n_waypoint): + robot.set_qpos(grab_traj[:, i, :]) + sim.update(step=4) + time.sleep(1e-2) + input("Press Enter to exit the simulation...") From 1e15c77b29f491fd3c4e86bb189187b5f92bbc08 Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 23 Mar 2026 19:17:02 +0800 Subject: [PATCH 2/6] update --- examples/sim/demo/grasp_mug.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/sim/demo/grasp_mug.py b/examples/sim/demo/grasp_mug.py index a0a138d0..18c5ff9c 100644 --- a/examples/sim/demo/grasp_mug.py +++ b/examples/sim/demo/grasp_mug.py @@ -236,15 +236,19 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso antipodal_sampler_cfg=AntipodalSamplerCfg( n_sample=5000, max_length=0.088, min_length=0.003 ), - force_regenerate=True, + force_regenerate=True, # force user to annotate grasp region each time ) sim.open_window() + + # 1. View grasp object in browser (e.g http://localhost:11801) + # 2. press 'Rect Select Region', select grasp region + # 3. press 'Confirm Selection' to finish grasp region selection. grasp_xpos = mug.get_grasp_pose( approach_direction=torch.tensor( [0, 0, -1], dtype=torch.float32, device=sim.device - ), + ), # gripper approach direction in the mug local frame cfg=grasp_cfg, - is_visual=True, + is_visual=True, # visualize selected grasp pose finally ) grab_traj = get_grasp_traj(sim, robot, grasp_xpos) From baf731a8ed19e65268dfead52afc3fff93f472bf Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 23 Mar 2026 19:22:12 +0800 Subject: [PATCH 3/6] update comment --- examples/sim/demo/grasp_mug.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sim/demo/grasp_mug.py b/examples/sim/demo/grasp_mug.py index 18c5ff9c..6ff56d69 100644 --- a/examples/sim/demo/grasp_mug.py +++ b/examples/sim/demo/grasp_mug.py @@ -246,7 +246,7 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso grasp_xpos = mug.get_grasp_pose( approach_direction=torch.tensor( [0, 0, -1], dtype=torch.float32, device=sim.device - ), # gripper approach direction in the mug local frame + ), # gripper approach direction in the world frame cfg=grasp_cfg, is_visual=True, # visualize selected grasp pose finally ) From f1f043b809c59ff4215ea1fa6d5db3bdd5466281 Mon Sep 17 00:00:00 2001 From: chenjian Date: Tue, 24 Mar 2026 11:23:13 +0800 Subject: [PATCH 4/6] add viser dependence --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 60a12496..25b15290 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,8 @@ dependencies = [ "black==24.3.0", "fvcore", "h5py", - "tensordict" + "tensordict", + "viser==1.0.21" ] [project.optional-dependencies] From 63ec5e6667eb0c2419a2a8c47d7d556acd5eb3f0 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 25 Mar 2026 19:14:51 +0800 Subject: [PATCH 5/6] update --- .../pg_grasp/batch_collision_checker.py | 528 ++++++++++++++++++ 1 file changed, 528 insertions(+) create mode 100644 embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py new file mode 100644 index 00000000..f50a12f9 --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -0,0 +1,528 @@ +import trimesh +import numpy as np +import torch +import time +from typing import List, Tuple, Union +from dexsim.kit.meshproc import convex_decomposition_coacd +import hashlib +from dataclasses import dataclass +import os +import pickle +import open3d as o3d +from embodichain.utils import logger + + +CONVEX_CACHE_DIR = os.path.join( + os.path.expanduser("~"), ".cache", "embodichain_cache", "convex_decomposition" +) + + +@dataclass +class BatchConvexCollisionCheckerCfg: + collsion_threshold: float = 0.0 + n_query_mesh_samples: int = 4096 + debug: bool = False + + +class BatchConvexCollisionChecker: + def __init__( + self, + base_mesh_verts: torch.Tensor, + base_mesh_faces: torch.Tensor, + max_decomposition_hulls: int = 32, + ): + if not os.path.isdir(CONVEX_CACHE_DIR): + os.makedirs(CONVEX_CACHE_DIR, exist_ok=True) + base_mesh_verts_np = base_mesh_verts.cpu().numpy() + base_mesh_faces_np = base_mesh_faces.cpu().numpy() + mesh_hash = hashlib.md5( + (base_mesh_verts_np.tobytes() + base_mesh_faces_np.tobytes()) + ).hexdigest() + + # for visualization + self.mesh = o3d.geometry.TriangleMesh( + vertices=o3d.utility.Vector3dVector(base_mesh_verts_np), + triangles=o3d.utility.Vector3iVector(base_mesh_faces_np), + ) + self.mesh.compute_vertex_normals() + self.cache_path = os.path.join( + CONVEX_CACHE_DIR, f"{mesh_hash}_{max_decomposition_hulls}.pkl" + ) + + if not os.path.isfile(self.cache_path): + # generate convex hulls and extract plane equations, then cache to disk + self.plane_equations = BatchConvexCollisionChecker._compute_plane_equations( + base_mesh_verts_np, base_mesh_faces_np, max_decomposition_hulls + ) + pickle.dump(self.plane_equations, open(self.cache_path, "wb")) + else: + # load precomputed plane equations from cache + self.plane_equations = pickle.load(open(self.cache_path, "rb")) + + def query( + self, + query_mesh_verts: torch.Tensor, + query_mesh_faces: torch.Tensor, + poses: torch.Tensor, + cfg: BatchConvexCollisionCheckerCfg = BatchConvexCollisionCheckerCfg(), + ) -> Tuple[torch.Tensor, torch.Tensor]: + query_mesh = trimesh.Trimesh( + vertices=query_mesh_verts.to("cpu").numpy(), + faces=query_mesh_faces.to("cpu").numpy(), + ) + n_query = cfg.n_query_mesh_samples + n_batch = poses.shape[0] + query_points_np = query_mesh.sample(n_query).astype(np.float32) + query_points = torch.tensor( + query_points_np, device=poses.device + ) # [n_query, 3] + penetration_result = torch.zeros(size=(n_batch, n_query), device=poses.device) + penetration_result.fill_(-float("inf")) + collision_result = torch.zeros( + size=(n_batch, n_query), dtype=torch.bool, device=poses.device + ) + collision_result.fill_(False) + for normals, offsets in self.plane_equations: + normals_torch = torch.tensor(normals, device=poses.device) + offsets_torch = torch.tensor(offsets, device=poses.device) + penetration, collides = check_collision_single_hull( + normals_torch, + offsets_torch, + transform_points_batch(query_points, poses), + cfg.collsion_threshold, + ) + penetration_result = torch.max(penetration_result, penetration) + collision_result = torch.logical_or(collision_result, collides) + is_colliding = collision_result.any(dim=-1) # [B] + max_penetration = penetration_result.max(dim=-1)[0] # [B] + + if cfg.debug: + # visualize result + query_points_o3d = o3d.geometry.PointCloud() + query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np) + query_points_o3d.transform(poses[-1].to("cpu").numpy()) + query_points_color = np.zeros_like(query_points_np) + query_points_color[collision_result[-1].cpu().numpy()] = [ + 1.0, + 0, + 0, + ] # red for colliding points + query_points_color[~collision_result[-1].cpu().numpy()] = [ + 0, + 1.0, + 0, + ] # green for non-colliding points + query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color) + o3d.visualization.draw_geometries( + [self.mesh, query_points_o3d], mesh_show_back_face=True + ) + return is_colliding, max_penetration + + @staticmethod + def _compute_plane_equations( + vertices: np.ndarray, faces: np.ndarray, max_decomposition_hulls: int + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Convex decomposition and extract plane equations given mesh vertices and triangles. + Each convex hull is represented by its outward-facing face normals and offsets. + No padding is applied; each hull can have a different number of faces. + + Args: + vertices: [N, 3] vertex positions of the input mesh. + faces: [M, 3] triangle indices of the input mesh. + max_decomposition_hulls: maximum number of convex hulls to decompose into. + + Returns: + List of (normals_i [Ki, 3], offsets_i [Ki]) tuples, one per convex hull. + Ki is the number of faces of the i-th hull and can differ across hulls. + """ + mesh = o3d.t.geometry.TriangleMesh() + mesh.vertex.positions = o3d.core.Tensor(vertices, dtype=o3d.core.Dtype.Float32) + mesh.triangle.indices = o3d.core.Tensor(faces, dtype=o3d.core.Dtype.Int32) + is_success, out_mesh_list = convex_decomposition_coacd( + mesh, max_convex_hull_num=max_decomposition_hulls + ) + convex_vert_face_list = [] + for out_mesh in out_mesh_list: + verts = out_mesh.vertex.positions.numpy() + faces = out_mesh.triangle.indices.numpy() + convex_vert_face_list.append((verts, faces)) + return extract_plane_equations(convex_vert_face_list) + + +def extract_plane_equations( + convex_meshes: List[Tuple[np.ndarray, np.ndarray]], +) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Extract plane equations from a list of convex hull meshes. + Each convex hull is represented by its outward-facing face normals and offsets. + No padding is applied; each hull can have a different number of faces. + + Args: + convex_meshes: List of convex hull data. + - tuple of (vertices [N,3], faces [M,3]) + + Returns: + List of (normals_i [Ki, 3], offsets_i [Ki]) tuples, one per convex hull. + Ki is the number of faces of the i-th hull and can differ across hulls. + """ + convex_plane_data = [] + + for i, convex_mesh_data in enumerate(convex_meshes): + vertices, faces = convex_mesh_data + hull = trimesh.Trimesh( + vertices=vertices, + faces=faces, + ) + # Outward-facing face normals [Ki, 3] + face_normals = hull.face_normals + # One vertex per face to compute offset [Ki, 3] + face_origins = hull.triangles[:, 0, :] + # Plane equation: n · x + d = 0 => d = -(n · p) + offsets_i = -np.sum(face_normals * face_origins, axis=1) + + convex_plane_data.append( + (face_normals.astype(np.float32), offsets_i.astype(np.float32)) + ) + return convex_plane_data + + +def sample_surface_points(mesh_path: str, num_points: int = 4096) -> np.ndarray: + """ + Sample surface points from a mesh file. + + Args: + mesh_path: Path to the mesh file. + num_points: Number of surface points to sample. + + Returns: + points: [P, 3] numpy array of sampled surface points. + """ + mesh = trimesh.load(mesh_path, force="mesh") + points = mesh.sample(num_points) + return points.astype(np.float32) + + +def check_collision_single_hull( + normals: torch.Tensor, # [K, 3] + offsets: torch.Tensor, # [K] + transformed_points: torch.Tensor, # [B, P, 3] + threshold: float = 0.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Check collision between a batch of transformed point clouds and a single convex hull. + + A point p is inside the convex hull iff: + max_k (n_k · p + d_k) <= 0 + + Penetration depth for a point is defined as: + penetration = -(max_k (n_k · p + d_k)) + Positive penetration means the point is inside the hull. + + Args: + normals: [K, 3] outward face normals of the convex hull. + offsets: [K] plane offsets of the convex hull. + transformed_points: [B, P, 3] point cloud already transformed by batch poses. + threshold: collision threshold. A point is considered colliding if + its signed distance to the hull interior is <= threshold. + + Returns: + penetration: [B, P] penetration depth for each point. + Positive values indicate the point is inside the hull. + collides: [B, P] boolean mask, True if the point collides with this hull. + """ + # signed_dist: [B, P, K] = einsum([B,P,3], [K,3]) + [K] + signed_dist = torch.einsum("bpj, kj -> bpk", transformed_points, normals) + offsets + + # For each point, the maximum signed distance across all planes + # If max <= 0, the point satisfies all half-plane constraints => inside the hull + max_over_planes, _ = signed_dist.max(dim=-1) # [B, P] + + # Penetration depth: negate so that positive = inside + penetration = -max_over_planes # [B, P] + + # A point collides if its penetration exceeds the threshold + collides = penetration > threshold # [B, P] + + return penetration, collides + + +def merge_collision_results( + hull_results: List[Tuple[torch.Tensor, torch.Tensor]], device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Merge collision detection results from multiple convex hulls. + + A pose is considered colliding if ANY point penetrates ANY convex hull. + The reported penetration depth is the maximum across all points and all hulls. + + Args: + hull_results: List of (penetration [B, P], collides [B, P]) tuples, + one per convex hull, as returned by check_collision_single_hull. + device: torch device. + + Returns: + overall_collisions: [B] boolean, True if the pose collides with any hull. + overall_max_penetrations: [B] float, maximum penetration depth per pose. + """ + if not hull_results: + raise ValueError("hull_results is empty, nothing to merge.") + + B = hull_results[0][0].shape[0] + + overall_collisions = torch.zeros(B, dtype=torch.bool, device=device) + overall_max_penetrations = torch.full( + (B,), -float("inf"), dtype=torch.float32, device=device + ) + + for penetration, collides in hull_results: + # Update collision flag: OR across hulls + # A pose collides if any point collides with this hull + overall_collisions |= collides.any(dim=-1) # [B] + + # Update max penetration: take the per-pose maximum across all points for this hull, + # then compare with the running maximum + hull_max_pen = penetration.max(dim=-1)[0] # [B] + overall_max_penetrations = torch.max(overall_max_penetrations, hull_max_pen) + + return overall_collisions, overall_max_penetrations + + +def transform_points_batch( + points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] +) -> torch.Tensor: + """ + Apply a batch of rigid transforms to a point cloud. + + Args: + points: [P, 3] source point cloud. + poses: [B, 4, 4] batch of homogeneous transformation matrices. + + Returns: + transformed: [B, P, 3] transformed point cloud for each pose. + """ + R = poses[:, :3, :3] # [B, 3, 3] + t = poses[:, :3, 3] # [B, 3] + transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) + return transformed + + +def batch_collision_detection( + convex_planes: List[Tuple[torch.Tensor, torch.Tensor]], + points_B: torch.Tensor, # [P, 3] + poses: torch.Tensor, # [B, 4, 4] + threshold: float = 0.0, + chunk_size: int = 512, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Full batch collision detection pipeline. + + Iterates over convex hulls sequentially and over pose chunks to control + GPU memory usage. Within each (hull, chunk) pair, the computation is + fully parallelized over B_chunk * P * K. + + Args: + convex_planes: List of (normals [Ki, 3], offsets [Ki]) tensors on device, + one per convex hull. Ki can differ across hulls. + points_B: [P, 3] sampled surface points of mesh B, on device. + poses: [B, 4, 4] batch of relative poses, on device. + threshold: collision threshold (positive = safety margin). + chunk_size: number of poses to process per chunk. + + Returns: + overall_collisions: [B] bool + overall_max_penetrations: [B] float + """ + device = points_B.device + B = poses.shape[0] + + all_hull_results: List[Tuple[torch.Tensor, torch.Tensor]] = [] + + # Sequential iteration over convex hulls to save memory + for hull_idx, (normals, offsets) in enumerate(convex_planes): + hull_pen_chunks = [] + hull_col_chunks = [] + + # Chunk over batch dimension to control peak memory + for start in range(0, B, chunk_size): + end = min(start + chunk_size, B) + poses_chunk = poses[start:end] + + # Transform points for this chunk of poses + transformed_chunk = transform_points_batch( + points_B, poses_chunk + ) # [B_chunk, P, 3] + + # Check collision against this single hull + penetration, collides = check_collision_single_hull( + normals, offsets, transformed_chunk, threshold + ) + + hull_pen_chunks.append(penetration) + hull_col_chunks.append(collides) + + # Concatenate chunks for this hull + hull_penetration = torch.cat(hull_pen_chunks, dim=0) # [B, P] + hull_collides = torch.cat(hull_col_chunks, dim=0) # [B, P] + + all_hull_results.append((hull_penetration, hull_collides)) + + # Merge results across all hulls + overall_collisions, overall_max_penetrations = merge_collision_results( + all_hull_results, device + ) + + return overall_collisions, overall_max_penetrations + + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # --- Create dummy mesh files for testing --- + box1 = trimesh.primitives.Box(extents=[0.5, 0.5, 0.5]) + box2 = trimesh.primitives.Box( + extents=[0.4, 0.4, 0.4], + transform=trimesh.transformations.translation_matrix([1, 0, 0]), + ) + box1.export("mesh_hull_1.obj") + box2.export("mesh_hull_2.obj") + + sphere_mesh = trimesh.primitives.Sphere(radius=0.3) + sphere_mesh.export("mesh_B.obj") + print("Created dummy mesh files.\n") + + # ==================== Preprocessing ==================== + + # Load externally decomposed convex hull meshes + convex_mesh_files = ["mesh_hull_1.obj", "mesh_hull_2.obj"] + convex_meshes = load_convex_meshes(convex_mesh_files) + if not convex_meshes: + print("No convex hulls loaded. Exiting.") + return + + # Extract plane equations (list of variable-length arrays) + convex_plane_data_np = extract_plane_equations(convex_meshes) + + # Convert to torch tensors on device + convex_planes_torch: List[Tuple[torch.Tensor, torch.Tensor]] = [] + for normals_np, offsets_np in convex_plane_data_np: + convex_planes_torch.append( + ( + torch.tensor(normals_np, device=device), # [Ki, 3] + torch.tensor(offsets_np, device=device), # [Ki] + ) + ) + + # Sample surface points from mesh B + points_np = sample_surface_points("mesh_B.obj", num_points=2048) + points_B = torch.tensor(points_np, device=device) # [P, 3] + + # ==================== Generate test poses ==================== + B = 10000 + chunk_size = 1024 + + # Random rotation matrices via SVD + random_mat = torch.randn(B, 3, 3, device=device) + U, _, Vt = torch.linalg.svd(random_mat) + R = U @ Vt + # Fix reflections to ensure proper rotations (det = +1) + det = torch.det(R) + R[det < 0] *= -1 + + poses = torch.eye(4, device=device).unsqueeze(0).repeat(B, 1, 1) + poses[:, :3, :3] = R + poses[:, :3, 3] = torch.randn(B, 3, device=device) * 0.5 + + # ==================== Run collision detection ==================== + print( + f"\nRunning collision detection: {B} poses, {points_B.shape[0]} points, " + f"{len(convex_planes_torch)} hulls..." + ) + + torch.cuda.synchronize() if device.type == "cuda" else None + start_time = time.time() + + with torch.no_grad(): + collisions, penetration_depths = batch_collision_detection( + convex_planes_torch, points_B, poses, threshold=0.001, chunk_size=chunk_size + ) + + torch.cuda.synchronize() if device.type == "cuda" else None + elapsed = time.time() - start_time + + # ==================== Report results ==================== + print(f"\n{'='*40}") + print(f"Total poses: {B}") + print(f"Collisions: {collisions.sum().item()} / {B}") + if collisions.any(): + print(f"Max penetration: {penetration_depths[collisions].max().item():.6f}") + else: + print(f"Max penetration: N/A (no collisions)") + print(f"Total time: {elapsed:.3f}s") + print(f"Per pose: {elapsed / B * 1e6:.2f} μs") + print(f"{'='*40}") + + # ==================== Benchmark ==================== + num_iters = 50 + torch.cuda.synchronize() if device.type == "cuda" else None + t0 = time.time() + for _ in range(num_iters): + with torch.no_grad(): + batch_collision_detection( + convex_planes_torch, + points_B, + poses, + threshold=0.001, + chunk_size=chunk_size, + ) + torch.cuda.synchronize() if device.type == "cuda" else None + t1 = time.time() + + avg_ms = (t1 - t0) / num_iters * 1000 + print( + f"\nBenchmark ({num_iters} iters): {avg_ms:.2f} ms/iter, " + f"{avg_ms / B * 1000:.2f} μs/pose" + ) + + +if __name__ == "__main__": + from embodichain.data import get_data_path + + bottle_a_path = get_data_path("ScannedBottle/moliwulong_processed.ply") + bottle_b_path = get_data_path("ScannedBottle/yibao_processed.ply") + + bottle_a_mesh = trimesh.load(bottle_a_path) + bottle_b_mesh = trimesh.load(bottle_b_path) + bottle_a_verts = torch.tensor(bottle_a_mesh.vertices, dtype=torch.float32) + bottle_a_faces = torch.tensor(bottle_a_mesh.faces, dtype=torch.int64) + bottle_b_verts = torch.tensor(bottle_b_mesh.vertices, dtype=torch.float32) + bottle_b_faces = torch.tensor(bottle_b_mesh.faces, dtype=torch.int64) + + collision_checker = BatchConvexCollisionChecker(bottle_a_verts, bottle_a_faces) + poses = torch.tensor( + [ + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 1.0], + [0, 0, 0, 1], + ], + [ + [1, 0, 0, 0.05], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], + ] + ) + check_cfg = BatchConvexCollisionCheckerCfg( + debug=False, + n_query_mesh_samples=32768, + collsion_threshold=-0.003, + ) + collisions, penetrations = collision_checker.query( + bottle_b_verts, bottle_b_faces, poses, cfg=check_cfg + ) + print("Collisions:", collisions) + print("Penetrations:", penetrations) From 73781d887eb82142224cc35daedc25ad505e964b Mon Sep 17 00:00:00 2001 From: chenjian Date: Thu, 26 Mar 2026 17:40:42 +0800 Subject: [PATCH 6/6] update --- .../graspkit/pg_grasp/antipodal_annotator.py | 16 ++ .../graspkit/pg_grasp/antipodal_sampler.py | 16 ++ .../pg_grasp/batch_collision_checker.py | 179 ------------------ .../pg_grasp/gripper_collision_checker.py | 0 4 files changed, 32 insertions(+), 179 deletions(-) create mode 100644 embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py index 4852879e..5ee3eda4 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -1,3 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + import os import argparse import open3d as o3d diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py index 1eb3ec61..09e4858e 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -1,3 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + import torch import torch.nn.functional as F import numpy as np diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py index f50a12f9..cf18b76e 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -307,185 +307,6 @@ def transform_points_batch( return transformed -def batch_collision_detection( - convex_planes: List[Tuple[torch.Tensor, torch.Tensor]], - points_B: torch.Tensor, # [P, 3] - poses: torch.Tensor, # [B, 4, 4] - threshold: float = 0.0, - chunk_size: int = 512, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Full batch collision detection pipeline. - - Iterates over convex hulls sequentially and over pose chunks to control - GPU memory usage. Within each (hull, chunk) pair, the computation is - fully parallelized over B_chunk * P * K. - - Args: - convex_planes: List of (normals [Ki, 3], offsets [Ki]) tensors on device, - one per convex hull. Ki can differ across hulls. - points_B: [P, 3] sampled surface points of mesh B, on device. - poses: [B, 4, 4] batch of relative poses, on device. - threshold: collision threshold (positive = safety margin). - chunk_size: number of poses to process per chunk. - - Returns: - overall_collisions: [B] bool - overall_max_penetrations: [B] float - """ - device = points_B.device - B = poses.shape[0] - - all_hull_results: List[Tuple[torch.Tensor, torch.Tensor]] = [] - - # Sequential iteration over convex hulls to save memory - for hull_idx, (normals, offsets) in enumerate(convex_planes): - hull_pen_chunks = [] - hull_col_chunks = [] - - # Chunk over batch dimension to control peak memory - for start in range(0, B, chunk_size): - end = min(start + chunk_size, B) - poses_chunk = poses[start:end] - - # Transform points for this chunk of poses - transformed_chunk = transform_points_batch( - points_B, poses_chunk - ) # [B_chunk, P, 3] - - # Check collision against this single hull - penetration, collides = check_collision_single_hull( - normals, offsets, transformed_chunk, threshold - ) - - hull_pen_chunks.append(penetration) - hull_col_chunks.append(collides) - - # Concatenate chunks for this hull - hull_penetration = torch.cat(hull_pen_chunks, dim=0) # [B, P] - hull_collides = torch.cat(hull_col_chunks, dim=0) # [B, P] - - all_hull_results.append((hull_penetration, hull_collides)) - - # Merge results across all hulls - overall_collisions, overall_max_penetrations = merge_collision_results( - all_hull_results, device - ) - - return overall_collisions, overall_max_penetrations - - -def main(): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - # --- Create dummy mesh files for testing --- - box1 = trimesh.primitives.Box(extents=[0.5, 0.5, 0.5]) - box2 = trimesh.primitives.Box( - extents=[0.4, 0.4, 0.4], - transform=trimesh.transformations.translation_matrix([1, 0, 0]), - ) - box1.export("mesh_hull_1.obj") - box2.export("mesh_hull_2.obj") - - sphere_mesh = trimesh.primitives.Sphere(radius=0.3) - sphere_mesh.export("mesh_B.obj") - print("Created dummy mesh files.\n") - - # ==================== Preprocessing ==================== - - # Load externally decomposed convex hull meshes - convex_mesh_files = ["mesh_hull_1.obj", "mesh_hull_2.obj"] - convex_meshes = load_convex_meshes(convex_mesh_files) - if not convex_meshes: - print("No convex hulls loaded. Exiting.") - return - - # Extract plane equations (list of variable-length arrays) - convex_plane_data_np = extract_plane_equations(convex_meshes) - - # Convert to torch tensors on device - convex_planes_torch: List[Tuple[torch.Tensor, torch.Tensor]] = [] - for normals_np, offsets_np in convex_plane_data_np: - convex_planes_torch.append( - ( - torch.tensor(normals_np, device=device), # [Ki, 3] - torch.tensor(offsets_np, device=device), # [Ki] - ) - ) - - # Sample surface points from mesh B - points_np = sample_surface_points("mesh_B.obj", num_points=2048) - points_B = torch.tensor(points_np, device=device) # [P, 3] - - # ==================== Generate test poses ==================== - B = 10000 - chunk_size = 1024 - - # Random rotation matrices via SVD - random_mat = torch.randn(B, 3, 3, device=device) - U, _, Vt = torch.linalg.svd(random_mat) - R = U @ Vt - # Fix reflections to ensure proper rotations (det = +1) - det = torch.det(R) - R[det < 0] *= -1 - - poses = torch.eye(4, device=device).unsqueeze(0).repeat(B, 1, 1) - poses[:, :3, :3] = R - poses[:, :3, 3] = torch.randn(B, 3, device=device) * 0.5 - - # ==================== Run collision detection ==================== - print( - f"\nRunning collision detection: {B} poses, {points_B.shape[0]} points, " - f"{len(convex_planes_torch)} hulls..." - ) - - torch.cuda.synchronize() if device.type == "cuda" else None - start_time = time.time() - - with torch.no_grad(): - collisions, penetration_depths = batch_collision_detection( - convex_planes_torch, points_B, poses, threshold=0.001, chunk_size=chunk_size - ) - - torch.cuda.synchronize() if device.type == "cuda" else None - elapsed = time.time() - start_time - - # ==================== Report results ==================== - print(f"\n{'='*40}") - print(f"Total poses: {B}") - print(f"Collisions: {collisions.sum().item()} / {B}") - if collisions.any(): - print(f"Max penetration: {penetration_depths[collisions].max().item():.6f}") - else: - print(f"Max penetration: N/A (no collisions)") - print(f"Total time: {elapsed:.3f}s") - print(f"Per pose: {elapsed / B * 1e6:.2f} μs") - print(f"{'='*40}") - - # ==================== Benchmark ==================== - num_iters = 50 - torch.cuda.synchronize() if device.type == "cuda" else None - t0 = time.time() - for _ in range(num_iters): - with torch.no_grad(): - batch_collision_detection( - convex_planes_torch, - points_B, - poses, - threshold=0.001, - chunk_size=chunk_size, - ) - torch.cuda.synchronize() if device.type == "cuda" else None - t1 = time.time() - - avg_ms = (t1 - t0) / num_iters * 1000 - print( - f"\nBenchmark ({num_iters} iters): {avg_ms:.2f} ms/iter, " - f"{avg_ms / B * 1000:.2f} μs/pose" - ) - - if __name__ == "__main__": from embodichain.data import get_data_path diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py new file mode 100644 index 00000000..e69de29b