55import torch .nn as nn
66import torch .nn .functional as F
77
8- from smart_tree .model .model_blocks import SparseFC , UBlock , SubMConvBlock
8+ from smart_tree .model .model_blocks import MLP , UBlock , SubMConvBlock
99from smart_tree .util .math .maths import torch_normalized
1010
1111spconv .constants .SPCONV_ALLOW_TF32 = True
@@ -30,9 +30,7 @@ def __init__(
3030
3131 self .branch_classes = torch .tensor (branch_classes , device = device )
3232
33- norm_fn = functools .partial (
34- nn .BatchNorm1d , eps = 1e-4 , momentum = 0.1
35- ) # , momentum=0.99)
33+ norm_fn = functools .partial (nn .BatchNorm1d , eps = 1e-4 ) # , momentum=0.99)
3634 activation_fn = nn .ReLU
3735
3836 self .radius_loss = nn .L1Loss ()
@@ -55,24 +53,9 @@ def __init__(
5553 algo = algo ,
5654 )
5755
58- self .radius_head = SparseFC (
59- radius_fc_planes ,
60- norm_fn ,
61- activation_fn ,
62- algo = algo ,
63- )
64- self .direction_head = SparseFC (
65- direction_fc_planes ,
66- norm_fn ,
67- activation_fn ,
68- algo = algo ,
69- )
70- self .class_head = SparseFC (
71- class_fc_planes ,
72- norm_fn ,
73- activation_fn ,
74- algo = algo ,
75- )
56+ self .radius_head = MLP (radius_fc_planes , norm_fn , activation_fn )
57+ self .direction_head = MLP (direction_fc_planes , norm_fn , activation_fn )
58+ self .class_head = MLP (class_fc_planes , norm_fn , activation_fn )
7659
7760 self .apply (self .set_bn_init )
7861
@@ -87,12 +70,12 @@ def forward(self, input):
8770 x = self .input_conv (input )
8871 unet_out = self .UNet (x )
8972
90- radius = self .radius_head (unet_out )
91- direction = self .direction_head (unet_out )
92- class_l = self .class_head (unet_out )
73+ radius = self .radius_head (unet_out ). features
74+ direction = self .direction_head (unet_out ). features
75+ class_l = self .class_head (unet_out ). features
9376
9477 return torch .cat (
95- [radius . features , F . normalize ( direction . features ) , class_l . features ],
78+ [radius , direction , class_l ],
9679 dim = 1 ,
9780 )
9881
@@ -105,7 +88,7 @@ def compute_loss(self, outputs, targets, mask=None):
10588 targets = targets [mask ]
10689
10790 radius_pred = outputs [:, [0 ]]
108- direction_pred = outputs [:, 1 :4 ]
91+ direction_pred = F . normalize ( outputs [:, 1 :4 ])
10992 class_pred = outputs [:, 4 :]
11093
11194 class_target = targets [:, [3 ]]
@@ -128,26 +111,23 @@ def compute_loss(self, outputs, targets, mask=None):
128111
129112 return losses
130113
131- # @force_fp32(apply_to=("outputs", "targets"))
114+ @force_fp32 (apply_to = ("outputs" , "targets" ))
132115 def compute_radius_loss (self , outputs , targets ):
133116 return self .radius_loss (outputs , torch .log (targets ))
134117
135- # @force_fp32(apply_to=("outputs", "targets"))
118+ @force_fp32 (apply_to = ("outputs" , "targets" ))
136119 def compute_direction_loss (self , outputs , targets ):
137120 return torch .mean (1 - self .direction_loss (outputs , targets ))
138121
139- # @force_fp32(apply_to=("outputs", "targets"))
122+ @force_fp32 (apply_to = ("outputs" , "targets" ))
140123 def compute_class_loss (self , outputs , targets ):
141- return self .focal_loss (outputs , targets .long ())
124+ return self .dice_loss (outputs , targets .long ())
142125
143126 def dice_loss (self , outputs , targets ):
144127 # https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08
145128 smooth = 1
146129 outputs = F .softmax (outputs , dim = 1 )
147- targets = F .one_hot (targets )
148-
149- outputs = outputs .view (- 1 )
150- targets = targets .view (- 1 )
130+ targets = F .one_hot (targets ).reshape (- 1 , 1 )
151131
152132 intersection = (outputs * targets ).sum ()
153133
0 commit comments