Skip to content

Commit 5ad9690

Browse files
committed
Improved progress bar
1 parent d40dce8 commit 5ad9690

File tree

8 files changed

+43
-40
lines changed

8 files changed

+43
-40
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
# 💡 Smart-Tree: Neural Medial Axis Approximation of Point Clouds for 3D Tree Skeletonization 🌳
1+
# <center> 💡🧠🤔 Smart-Tree: Neural Medial Axis Approximation of Point Clouds for 3D Tree Skeletonization 🌳🌲🌴 </center>
2+
3+
## 📝 Description:
24

35
This GitHub repository contains code from the the paper "Smart-Tree: Neural Medial Axis Approximation of Point Clouds for 3D Tree Skeletonization". <br>
46
The code provided, is a deep-learning based skeletonization method for point clouds.
@@ -51,7 +53,7 @@ We supply two different models with weights:
5153
* `noble-elevator-58` contains branch / foliage segmentation. <br>
5254
* `peach-forest-65` is only trained on points from branching structure. <br>
5355

54-
If you wish to run smart-tree using your own weights you will need to update the model paths in the tree-dataset.yaml. <br>
56+
If you wish to run smart-tree using your own weights you will need to update the model paths in the `tree-dataset.yaml`. <br>
5557

5658
To run smart-tree use: <br>
5759
`run-smart-tree +path=cloud_path` <br>

smart_tree/conf/tree-dataset.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ model_inference:
6767

6868
skeletonizer:
6969
K: 16
70-
min_connection_length: 0.02
70+
min_connection_length: 0.00 #.02
7171
minimum_graph_vertices: 32
72-
max_number_components: 1000
72+
max_number_components: 30
7373
voxel_downsample: False
7474
edge_non_linear: None
7575

@@ -98,8 +98,8 @@ pipeline:
9898
repair_skeletons : True
9999
smooth_skeletons : True
100100
prune_skeletons : True
101-
min_skeleton_radius : 0 #0.005
102-
min_skeleton_length : 0
101+
min_skeleton_radius : 0.001 #0.005
102+
min_skeleton_length : 0.02
103103
view_model_output : False
104104
view_skeletons : True
105105
save_outputs : False

smart_tree/data_types/cloud.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ def cat(self):
153153

154154
def medial_voxel_down_sample(self, voxel_size):
155155
idx = voxel_downsample_idxs(self.medial_pts, voxel_size)
156-
print(idx)
157-
quit()
156+
158157
return self.filter(idx)
159158

160159
def to_torch(self):

smart_tree/model/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def main(cfg: DictConfig):
110110

111111
model_output = model.forward(sparse_input)
112112

113-
# print(model_output)
114113
loss = model.compute_loss(model_output, targets, loss_mask)
115114
total_loss = loss["radius"] + loss["direction"] + loss["class"]
116115

smart_tree/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def post_process(self, skeleton: DisjointTreeSkeleton):
110110
skeleton.repair()
111111

112112
if self.smooth_skeletons:
113-
print("Smoothing.")
113+
print("Smoothing...")
114114
skeleton.smooth()
115115

116116
@staticmethod

smart_tree/skeleton/path.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import numpy as np
33
import open3d as o3d
44
import torch
5+
import sys
6+
57
from tqdm import tqdm
68

