Skip to content

Commit 230f433

Browse files
committed
FP16
1 parent 426ddcd commit 230f433

File tree

9 files changed

+157
-181
lines changed

9 files changed

+157
-181
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,10 @@ wandb/
138138
*.out
139139
*.sl
140140
speed-tree-outputs/
141+
smart_tree/conf/tree-dataset-test.yaml
141142
smart_tree/conf/tree-split-test.json
142143
smart_tree/conf/apple-trellis-split-test.json
143144
smart_tree/conf/apple-trellis.yaml
144145
smart_tree/conf/apple-trellis-split.json
145-
FRNN/
146+
FRNN/
147+
debug/

smart_tree/conf/tree-dataset.yaml

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ training:
77
lr_decay: True
88
early_stop_epoch: 20
99
early_stop: True
10-
use_colour: False
1110

1211
dataset:
1312
_target_: smart_tree.dataset.dataset.TreeDataset
@@ -18,16 +17,14 @@ training:
1817
block_size: 4
1918
buffer_size: 0.4
2019

21-
# augmentation:
22-
# _target_: smart_tree.dataset.augmentations.Scale
23-
# min_scale: 0.9
24-
# max_scale: 1.1
25-
20+
augmentation:
21+
# Scale:
22+
# _target_: smart_tree.dataset.augmentations.Scale
23+
# min_scale: 0.1
24+
# max_scale: 0.1
25+
# Dropout:
2626
# _target_: smart_tree.dataset.augmentations.RandomDropout
2727
# max_drop_out: 0.1
28-
29-
# _target_: smart_tree.dataset.augmentations.RandomColourDropout
30-
# max_drop_out: 0.1
3128

3229
data_loader:
3330
_target_: torch.utils.data.DataLoader
@@ -67,7 +64,6 @@ model_inference:
6764
buffer_size: 0.4
6865
num_workers : 8
6966
batch_size : 4
70-
use_colour: False
7167

7268
skeletonizer:
7369
K: 16
@@ -92,20 +88,19 @@ pipeline:
9288
# VoxelDownsample:
9389
# _target_: smart_tree.dataset.augmentations.VoxelDownsample
9490
# voxel_size : 0.01
95-
96-
# FixedRotate:
97-
# _target_: smart_tree.dataset.augmentations.FixedRotate
98-
# xyz: [0, 0, 90]
91+
#FixedRotate:
92+
# _target_: smart_tree.dataset.augmentations.FixedRotate
93+
# xyz: [0, 0, 90]
9994
# CentreCloud:
100-
# _target_: smart_tree.dataset.augmentations.CentreCloud
95+
# _target_: smart_tree.dataset.augmentations.CentreCloud
10196

10297

10398
repair_skeletons : True
10499
smooth_skeletons : True
105100
prune_skeletons : True
106101
min_skeleton_radius : 0.001 #0.005
107102
min_skeleton_length : 0.02
108-
view_model_output : True
103+
view_model_output : False
109104
view_skeletons : True
110105
save_outputs : False
111106
branch_classes: [0]

smart_tree/data_types/cloud.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ def medial_pts(self):
193193
@staticmethod
194194
def from_numpy(xyz, rgb, vector, class_l):
195195
return LabelledCloud(
196-
torch.from_numpy(xyz).float(), # float64 -> these data types are stupid...
196+
torch.from_numpy(xyz).float(), # -> these data types are stupid...
197197
torch.from_numpy(rgb).float(), # float64
198198
torch.from_numpy(vector).float(), # float32
199-
torch.from_numpy(class_l).float(), # int64
199+
torch.from_numpy(class_l).int(), # int64
200200
)

smart_tree/dataset/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __getitem__(self, idx):
103103
if self.blocking:
104104
loss_mask = cube_filter(feats[:, :3], block_center, self.block_size)
105105

106-
return feats.float(), coords.int(), loss_mask
106+
return feats, coords.int(), loss_mask
107107

108108
def __len__(self):
109109
return len(self.tree_paths)

smart_tree/model/model.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn as nn
66
import torch.nn.functional as F
77

8-
from smart_tree.model.model_blocks import SparseFC, UBlock, SubMConvBlock
8+
from smart_tree.model.model_blocks import MLP, UBlock, SubMConvBlock
99
from smart_tree.util.math.maths import torch_normalized
1010

1111
spconv.constants.SPCONV_ALLOW_TF32 = True
@@ -30,9 +30,7 @@ def __init__(
3030

3131
self.branch_classes = torch.tensor(branch_classes, device=device)
3232

33-
norm_fn = functools.partial(
34-
nn.BatchNorm1d, eps=1e-4, momentum=0.1
35-
) # , momentum=0.99)
33+
norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4) # , momentum=0.99)
3634
activation_fn = nn.ReLU
3735

3836
self.radius_loss = nn.L1Loss()
@@ -55,24 +53,9 @@ def __init__(
5553
algo=algo,
5654
)
5755

58-
self.radius_head = SparseFC(
59-
radius_fc_planes,
60-
norm_fn,
61-
activation_fn,
62-
algo=algo,
63-
)
64-
self.direction_head = SparseFC(
65-
direction_fc_planes,
66-
norm_fn,
67-
activation_fn,
68-
algo=algo,
69-
)
70-
self.class_head = SparseFC(
71-
class_fc_planes,
72-
norm_fn,
73-
activation_fn,
74-
algo=algo,
75-
)
56+
self.radius_head = MLP(radius_fc_planes, norm_fn, activation_fn)
57+
self.direction_head = MLP(direction_fc_planes, norm_fn, activation_fn)
58+
self.class_head = MLP(class_fc_planes, norm_fn, activation_fn)
7659

