Skip to content

Commit 35eb93e

Browse files
committed
Path skeleton optimization
1 parent 0410e55 commit 35eb93e

File tree

12 files changed

+212
-94
lines changed

12 files changed

+212
-94
lines changed

pyproject.toml

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,28 @@ requires = ["setuptools", "setuptools-scm"]
33
build-backend = "setuptools.build_meta"
44

55
[project]
6-
name = "Smart-Tree"
6+
name = "smart-tree"
7+
version = "1.0.0"
78
authors = [
89
{name = "Harry Dobbs", email = "harrydobbs87@gmail.com"},
910
]
1011
description = "Neural Network Point Cloud Tree Skeletonization"
11-
readme = "README.rst"
12+
readme = "README.md"
1213
requires-python = ">=3.8"
1314
license = {text = "BSD-3"}
1415
dependencies = [
1516
'numpy',
1617
'open3d',
1718
'hydra-core>=1.2.0',
1819
'click',
19-
# 'opencv-python',
20-
# 'pytorch3d@https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/pytorch3d-0.7.2-cp38-cp38-linux_x86_64.whl',
21-
# 'jaxtyping',
22-
# 'scikit-image',
23-
# 'scikit-learn',
24-
# 'scipy',
2520
'oauthlib',
26-
# 'type_enforced',
2721
'spconv-cu117',
2822
'wandb',
2923
'cmapy',
3024
'pykeops',
31-
# 'seaborn',
3225
'plyfile',
3326
'py_structs'
3427
]
35-
dynamic = ["version"]
3628

3729
[tool.setuptools.packages]
3830
find = {} # Scan the project directory with the default parameters

smart_tree/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "1.0.0"

smart_tree/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
@hydra.main(
1515
version_base=None,
16-
config_path="conf",
16+
config_path=".conf",
1717
config_name="tree-dataset",
1818
)
1919
def main(cfg: DictConfig):

smart_tree/conf/tree-dataset.yaml

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,43 +67,47 @@ model_inference:
6767

6868
skeletonizer:
6969
K: 16
70-
min_connection_length: 0.00 #.02
70+
min_connection_length: 0.02
7171
minimum_graph_vertices: 32
72-
max_number_components: 30
72+
max_number_components: 100
7373
voxel_downsample: False
7474
edge_non_linear: None
7575

7676

7777
pipeline:
7878

7979
preprocessing:
80-
# Scale:
81-
# _target_: smart_tree.dataset.augmentations.Scale
82-
# min_scale: 1
83-
# max_scale: 1
80+
# Scale:
81+
# _target_: smart_tree.dataset.augmentations.Scale
82+
# min_scale: 0.1
83+
# max_scale: 0.1
8484
# Scale:
8585
# _target_: smart_tree.dataset.augmentations.Scale
8686
# min_scale: 1
8787
# max_scale: 1
88-
# VoxelDownsample:
89-
# _target_: smart_tree.dataset.augmentations.VoxelDownsample
90-
# voxel_size : 0.01
88+
VoxelDownsample:
89+
_target_: smart_tree.dataset.augmentations.VoxelDownsample
90+
voxel_size : 0.01
9191
#FixedRotate:
9292
# _target_: smart_tree.dataset.augmentations.FixedRotate
9393
# xyz: [0, 0, 90]
9494
# CentreCloud:
9595
# _target_: smart_tree.dataset.augmentations.CentreCloud
9696

9797

98-
repair_skeletons : True
99-
smooth_skeletons : True
100-
prune_skeletons : True
101-
min_skeleton_radius : 0.001 #0.005
102-
min_skeleton_length : 0.02
103-
view_model_output : False
98+
view_model_output : True
10499
view_skeletons : True
105100
save_outputs : False
106101
branch_classes: [0]
107102
cmap:
108-
- [1, 0, 0] # Trunk
109-
- [0, 1, 0] # Foliage
103+
- [0.325, 0.207, 0.039] # Trunk
104+
- [0.290, 0.703, 0.254] # Foliage
105+
106+
107+
108+
109+
repair_skeletons : True
110+
smooth_skeletons : True
111+
prune_skeletons : False # need to fix as some skeletons don't know their start I guess?
112+
min_skeleton_radius : 0.05 #0.005
113+
min_skeleton_length : 0.02

smart_tree/data_types/cloud.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,13 @@ class LabelledCloud(Cloud):
9999
vector: torch.Tensor
100100
class_l: torch.Tensor
101101

