Skip to content

Commit 426ddcd

Browse files
committed
Added colour training / colour dropout augmentation
1 parent 0410e55 commit 426ddcd

File tree

6 files changed

+77
-27
lines changed

6 files changed

+77
-27
lines changed

smart_tree/conf/tree-dataset.yaml

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ training:
77
lr_decay: True
88
early_stop_epoch: 20
99
early_stop: True
10+
use_colour: False
1011

1112
dataset:
1213
_target_: smart_tree.dataset.dataset.TreeDataset
@@ -17,14 +18,16 @@ training:
1718
block_size: 4
1819
buffer_size: 0.4
1920

20-
augmentation:
21-
# Scale:
22-
# _target_: smart_tree.dataset.augmentations.Scale
23-
# min_scale: 0.1
24-
# max_scale: 0.1
25-
# Dropout:
21+
# augmentation:
22+
# _target_: smart_tree.dataset.augmentations.Scale
23+
# min_scale: 0.9
24+
# max_scale: 1.1
25+
2626
# _target_: smart_tree.dataset.augmentations.RandomDropout
2727
# max_drop_out: 0.1
28+
29+
# _target_: smart_tree.dataset.augmentations.RandomColourDropout
30+
# max_drop_out: 0.1
2831

2932
data_loader:
3033
_target_: torch.utils.data.DataLoader
@@ -64,6 +67,7 @@ model_inference:
6467
buffer_size: 0.4
6568
num_workers : 8
6669
batch_size : 4
70+
use_colour: False
6771

6872
skeletonizer:
6973
K: 16
@@ -88,19 +92,20 @@ pipeline:
8892
# VoxelDownsample:
8993
# _target_: smart_tree.dataset.augmentations.VoxelDownsample
9094
# voxel_size : 0.01
91-
#FixedRotate:
92-
# _target_: smart_tree.dataset.augmentations.FixedRotate
93-
# xyz: [0, 0, 90]
95+
96+
# FixedRotate:
97+
# _target_: smart_tree.dataset.augmentations.FixedRotate
98+
# xyz: [0, 0, 90]
9499
# CentreCloud:
95-
# _target_: smart_tree.dataset.augmentations.CentreCloud
100+
# _target_: smart_tree.dataset.augmentations.CentreCloud
96101

97102

98103
repair_skeletons : True
99104
smooth_skeletons : True
100105
prune_skeletons : True
101106
min_skeleton_radius : 0.001 #0.005
102107
min_skeleton_length : 0.02
103-
view_model_output : False
108+
view_model_output : True
104109
view_skeletons : True
105110
save_outputs : False
106111
branch_classes: [0]

smart_tree/dataset/augmentations.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, xyz):
3333

3434
def __call__(self, cloud):
3535
self.rot_mat = euler_angles_to_rotation(
36-
torch.tensor(self.xyz), device=cloud.device
36+
torch.tensor(self.xyz), device=cloud.xyz.device
3737
).float()
3838
return cloud.rotate(self.rot_mat)
3939

@@ -60,6 +60,19 @@ def __call__(self, cloud):
6060
return cloud.translate(self.xyz)
6161

6262

63+
class RandomTranslate(Augmentation):
64+
def __init__(self, std):
65+
self.std = torch.tensor(std)
66+
67+
def __call__(self, cloud):
68+
return cloud.translate(
69+
torch.normal(
70+
torch.zeros(3, device=cloud.xyz.device),
71+
std=self.std.to(cloud.xyz.device),
72+
)
73+
)
74+
75+
6376
class RandomDropout(Augmentation):
6477
def __init__(self, max_drop_out):
6578
self.max_drop_out = max_drop_out
@@ -76,6 +89,25 @@ def __call__(self, cloud):
7689
return cloud.filter(indices)
7790

7891

92+
class RandomColourDropout(Augmentation):
93+
def __init__(self, max_drop_out):
94+
self.max_drop_out = max_drop_out
95+
96+
def __call__(self, cloud):
97+
num_indices = int(
98+
(1.0 - (self.max_drop_out * torch.rand(1, device=cloud.rgb.device)))
99+
* cloud.xyz.shape[0]
100+
)
101+
102+
indices = torch.randint(
103+
high=cloud.rgb.shape[0], size=(num_indices, 1), device=cloud.rgb.device
104+
).squeeze(1)
105+
106+
cloud.rgb[indices] = torch.ones_like(cloud.rgb[indices])
107+
108+
return cloud
109+
110+
79111
class AugmentationPipeline:
80112
def __init__(self, augmentation_fns: List[Augmentation]):
81113
# config is a dict

smart_tree/model/model.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@ def __init__(
2222
direction_fc_planes,
2323
class_fc_planes,
2424
bias=False,
25+
branch_classes=[0],
2526
algo=spconv.ConvAlgo.Native,
27+
device=torch.device("cuda"),
2628
):
2729
super().__init__()
2830

