Skip to content

Commit 525869d

Browse files
committed
Minor cleanup , slight path optimization
1 parent 35eb93e commit 525869d

File tree

8 files changed

+45
-171
lines changed

8 files changed

+45
-171
lines changed

smart_tree/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def main(cfg: DictConfig):
2929
pipeline.process_cloud(Path(f"{cfg.directory}/{p}"))
3030

3131
else:
32-
print("Please Supply a path or Directory")
32+
print("Please supply a path or directory to point clouds.")
3333

3434

3535
if __name__ == "__main__":

smart_tree/conf/tree-dataset.yaml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,25 @@ skeletonizer:
6969
K: 16
7070
min_connection_length: 0.02
7171
minimum_graph_vertices: 32
72-
max_number_components: 100
72+
max_number_components: 8
7373
voxel_downsample: False
7474
edge_non_linear: None
7575

7676

7777
pipeline:
7878

7979
preprocessing:
80-
# Scale:
81-
# _target_: smart_tree.dataset.augmentations.Scale
82-
# min_scale: 0.1
83-
# max_scale: 0.1
80+
# Scale:
81+
# _target_: smart_tree.dataset.augmentations.Scale
82+
# min_scale: 0.6
83+
# max_scale: 0.6
8484
# Scale:
8585
# _target_: smart_tree.dataset.augmentations.Scale
8686
# min_scale: 1
8787
# max_scale: 1
88-
VoxelDownsample:
89-
_target_: smart_tree.dataset.augmentations.VoxelDownsample
90-
voxel_size : 0.01
88+
# VoxelDownsample:
89+
# _target_: smart_tree.dataset.augmentations.VoxelDownsample
90+
# voxel_size : 0.01
9191
#FixedRotate:
9292
# _target_: smart_tree.dataset.augmentations.FixedRotate
9393
# xyz: [0, 0, 90]
@@ -108,6 +108,7 @@ pipeline:
108108

109109
repair_skeletons : True
110110
smooth_skeletons : True
111-
prune_skeletons : False # need to fix as some skeletons don't know their start I guess?
111+
kernel_size: 10
112+
prune_skeletons : False
112113
min_skeleton_radius : 0.05 #0.005
113114
min_skeleton_length : 0.02

smart_tree/data_types/cloud.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def to_labelled_cld(self, radii, direction, class_l) -> LabelledCloud:
2727
return LabelledCloud(self.xyz, self.rgb, radii * direction, class_l)
2828

2929
def to_o3d_cld(self):
30-
return o3d_cloud(self.xyz, colours=self.rgb)
30+
cpu_cld = self.to_device("cpu")
31+
return o3d_cloud(cpu_cld.xyz, colours=cpu_cld.rgb)
3132

3233
def filter(self, mask):
3334
return Cloud(self.xyz[mask], self.rgb[mask])
@@ -50,8 +51,7 @@ def cat(self):
5051
)
5152

5253
def view(self):
53-
cpu_cld = self.to_device("cpu")
54-
o3d_viewer([cpu_cld.to_o3d_cld()])
54+
o3d_viewer([self.to_o3d_cld()])
5555

5656
def voxel_down_sample(self, voxel_size):
5757
idx = voxel_downsample(self.xyz, voxel_size)
@@ -193,8 +193,8 @@ def medial_pts(self):
193193
@staticmethod
194194
def from_numpy(xyz, rgb, vector, class_l):
195195
return LabelledCloud(
196-
torch.from_numpy(xyz).float(), # float64 -> these data types are stupid...
197-
torch.from_numpy(rgb).float(), # float64
198-
torch.from_numpy(vector).float(), # float32
199-
torch.from_numpy(class_l).float(), # int64
196+
torch.from_numpy(xyz), # float64 -> these data types are stupid...
197+
torch.from_numpy(rgb), # float64
198+
torch.from_numpy(vector), # float32
199+
torch.from_numpy(class_l), # int64
200200
)

smart_tree/data_types/tree.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,10 @@ def smooth(self, kernel_size=10):
109109
for branch in self.branches.values():
110110
if branch.radii.shape[0] >= kernel_size:
111111
branch.radii = np.convolve(
112-
branch.radii.ravel(), kernel, mode="same"
113-
).reshape(-1, 1)
112+
branch.radii.reshape(-1),
113+
kernel,
114+
mode="same",
115+
)
114116

115117

116118
@dataclass
@@ -125,9 +127,9 @@ def repair(self):
125127
for skeleton in self.skeletons:
126128
skeleton.repair()
127129

128-
def smooth(self):
130+
def smooth(self, kernel_size=10):
129131
for skeleton in self.skeletons:
130-
skeleton.smooth()
132+
skeleton.smooth(kernel_size=kernel_size)
131133