102-
def __post_init__(self):
103-
num_classes = int(torch.max(self.class_l, 0)[0].item())
104-
self.cmap = torch.rand(num_classes + 1, 3)
102+
@property
103+
def number_classes(self):
104+
return int(torch.max(self.class_l, 0)[0].item()) + 1
105+
106+
@property
107+
def cmap(self):
108+
return torch.rand(self.number_classes, 3)
105109

106110
def filter(self, mask):
107111
return LabelledCloud(
@@ -120,11 +124,7 @@ def filter_by_class(self, classes: List):
120124
return self.filter(mask)
121125

122126
def view(self, cmap=[]):
123-
if len(cmap) != 0:
124-
cmap = cmap
125-
else:
126-
cmap = self.cmap
127-
127+
cmap = cmap if cmap != [] else self.cmap
128128
cpu_cld = self.to_device("cpu")
129129
input_cld = cpu_cld.to_o3d_cld()
130130
segmented_cld = o3d_cloud(cpu_cld.xyz, colours=cmap[cpu_cld.class_l])

smart_tree/data_types/tree.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def prune(self, min_radius, min_length, root_id=0):
9393
continue
9494

9595
if branch.parent_id in branches_to_keep:
96-
if branch.length > min_length and branch.radii[0] > min_radius:
96+
if branch.length > min_length and (
97+
(branch.radii[0] > min_radius) or branch.radii[-1] > min_radius
98+
):
9799
branches_to_keep[branch_id] = branch
98100

99101
self.branches = branches_to_keep

smart_tree/dataset/dataset.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ def compute_blocks(self):
155155

156156
self.block_centres = self.block_centres.to(torch.device("cpu"))
157157

158-
print("Blocks Computed")
159-
160158
def __getitem__(self, idx):
161159
block_centre = self.block_centres[idx]
162160
cloud: Cloud = self.clouds[idx]

smart_tree/model/model_inference.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,22 @@ def __init__(
4343
num_workers=8,
4444
batch_size=4,
4545
device=torch.device("cuda:0"),
46+
verbose=False,
4647
):
47-
print("Initalizing Model Inference")
4848
self.device = device
49-
self.model = load_model(model_path, weights_path, self.device)
50-
print("Model Loaded")
49+
self.verbose = verbose
5150
self.voxel_size = voxel_size
5251
self.block_size = block_size
5352
self.buffer_size = buffer_size
5453

5554
self.num_workers = num_workers
5655
self.batch_size = batch_size
5756

57+
self.model = load_model(model_path, weights_path, self.device)
58+
59+
if self.verbose:
60+
print("Model Loaded Succesfully")
61+
5862
def forward(self, cloud: Cloud, return_masked=True):
5963
outputs, inputs, masks = [], [], []
6064

smart_tree/pipeline.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __init__(
4141
cmap=[[1, 0, 0], [0, 1, 0]],
4242
device=torch.device("cuda:0"),
4343
):
44-
print("Setting up pipeline...")
4544
self.inferer = inferer
4645
self.skeletonizer = skeletonizer
4746

@@ -73,7 +72,7 @@ def process_cloud(self, path: Path):
7372
# Run point cloud through model to predict class, radius, direction
7473
lc: LabelledCloud = self.inferer.forward(cloud).to_device("cuda")
7574
if self.view_model_output:
76-
lc.view(cmap=self.cmap)
75+
lc.view(self.cmap)
7776

7877
# Filter only the branch points for skeletonizaiton
7978
branch_cloud: LabelledCloud = lc.filter_by_class(self.branch_classes)
@@ -88,16 +87,16 @@ def process_cloud(self, path: Path):
8887
[
8988
skeleton.to_o3d_tube(),
9089
skeleton.to_o3d_lineset(),
91-
cloud.to_o3d_cld(),
9290
skeleton.to_o3d_tube(colour=False),
91+
cloud.to_o3d_cld(),
9392
],
9493
line_width=5,
9594
)
9695

9796
if self.save_outputs:
9897
save_o3d_mesh("skeleton.ply", skeleton.to_o3d_lineset())
9998
save_o3d_lineset("mesh.ply", skeleton.to_o3d_tube())
100-
save_o3d_cloud("mesh.ply", cloud.to_o3d_cld())
99+
save_o3d_cloud("cloud.ply", cloud.to_o3d_cld())
101100

102101
def post_process(self, skeleton: DisjointTreeSkeleton):
103102
if self.prune_skeletons:
@@ -110,7 +109,6 @@ def post_process(self, skeleton: DisjointTreeSkeleton):
110109
skeleton.repair()
111110

112111
if self.smooth_skeletons:
113-
print("Smoothing...")
114112
skeleton.smooth()
115113

116114
@staticmethod