31+
self.branch_classes = torch.tensor(branch_classes, device=device)
32+
2933
norm_fn = functools.partial(
3034
nn.BatchNorm1d, eps=1e-4, momentum=0.1
3135
) # , momentum=0.99)
@@ -51,7 +55,6 @@ def __init__(
5155
algo=algo,
5256
)
5357

54-
# Three Heads...
5558
self.radius_head = SparseFC(
5659
radius_fc_planes,
5760
norm_fn,
@@ -108,16 +111,18 @@ def compute_loss(self, outputs, targets, mask=None):
108111
class_target = targets[:, [3]]
109112
direction_target, radius_target = torch_normalized(targets[:, :3])
110113

111-
vector_mask = (
112-
class_target == 0
113-
) # only compute vector loss on points that are meant to be branches
114-
vector_mask = vector_mask.reshape(-1)
114+
mask = torch.isin(
115+
class_target,
116+
self.branch_classes,
117+
)
118+
119+
mask = mask.reshape(-1)
115120

116121
losses["radius"] = self.compute_radius_loss(
117-
radius_pred[vector_mask], radius_target[vector_mask]
122+
radius_pred[mask], radius_target[mask]
118123
)
119124
losses["direction"] = self.compute_direction_loss(
120-
direction_pred[vector_mask], direction_target[vector_mask]
125+
direction_pred[mask], direction_target[mask]
121126
)
122127
losses["class"] = self.compute_class_loss(class_pred, class_target)
123128

smart_tree/model/model_inference.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
buffer_size: float,
4343
num_workers=8,
4444
batch_size=4,
45+
use_colour=False,
4546
device=torch.device("cuda:0"),
4647
):
4748
print("Initalizing Model Inference")
@@ -55,6 +56,8 @@ def __init__(
5556
self.num_workers = num_workers
5657
self.batch_size = batch_size
5758

59+
self.num_input_feats = 6 if use_colour else 3
60+
5861
def forward(self, cloud: Cloud, return_masked=True):
5962
outputs, inputs, masks = [], [], []
6063

@@ -71,7 +74,7 @@ def forward(self, cloud: Cloud, return_masked=True):
7174
dataloader, desc="Inferring", leave=False
7275
):
7376
sparse_input = sparse_from_batch(
74-
features[:, :3],
77+
features[:, : self.num_input_feats],
7578
coordinates,
7679
device=self.device,
7780
)
@@ -111,6 +114,7 @@ def from_cfg(cfg):
111114
buffer_size=cfg.buffer_size,
112115
num_workers=cfg.num_workers,
113116
batch_size=cfg.batch_size,
117+
use_colour=cfg.use_colour,
114118
)
115119

116120

smart_tree/model/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def main(cfg: DictConfig):
8383

8484
epochs_no_improve = 0
8585
best_val_loss = torch.inf
86+
num_input_feats = 6 if cfg.use_colour else 3
8687

8788
# Training Epochs
8889
for epoch in tqdm(range(0, cfg.num_epoch), leave=False):
@@ -96,7 +97,7 @@ def main(cfg: DictConfig):
9697
):
9798
with amp_ctx:
9899
sparse_input = sparse_from_batch(
99-
features[:, :3],
100+
features[:, :num_input_feats],
100101
coordinates,
101102
device=device,
102103
)
@@ -143,7 +144,7 @@ def main(cfg: DictConfig):
143144
desc="Validating",
144145
):
145146
sparse_input = sparse_from_batch(
146-
features[:, :3],
147+
features[:, :num_input_feats],
147148
coordinates,
148149
device=device,
149150
)

smart_tree/util/math/maths.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,34 @@ def torch_normalized(v):
1616
return F.normalize(v), v.pow(2).sum(1).sqrt().unsqueeze(1)
1717

1818

19-
def euler_angles_to_rotation(xyz: List) -> torch.tensor:
20-
x, y, z = xyz
19+
def euler_angles_to_rotation(xyz: List, device=torch.device("cuda")) -> torch.tensor:
20+
x, y, z = xyz.to(device=device)
2121

2222
R_X = torch.tensor(
2323
[
2424
[1.0, 0.0, 0.0],
2525
[0.0, torch.cos(x), -torch.sin(x)],
2626
[0.0, torch.sin(x), torch.cos(x)],
27-
]
27+
],
28+
device=device,
2829
)
2930

3031
R_Y = torch.tensor(
3132
[
3233
[torch.cos(y), 0.0, torch.sin(y)],
3334
[0.0, 1.0, 0.0],
3435
[-torch.sin(y), 0.0, torch.cos(y)],
35-
]
36+
],
37+
device=device,
3638
)
3739

3840
R_Z = torch.tensor(
3941
[
4042
[torch.cos(z), -torch.sin(z), 0.0],
4143
[torch.sin(z), torch.cos(z), 0.0],
4244
[0.0, 0.0, 1.0],
43-
]
45+
],
46+
device=device,
4447
)
4548

4649
return torch.mm(R_Z, torch.mm(R_Y, R_X))

0 commit comments

Comments
 (0)