7760
self.apply(self.set_bn_init)
7861

@@ -87,12 +70,12 @@ def forward(self, input):
8770
x = self.input_conv(input)
8871
unet_out = self.UNet(x)
8972

90-
radius = self.radius_head(unet_out)
91-
direction = self.direction_head(unet_out)
92-
class_l = self.class_head(unet_out)
73+
radius = self.radius_head(unet_out).features
74+
direction = self.direction_head(unet_out).features
75+
class_l = self.class_head(unet_out).features
9376

9477
return torch.cat(
95-
[radius.features, F.normalize(direction.features), class_l.features],
78+
[radius, direction, class_l],
9679
dim=1,
9780
)
9881

@@ -105,7 +88,7 @@ def compute_loss(self, outputs, targets, mask=None):
10588
targets = targets[mask]
10689

10790
radius_pred = outputs[:, [0]]
108-
direction_pred = outputs[:, 1:4]
91+
direction_pred = F.normalize(outputs[:, 1:4])
10992
class_pred = outputs[:, 4:]
11093

11194
class_target = targets[:, [3]]
@@ -128,26 +111,23 @@ def compute_loss(self, outputs, targets, mask=None):
128111

129112
return losses
130113

131-
# @force_fp32(apply_to=("outputs", "targets"))
114+
@force_fp32(apply_to=("outputs", "targets"))
132115
def compute_radius_loss(self, outputs, targets):
133116
return self.radius_loss(outputs, torch.log(targets))
134117

135-
# @force_fp32(apply_to=("outputs", "targets"))
118+
@force_fp32(apply_to=("outputs", "targets"))
136119
def compute_direction_loss(self, outputs, targets):
137120
return torch.mean(1 - self.direction_loss(outputs, targets))
138121

139-
# @force_fp32(apply_to=("outputs", "targets"))
122+
@force_fp32(apply_to=("outputs", "targets"))
140123
def compute_class_loss(self, outputs, targets):
141-
return self.focal_loss(outputs, targets.long())
124+
return self.dice_loss(outputs, targets.long())
142125

143126
def dice_loss(self, outputs, targets):
144127
# https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08
145128
smooth = 1
146129
outputs = F.softmax(outputs, dim=1)
147-
targets = F.one_hot(targets)
148-
149-
outputs = outputs.view(-1)
150-
targets = targets.view(-1)
130+
targets = F.one_hot(targets).reshape(-1, 1)
151131

152132
intersection = (outputs * targets).sum()
153133

smart_tree/model/model_blocks.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -127,19 +127,19 @@ def __init__(
127127
output_channels,
128128
kernel_size=1,
129129
padding=1,
130-
bias=False,
130+
bias=bias,
131131
algo=algo,
132132
)
133133
)
134134

135135
self.sequence = spconv.SparseSequential(
136136
spconv.SubMConv3d(
137-
input_channels, output_channels, kernel_size, bias=False, algo=algo
137+
input_channels, output_channels, kernel_size, bias=bias, algo=algo
138138
),
139139
norm_fn(output_channels),
140140
activation_fn(),
141141
spconv.SubMConv3d(
142-
output_channels, output_channels, kernel_size, bias=False, algo=algo
142+
output_channels, output_channels, kernel_size, bias=bias, algo=algo
143143
),
144144
norm_fn(output_channels),
145145
)
@@ -244,43 +244,23 @@ def forward(self, input):
244244
return output
245245

246246

247-
class SparseFC(nn.Module):
247+
class MLP(nn.Module):
248248
def __init__(
249249
self,
250250
n_planes,
251251
norm_fn,
252252
activation_fn=None,
253-
kernel_size=1,
254-
algo=spconv.ConvAlgo.Native,
255-
bias=False,
256253
):
257254
super().__init__()
258255

259256
self.sequence = spconv.SparseSequential()
257+
260258
for i in range(len(n_planes) - 2):
261-
self.sequence.add(
262-
spconv.SubMConv3d(
263-
n_planes[i],
264-
n_planes[i + 1],
265-
kernel_size=kernel_size,
266-
bias=False,
267-
algo=algo,
268-
padding=0,
269-
)
270-
)
259+
self.sequence.add(nn.Linear(n_planes[i], n_planes[i + 1]))
271260
self.sequence.add(norm_fn(n_planes[i + 1]))
272261
self.sequence.add(activation_fn())
273262

274-
self.sequence.add(
275-
spconv.SubMConv3d(
276-
n_planes[-2],
277-
n_planes[-1],
278-
kernel_size=kernel_size,
279-
bias=False,
280-
algo=algo,
281-
padding=0,
282-
)
283-
)
263+
self.sequence.add(nn.Linear(n_planes[-2], n_planes[-1]))
284264

285265
def forward(self, input):
286266
return self.sequence(input)

smart_tree/model/model_inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from omegaconf import DictConfig, OmegaConf
1010
from py_structs.torch import map_tensors
1111
from torch.utils.data import DataLoader
12+
import torch.nn.functional as F
13+
1214
from tqdm import tqdm
1315

1416
import wandb
@@ -159,7 +161,7 @@ def test():
159161

160162
xyz = xyz[:, :3] # [mask]
161163
radius = out[:, [0]] # [mask]
162-
direction = out[:, 1:4] # [mask]
164+
direction = F.normalize(out[:, 1:4]) # [mask]
163165

164166
new_xyz = xyz + torch.exp(radius) * direction
165167

0 commit comments

Comments
 (0)