Skip to content

Commit 8e2fe91

Browse files
committed
Speed up skeleton repair
1 parent 4fa9995 commit 8e2fe91

File tree

7 files changed

+55
-43
lines changed

7 files changed

+55
-43
lines changed

smart_tree/conf/tree-dataset.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ model_inference:
6666
batch_size : 4
6767

6868
skeletonizer:
69-
K: 8
69+
K: 16
7070
min_connection_length: 0.02
7171
minimum_graph_vertices: 32
72-
max_number_components: 8
72+
max_number_components: 256
7373
voxel_downsample: False
7474
edge_non_linear: None
7575

@@ -107,7 +107,7 @@ pipeline:
107107

108108

109109
repair_skeletons : True
110-
smooth_skeletons : True
110+
smooth_skeletons : False
111111
prune_skeletons : True
112-
min_skeleton_radius : 0.1 #0.005
113-
min_skeleton_length : 0.02
112+
min_skeleton_radius : 0.001
113+
min_skeleton_length : 0.002

smart_tree/data_types/branch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def to_o3d_lineset(self, colour=(0, 0, 0)):
3333
return o3d_path(self.xyz, colour)
3434

3535
def to_o3d_tube(self):
36-
return o3d_tube_mesh(self.xyz, self.radii, self.colour)
36+
return o3d_tube_mesh(self.xyz.numpy(), self.radii.numpy(), self.colour)
3737