79
from ..data_types.branch import BranchSkeleton
@@ -39,7 +41,9 @@ def select_path_points(
3941
"""
4042

4143
point_path, dists, _ = nn(
42-
points, path_verts, r=radii.max().item()
44+
points,
45+
path_verts,
46+
r=radii.max().item(),
4347
) # nearest path idx for each point
4448
valid = dists[point_path >= 0] < radii[point_path[point_path >= 0]].squeeze(
4549
1
@@ -72,6 +76,7 @@ def sample_tree(
7276
all_points,
7377
root_idx=0,
7478
visualize=False,
79+
pbar=None,
7580
):
7681
"""
7782
Medial Points: NN estimated medial points
@@ -91,17 +96,18 @@ def sample_tree(
9196
idx_lookup = {}
9297
branches = {}
9398

94-
tubes = []
99+
# tubes = []
95100

96101
while True:
97-
os.system("clear")
98-
print(f"{(distances > 0).sum().item() / medial_pts.shape[0]:.4f}")
99-
100102
farthest = distances.argmax().item() # Get fartherest away medial point
101103

102104
if distances[farthest] <= 0:
103105
break
104106

107+
if pbar:
108+
pts_sampled = f"{100 * (1.0 - ((distances > 0).sum().item() / medial_pts.shape[0])):.2f}"
109+
pbar.set_postfix_str(f"Sampling Graph: {pts_sampled} %")
110+
105111
path_vertices_idx, first_idx = trace_route(
106112
preds, farthest, allocated=allocated_path_points
107113
) # Gets IDXs along a path and the first IDX of that path

smart_tree/skeleton/skeletonize.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import open3d.visualization.rendering as rendering
44
import torch
55
from cugraph import sssp
6-
from tqdm import tqdm
6+
from tqdm.auto import tqdm
77

88
from ..data_types.tree import TreeSkeleton, DisjointTreeSkeleton
99
from ..util.mesh.geometries import (
@@ -69,36 +69,25 @@ def forward(self, cloud: LabelledCloud):
6969
max_components=self.max_number_components,
7070
)
7171

72-
return DisjointTreeSkeleton(
73-
skeletons=[
74-
self.process_subgraph(
75-
subgraph,
76-
cloud.medial_pts,
77-
cloud.xyz,
78-
cloud.radii.unsqueeze(1),
79-
self.min_connection_length,
80-
self.device,
81-
skeleton_id,
82-
)
83-
for skeleton_id, subgraph in enumerate(
84-
tqdm(subgraphs, desc="Processing Skeleton Fragment", leave=False)
85-
)
86-
]
87-
)
72+
skeletons = []
73+
pbar = tqdm(subgraphs, position=0, leave=False)
74+
75+
for skeleton_id, subgraph in enumerate(pbar):
76+
pbar.set_description(f"Processing Subgraph {skeleton_id}")
77+
skeletons.append(self.process_subgraph(subgraph, cloud, skeleton_id, pbar))
78+
79+
return DisjointTreeSkeleton(skeletons)
8880

8981
def process_subgraph(
9082
self,
9183
subgraph,
92-
medial_points,
93-
points,
94-
radii,
95-
min_edge,
96-
device=torch.device("cuda:0"),
84+
cloud,
9785
skeleton_id=0,
86+
pbar=None,
9887
) -> TreeSkeleton:
9988
edges, edge_weights = decompose_cuda_graph(subgraph, self.device)
10089

101-
root_idx = edges[torch.argmin(medial_points[edges[:, 0]][:, 1])][0].item()
90+
root_idx = edges[torch.argmin(cloud.medial_pts[edges[:, 0]][:, 1])][0].item()
10291

10392
verts, preds, distance = shortest_paths(
10493
root_idx,
@@ -107,13 +96,20 @@ def process_subgraph(
10796
renumber=False,
10897
)
10998

110-
predecessor_graph = pred_graph(verts, preds, medial_points)
99+
predecessor_graph = pred_graph(verts, preds, cloud.medial_pts)
111100

112101
distances = torch.as_tensor(
113102
sssp(predecessor_graph, source=root_idx)["distance"]
114103
).to(self.device)
115104

116-
branches = sample_tree(medial_points, radii, preds, distances, points)
105+
branches = sample_tree(
106+
cloud.medial_pts,
107+
cloud.radii.unsqueeze(1),
108+
preds,
109+
distances,
110+
cloud.xyz,
111+
pbar=pbar,
112+
)
117113

118114
return TreeSkeleton(skeleton_id, branches)
119115

smart_tree/util/misc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
from typing import List, Union
33

4+
import sys
45
import cmapy
56
import numpy as np
67
import open3d as o3d

0 commit comments

Comments
 (0)