Skip to content

Commit 4fa9995

Browse files
committed
V1.0.0 beta
1 parent 22f240b commit 4fa9995

File tree

4 files changed

+29
-18
lines changed

4 files changed

+29
-18
lines changed

smart_tree/conf/tree-dataset.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ model_inference:
6666
batch_size : 4
6767

6868
skeletonizer:
69-
K: 16
69+
K: 8
7070
min_connection_length: 0.02
7171
minimum_graph_vertices: 32
7272
max_number_components: 8
@@ -108,7 +108,6 @@ pipeline:
108108

109109
repair_skeletons : True
110110
smooth_skeletons : True
111-
kernel_size: 10
112-
prune_skeletons : False
113-
min_skeleton_radius : 0.05 #0.005
111+
prune_skeletons : True
112+
min_skeleton_radius : 0.1 #0.005
114113
min_skeleton_length : 0.02

smart_tree/data_types/tree.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def to_o3d_linesets(self) -> List:
3434
return [b.to_o3d_lineset() for b in self.branches.values()]
3535

3636
def to_o3d_lineset(self, colour=(0, 0, 0)):
37-
return o3d_merge_linesets(self.to_o3d_linesets())
37+
return o3d_merge_linesets(self.to_o3d_linesets(), colour=colour)
3838

3939
def to_o3d_tubes(self) -> List:
4040
return [b.to_o3d_tube() for b in self.branches.values()]
@@ -94,7 +94,8 @@ def prune(self, min_radius, min_length, root_id=0):
9494

9595
if branch.parent_id in branches_to_keep:
9696
if branch.length > min_length and (
97-
(branch.radii[0] > min_radius) or branch.radii[-1] > min_radius
97+
max(branch.radii[0], branch.radii[-1])
98+
> min_radius # We don't know which end of the branch is the start for skeletons that aren't connected...
9899
):
99100
branches_to_keep[branch_id] = branch
100101

@@ -120,8 +121,10 @@ class DisjointTreeSkeleton:
120121
skeletons: List[TreeSkeleton]
121122

122123
def prune(self, min_radius, min_length):
123-
for skeleton in self.skeletons:
124-
skeleton.prune(min_radius=min_radius, min_length=min_length)
124+
self.skeletons[0].prune(
125+
min_radius=min_radius,
126+
min_length=min_length,
127+
) # Can only prune the first skeleton as we don't know the root points for all the other skeletons...
125128

126129
def repair(self):
127130
for skeleton in self.skeletons:

smart_tree/skeleton/path.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,20 @@ def sample_tree(
8080
Surface Points: The point the medial pts got projected from..
8181
"""
8282

83+
branch_id = 0
84+
85+
branches = {}
86+
8387
selection_mask = preds > 0
8488
distances[~selection_mask] = -1
8589

8690
termination_pts = torch.tensor([], device=torch.device("cuda"))
87-
88-
branch_id = 0
89-
90-
idx_lookup = {}
91-
branches = {}
91+
branch_ids = torch.full(
92+
(medial_pts.shape[0],),
93+
-1,
94+
device=torch.device("cuda"),
95+
dtype=int,
96+
)
9297

9398
while True:
9499
farthest = distances.argmax().item()
@@ -100,23 +105,25 @@ def sample_tree(
100105
pts_sampled = f"{100 * (1.0 - ((distances > 0).sum().item() / medial_pts.shape[0])):.2f}"
101106
pbar.set_postfix_str(f"Sampling Graph: {pts_sampled} %")
102107

108+
""" Traces the path of the futhrest point until it converges with allocated points """
103109
path_vertices_idx, termination_idx = trace_route(
104110
preds,
105111
farthest,
106112
termination_pts,
107113
)
108114

115+
""" Gets the points around that path """
109116
idx_points, idx_path = select_path_points(
110117
medial_pts,
111118
medial_pts[path_vertices_idx],
112119
medial_radii[path_vertices_idx],
113120
)
114121

115122
distances[idx_points] = -1
116-
distances[idx_path] = -1
123+
distances[path_vertices_idx] = -1
117124

118125
termination_pts = torch.unique(
119-
torch.cat((termination_pts, idx_points, idx_path))
126+
torch.cat((termination_pts, idx_points, path_vertices_idx))
120127
)
121128

122129
if len(path_vertices_idx) < 2:
@@ -126,11 +133,13 @@ def sample_tree(
126133
branch_id,
127134
xyz=medial_pts[path_vertices_idx].cpu().numpy(),
128135
radii=medial_radii[path_vertices_idx].cpu().numpy(),
129-
parent_id=-1, # find_branch_parent(termination_idx, idx_lookup),
136+
parent_id=int(branch_ids[termination_idx].item()),
130137
child_id=-1,
131138
)
132139

133-
idx_lookup[branch_id] = idx_path.cpu().tolist() + idx_points.cpu().tolist()
140+
branch_ids[path_vertices_idx] = branch_id
141+
branch_ids[idx_points] = branch_id
142+
134143
branch_id += 1
135144

136145
return branches

smart_tree/util/mesh/geometries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def o3d_merge_linesets(line_sets, colour=[0, 0, 0]):
4141

4242
return o3d.geometry.LineSet(
4343
o3d.utility.Vector3dVector(points), o3d.utility.Vector2iVector(idxs)
44-
)
44+
).paint_uniform_color(colour)
4545

4646

4747
def points_to_edge_idx(points):

0 commit comments

Comments
 (0)