3838
def to_tubes(self, colour=(1, 0, 0)) -> List[Tube]:
3939
a_, b_, r1_, r2_ = (

smart_tree/data_types/tree.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from dataclasses import dataclass
33
from typing import Dict, List
44

5-
import numpy as np
65
import torch
6+
import numpy as np
7+
import torch.nn.functional as F
8+
79

8-
from ..util.math.queries import pts_to_nearest_tube
10+
from ..util.math.queries import pts_to_nearest_tube_gpu
911
from ..util.mesh.geometries import (
1012
o3d_merge_clouds,
1113
o3d_merge_linesets,
@@ -72,13 +74,12 @@ def repair(self):
7274
parent_branch = self.branches[branch.parent_id]
7375
tubes = parent_branch.to_tubes()
7476

75-
v, idx, _ = pts_to_nearest_tube(branch.xyz[0].reshape(-1, 3), tubes)
77+
v, idx, _ = pts_to_nearest_tube_gpu(branch.xyz[0].reshape(-1, 3), tubes)
7678

77-
# connection point...
78-
connection_pt = branch.xyz[0].reshape(-1, 3) + v[0]
79+
connection_pt = branch.xyz[0].reshape(-1, 3).cpu() + v[0].cpu()
7980

80-
branch.xyz = np.insert(branch.xyz, 0, connection_pt, axis=0)
81-
branch.radii = np.insert(branch.radii, 0, branch.radii[0], axis=0)
81+
branch.xyz = torch.cat((connection_pt, branch.xyz))
82+
branch.radii = torch.cat((branch.radii[[0]], branch.radii))
8283

8384
def prune(self, min_radius, min_length, root_id=0):
8485
"""If a branch doesn't meet the initial radius threshold or length threshold we want to remove it and all
@@ -101,19 +102,17 @@ def prune(self, min_radius, min_length, root_id=0):
101102

102103
self.branches = branches_to_keep
103104

104-
def smooth(self, kernel_size=10):
105+
def smooth(self, kernel_size=5):
105106
"""
106107
Smooths the skeleton radius.
107108
"""
108-
kernel = np.ones(kernel_size) / kernel_size
109-
110109
for branch in self.branches.values():
111-
if branch.radii.shape[0] >= kernel_size:
112-
branch.radii = np.convolve(
113-
branch.radii.reshape(-1),
114-
kernel,
115-
mode="same",
116-
)
110+
if branch.radii.shape[0] > kernel_size:
111+
branch.radii = F.conv1d(
112+
branch.radii.reshape(-1, 1, 1),
113+
torch.ones(1, 1, kernel_size),
114+
padding="same",
115+
).reshape(-1)
117116

118117

119118
@dataclass

smart_tree/data_types/tube.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,24 @@ def to_numpy(self):
2727

2828
@dataclass
2929
class CollatedTube:
30-
a: np.array # Nx3
31-
b: np.array # Nx3
32-
r1: np.array # N
33-
r2: np.array # N
30+
a: torch.tensor # Nx3
31+
b: torch.tensor # Nx3
32+
r1: torch.tensor # N
33+
r2: torch.tensor # N
34+
35+
def to_gpu(self, device=torch.device("cuda")):
36+
self.a = self.a.to(device)
37+
self.b = self.b.to(device)
38+
self.r1 = self.r1.to(device)
39+
self.r2 = self.r2.to(device)
3440

3541

3642
def collate_tubes(tubes: List[Tube]) -> CollatedTube:
37-
a = np.concatenate([tube.a for tube in tubes]).reshape(-1, 3)
38-
b = np.concatenate([tube.b for tube in tubes]).reshape(-1, 3)
43+
a = torch.cat([tube.a for tube in tubes]).reshape(-1, 3)
44+
b = torch.cat([tube.b for tube in tubes]).reshape(-1, 3)
3945

40-
r1 = np.asarray([tube.r1 for tube in tubes]).reshape(1, -1)
41-
r2 = np.asarray([tube.r2 for tube in tubes]).reshape(1, -1)
46+
r1 = torch.cat([tube.r1 for tube in tubes]).reshape(1, -1)
47+
r2 = torch.cat([tube.r2 for tube in tubes]).reshape(1, -1)
4248

4349
return CollatedTube(a, b, r1, r2)
4450

smart_tree/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def post_process(self, skeleton: DisjointTreeSkeleton):
111111
skeleton.repair()
112112

113113
if self.smooth_skeletons:
114-
skeleton.smooth(kernel_size=30)
114+
skeleton.smooth(kernel_size=27)
115115

116116
@staticmethod
117117
def from_cfg(inferer, skeletonizer, cfg):

smart_tree/skeleton/path.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,27 +112,34 @@ def sample_tree(
112112
termination_pts,
113113
)
114114

115-
""" Gets the points around that path """
115+
""" Gets the points around that path (and which path indexs they are close to) """
116116
idx_points, idx_path = select_path_points(
117117
medial_pts,
118118
medial_pts[path_vertices_idx],
119119
medial_radii[path_vertices_idx],
120120
)
121121

122+
""" Mark this points as allocated and as termination points """
122123
distances[idx_points] = -1
123124
distances[path_vertices_idx] = -1
124-
125125
termination_pts = torch.unique(
126-
torch.cat((termination_pts, idx_points, path_vertices_idx))
126+
torch.cat(
127+
(
128+
termination_pts,
129+
idx_points,
130+
path_vertices_idx,
131+
)
132+
)
127133
)
128134

135+
""" If the path has at least two points, save it as a branch """
129136
if len(path_vertices_idx) < 2:
130137
continue
131138

132139
branches[branch_id] = BranchSkeleton(
133140
branch_id,
134-
xyz=medial_pts[path_vertices_idx].cpu().numpy(),
135-
radii=medial_radii[path_vertices_idx].cpu().numpy(),
141+
xyz=medial_pts[path_vertices_idx].cpu(),
142+
radii=medial_radii[path_vertices_idx].cpu(),
136143
parent_id=int(branch_ids[termination_idx].item()),
137144
child_id=-1,
138145
)

smart_tree/util/math/queries.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,14 @@ def projection_to_distance_matrix_gpu(projections, pts): # N x M x 3
106106

107107

108108
def pts_to_nearest_tube_gpu(
109-
pts: np.array, tubes: List[Tube], device=torch.device("cuda")
109+
pts: torch.tensor, tubes: List[Tube], device=torch.device("cuda")
110110
):
111111
"""Vectors from pt to the nearest tube"""
112112

113-
collated_tube = collate_tubes(tubes)
113+
collated_tube_gpu = collate_tubes(tubes)
114+
collated_tube_gpu.to_gpu()
114115

115-
collated_tube_gpu = collated_tube_to_gpu(collated_tube)
116-
pts = torch.from_numpy(pts).float().to(device)
116+
pts = pts.float().to(device)
117117

118118
projections, t = points_to_collated_tube_projections_gpu(
119119
pts, collated_tube_gpu, device=torch.device("cuda")
@@ -128,9 +128,9 @@ def pts_to_nearest_tube_gpu(
128128
assert idx.shape[0] == pts.shape[0]
129129

130130
return (
131-
to_numpy(projections[torch.arange(pts.shape[0]), idx] - pts),
132-
to_numpy(idx),
133-
to_numpy(r[torch.arange(pts.shape[0]), idx]),
131+
projections[torch.arange(pts.shape[0]), idx] - pts,
132+
idx,
133+
r[torch.arange(pts.shape[0]), idx],
134134
)
135135

136136

0 commit comments

Comments
 (0)