smart_tree/skeleton/path copy.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import os
2+
import numpy as np
3+
import open3d as o3d
4+
import torch
5+
import sys
6+
7+
from tqdm import tqdm
8+
9+
from ..data_types.branch import BranchSkeleton
10+
from ..util.mesh.geometries import o3d_cloud, o3d_path, o3d_sphere, o3d_tube_mesh
11+
from ..util.misc import flatten_list
12+
from .graph import nn
13+
from ..util.visualizer.view import o3d_viewer
14+
15+
16+
def trace_route(preds, idx, end_points):
17+
# cpu_preds = preds.cpu().numpy()
18+
path = []
19+
20+
while idx >= 0 and idx not in end_points:
21+
path.append(idx)
22+
idx = preds[idx]
23+
24+
return preds.new_tensor(path, dtype=torch.long).flip(0)
25+
26+
27+
def select_path_points(
28+
points: torch.tensor, path_verts: torch.tensor, radii: torch.tensor
29+
):
30+
"""
31+
Finds points nearest to a path (specified by points with radii).
32+
points: (N, 3) 3d points of point cloud
33+
path_verts: (M, 3) 3d points of path
34+
radii: (M, 1) radii of path
35+
returns: (X, 2) index tuples of (path, point) for X points falling within the path, ordered by path index
36+
"""
37+
38+
point_path, dists, _ = nn(
39+
points,
40+
path_verts,
41+
r=radii.max().item(),
42+
) # nearest path idx for each point
43+
valid = dists[point_path >= 0] < radii[point_path[point_path >= 0]].squeeze(
44+
1
45+
) # where the path idx is less than the distance to the point
46+
47+
on_path = point_path.new_zeros(point_path.shape, dtype=torch.bool)
48+
on_path[point_path >= 0] = valid # points that are on the path.
49+
50+
idx_point = on_path.nonzero().squeeze(1)
51+
idx_path = point_path[idx_point]
52+
53+
order = torch.argsort(idx_path)
54+
return idx_point[order], idx_path[order]
55+
56+
57+
def find_branch_parent(idx_pt, idx_lookup):
58+
"""Finds which branch pt is from ...
59+
so we can work out the branch parent..."""
60+
for _id, _idxs in idx_lookup.items():
61+
if idx_pt in _idxs:
62+
return int(_id)
63+
return -1
64+
65+
66+
def sample_tree(
67+
medial_pts,
68+
medial_radii,
69+
preds,
70+
distances,
71+
all_points,
72+
root_idx=0,
73+
visualize=False,
74+
pbar=None,
75+
):
76+
"""
77+
Medial Points: NN estimated medial points
78+
Medial Radii: NN estimated radii of points
79+
Preds: Predecessor of each medial point (on path to root node)
80+
Distance: Distance from root node to medial points
81+
Surface Points: The point the medial pts got projected from..
82+
"""
83+
84+
selection_mask = preds > 0
85+
distances[~selection_mask] = -1
86+
87+
end_points = torch.tensor([], device=torch.device("cuda"))
88+
89+
branch_id = 0
90+
91+
idx_lookup = {}
92+
branches = {}
93+
94+
while True:
95+
farthest = distances.argmax().item()
96+
97+
if distances[farthest] <= 0:
98+
break
99+
100+
if pbar:
101+
pts_sampled = f"{100 * (1.0 - ((distances > 0).sum().item() / medial_pts.shape[0])):.2f}"
102+
pbar.set_postfix_str(f"Sampling Graph: {pts_sampled} %")
103+
104+
path_vertices_idx = trace_route(
105+
preds,
106+
farthest,
107+
end_points,
108+
)
109+
110+
idx_points, idx_path = select_path_points(
111+
medial_pts,
112+
medial_pts[path_vertices_idx],
113+
medial_radii[path_vertices_idx],
114+
)
115+
116+
distances[idx_points] = -1
117+
distances[idx_path] = -1
118+
119+
end_points = torch.unique(torch.cat((end_points, idx_points, idx_path)))
120+
121+
if len(path_vertices_idx) > 1:
122+
branches[branch_id] = BranchSkeleton(
123+
branch_id,
124+
xyz=medial_pts[path_vertices_idx].cpu().numpy(),
125+
radii=medial_radii[path_vertices_idx].cpu().numpy(),
126+
parent_id=find_branch_parent(int(path_vertices_idx[0]), idx_lookup),
127+
child_id=-1,
128+
)
129+
130+
idx_lookup[branch_id] = (
131+
path_vertices_idx.cpu().tolist() + idx_points.cpu().tolist()
132+
)
133+
branch_id += 1
134+
135+
return branches

0 commit comments

Comments
 (0)