22from dataclasses import dataclass
33from typing import Dict , List
44
5- import numpy as np
65import torch
6+ import numpy as np
7+ import torch .nn .functional as F
8+
79
8- from ..util .math .queries import pts_to_nearest_tube
10+ from ..util .math .queries import pts_to_nearest_tube_gpu
911from ..util .mesh .geometries import (
1012 o3d_merge_clouds ,
1113 o3d_merge_linesets ,
@@ -72,13 +74,12 @@ def repair(self):
7274 parent_branch = self .branches [branch .parent_id ]
7375 tubes = parent_branch .to_tubes ()
7476
75- v , idx , _ = pts_to_nearest_tube (branch .xyz [0 ].reshape (- 1 , 3 ), tubes )
77+ v , idx , _ = pts_to_nearest_tube_gpu (branch .xyz [0 ].reshape (- 1 , 3 ), tubes )
7678
77- # connection point...
78- connection_pt = branch .xyz [0 ].reshape (- 1 , 3 ) + v [0 ]
79+ connection_pt = branch .xyz [0 ].reshape (- 1 , 3 ).cpu () + v [0 ].cpu ()
7980
80- branch .xyz = np . insert ( branch . xyz , 0 , connection_pt , axis = 0 )
81- branch .radii = np . insert ( branch .radii , 0 , branch .radii [ 0 ], axis = 0 )
81+ branch .xyz = torch . cat (( connection_pt , branch . xyz ) )
82+ branch .radii = torch . cat (( branch .radii [[ 0 ]] , branch .radii ) )
8283
8384 def prune (self , min_radius , min_length , root_id = 0 ):
8485 """If a branch doesn't meet the initial radius threshold or length threshold we want to remove it and all
@@ -101,19 +102,17 @@ def prune(self, min_radius, min_length, root_id=0):
101102
102103 self .branches = branches_to_keep
103104
104- def smooth (self , kernel_size = 10 ):
105+ def smooth (self , kernel_size = 5 ):
105106 """
106107 Smooths the skeleton radius.
107108 """
108- kernel = np .ones (kernel_size ) / kernel_size
109-
110109 for branch in self .branches .values ():
111- if branch .radii .shape [0 ] >= kernel_size :
112- branch .radii = np . convolve (
113- branch .radii .reshape (- 1 ),
114- kernel ,
115- mode = "same" ,
116- )
110+ if branch .radii .shape [0 ] > kernel_size :
111+ branch .radii = F . conv1d (
112+ branch .radii .reshape (- 1 , 1 , 1 ),
113+ torch . ones ( 1 , 1 , kernel_size ) ,
114+ padding = "same" ,
115+ ). reshape ( - 1 )
117116
118117
119118@dataclass
0 commit comments