132134
def to_o3d_lineset(self):
133135
return o3d_merge_linesets(

smart_tree/pipeline.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import numpy as np
44
import torch
5+
import time
56

7+
from copy import deepcopy
68
from .data_types.cloud import LabelledCloud, Cloud
79
from .data_types.tree import TreeSkeleton, DisjointTreeSkeleton
810
from hydra.utils import instantiate
@@ -66,11 +68,11 @@ def __init__(
6668

6769
def process_cloud(self, path: Path):
6870
# Load point cloud
69-
cloud: Cloud = load_cloud(path)
71+
cloud: Cloud = load_cloud(path).to_device(self.device)
7072
cloud = self.preprocessing(cloud)
7173

7274
# Run point cloud through model to predict class, radius, direction
73-
lc: LabelledCloud = self.inferer.forward(cloud).to_device("cuda")
75+
lc: LabelledCloud = self.inferer.forward(cloud).to_device(self.device)
7476
if self.view_model_output:
7577
lc.view(self.cmap)
7678

@@ -79,6 +81,8 @@ def process_cloud(self, path: Path):
7981

8082
# Run the branch cloud through skeletonization algorithm, then post process
8183
skeleton: DisjointTreeSkeleton = self.skeletonizer.forward(branch_cloud)
84+
original_skeleton = deepcopy(skeleton)
85+
8286
self.post_process(skeleton)
8387

8488
# View skeletonization results
@@ -89,13 +93,14 @@ def process_cloud(self, path: Path):
8993
skeleton.to_o3d_lineset(),
9094
skeleton.to_o3d_tube(colour=False),
9195
cloud.to_o3d_cld(),
96+
original_skeleton.to_o3d_tube(),
9297
],
9398
line_width=5,
9499
)
95100

96101
if self.save_outputs:
97-
save_o3d_mesh("skeleton.ply", skeleton.to_o3d_lineset())
98-
save_o3d_lineset("mesh.ply", skeleton.to_o3d_tube())
102+
save_o3d_lineset("skeleton.ply", skeleton.to_o3d_lineset())
103+
save_o3d_mesh("mesh.ply", skeleton.to_o3d_tube())
99104
save_o3d_cloud("cloud.ply", cloud.to_o3d_cld())
100105

101106
def post_process(self, skeleton: DisjointTreeSkeleton):
@@ -109,7 +114,7 @@ def post_process(self, skeleton: DisjointTreeSkeleton):
109114
skeleton.repair()
110115

111116
if self.smooth_skeletons:
112-
skeleton.smooth()
117+
skeleton.smooth(kernel_size=30)
113118

114119
@staticmethod
115120
def from_cfg(inferer, skeletonizer, cfg):

smart_tree/skeleton/path copy.py

Lines changed: 0 additions & 135 deletions
This file was deleted.

smart_tree/skeleton/path.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
from ..util.visualizer.view import o3d_viewer
1414

1515

16-
def trace_route(preds, idx, end_points):
17-
# cpu_preds = preds.cpu().numpy()
16+
def trace_route(preds, idx, termination_pts):
1817
path = []
1918

20-
while idx >= 0 and idx not in end_points:
19+
while idx >= 0 and idx not in termination_pts:
2120
path.append(idx)
2221
idx = preds[idx]
2322

24-
return preds.new_tensor(path, dtype=torch.long).flip(0)
23+
return preds.new_tensor(path, dtype=torch.long).flip(0), idx
2524

2625

2726
def select_path_points(
@@ -84,7 +83,7 @@ def sample_tree(
8483
selection_mask = preds > 0
8584
distances[~selection_mask] = -1
8685

87-
end_points = torch.tensor([], device=torch.device("cuda"))
86+
termination_pts = torch.tensor([], device=torch.device("cuda"))
8887

8988
branch_id = 0
9089

@@ -101,10 +100,10 @@ def sample_tree(
101100
pts_sampled = f"{100 * (1.0 - ((distances > 0).sum().item() / medial_pts.shape[0])):.2f}"
102101
pbar.set_postfix_str(f"Sampling Graph: {pts_sampled} %")
103102

104-
path_vertices_idx = trace_route(
103+
path_vertices_idx, termination_idx = trace_route(
105104
preds,
106105
farthest,
107-
end_points,
106+
termination_pts,
108107
)
109108

110109
idx_points, idx_path = select_path_points(
@@ -116,7 +115,9 @@ def sample_tree(
116115
distances[idx_points] = -1
117116
distances[idx_path] = -1
118117

119-
end_points = torch.unique(torch.cat((end_points, idx_points, idx_path)))
118+
termination_pts = torch.unique(
119+
torch.cat((termination_pts, idx_point, idx_path))
120+
)
120121

121122
if len(path_vertices_idx) < 2:
122123
continue
@@ -125,7 +126,7 @@ def sample_tree(
125126
branch_id,
126127
xyz=medial_pts[path_vertices_idx].cpu().numpy(),
127128
radii=medial_radii[path_vertices_idx].cpu().numpy(),
128-
parent_id=find_branch_parent(int(path_vertices_idx[0]), idx_lookup),
129+
parent_id=-1, # find_branch_parent(termination_idx, idx_lookup),
129130
child_id=-1,
130131
)
131132

smart_tree/skeleton/skeletonize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def forward(self, cloud: LabelledCloud):
5050
cloud.to_device(self.device)
5151

5252
if self.voxel_downsample != False:
53-
cloud = cloud.medial_voxel_down_sample(self.voxel_downsample)
53+
cloud = cloud.voxel_down_sample(0.01)
5454

5555
mask = outlier_removal(cloud.medial_pts, cloud.radii.unsqueeze(1))
5656
cloud = cloud.filter(mask)

0 commit comments

Comments
 (0)