Skip to content

Commit ea5bbad

Browse files
author
Donglai Wei
committed
fix hydra-lv.yaml bug: loss function overloading
1 parent 89937ca commit ea5bbad

File tree

9 files changed

+491
-43
lines changed

9 files changed

+491
-43
lines changed

connectomics/config/hydra_config.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ class SystemConfig:
107107
print_auto_plan: bool = True # Print auto-planning results
108108

109109

110+
@dataclass
111+
class LossBalancingConfig:
112+
"""Configuration for adaptive loss weighting."""
113+
114+
strategy: Optional[str] = None # None, "uncertainty", or "gradnorm"
115+
gradnorm_alpha: float = 0.5
116+
gradnorm_lambda: float = 1.0
117+
gradnorm_parameter_strategy: str = "last" # "first", "last", or "all"
118+
119+
110120
@dataclass
111121
class ModelConfig:
112122
"""Model architecture configuration.
@@ -199,6 +209,7 @@ class ModelConfig:
199209
loss_functions: List[str] = field(default_factory=lambda: ["DiceLoss", "BCEWithLogitsLoss"])
200210
loss_weights: List[float] = field(default_factory=lambda: [1.0, 1.0])
201211
loss_kwargs: List[dict] = field(default_factory=lambda: [{}, {}]) # Per-loss kwargs
212+
loss_balancing: LossBalancingConfig = field(default_factory=LossBalancingConfig)
202213

203214
# Multi-task learning configuration
204215
# Defines which output channels correspond to which targets
@@ -881,11 +892,18 @@ class TestTimeAugmentationConfig:
881892
"""Test-time augmentation configuration.
882893
883894
Note: Saving predictions is now handled by SavePredictionConfig.
895+
896+
Axis Indexing:
897+
- flip_axes: Uses full tensor indices (e.g., [2, 3] for H, W in 5D tensor (B, C, D, H, W))
898+
- rotation90_axes: Uses spatial-only indices (e.g., [1, 2] for H-W plane where 0=D, 1=H, 2=W)
884899
"""
885900

886901
enabled: bool = False
887902
flip_axes: Any = (
888-
None # TTA flip strategy: "all" (8 flips), null (no aug), or list like [[0], [1], [2]]
903+
None # TTA flip strategy: "all" (8 flips), null (no aug), or list like [[2], [3]] (full tensor indices)
904+
)
905+
rotation90_axes: Any = (
906+
None # TTA rotation90 strategy: "all" (3 planes × 4 rotations), null, or list like [[1, 2]] (spatial indices: 0=D, 1=H, 2=W)
889907
)
890908
channel_activations: Optional[List[Any]] = (
891909
None # Per-channel activations: [[start_ch, end_ch, 'activation'], ...] e.g., [[0, 2, 'softmax'], [2, 3, 'sigmoid'], [3, 4, 'tanh']]
@@ -1058,9 +1076,11 @@ class InferenceConfig:
10581076
@dataclass
10591077
class TestDataConfig:
10601078
"""Test data configuration."""
1061-
test_image: Optional[str] = None
1062-
test_label: Optional[str] = None
1063-
test_mask: Optional[str] = None
1079+
# These can be strings (single file), lists (multiple files), or None
1080+
# Using Any to support both str and List[str] (OmegaConf doesn't support Union of containers)
1081+
test_image: Any = None # str, List[str], or None
1082+
test_label: Any = None # str, List[str], or None
1083+
test_mask: Any = None # str, List[str], or None
10641084
test_resolution: Optional[List[float]] = None
10651085
test_transpose: Optional[List[int]] = None
10661086
output_path: Optional[str] = None
@@ -1080,9 +1100,11 @@ class TestConfig:
10801100
@dataclass
10811101
class TuneDataConfig:
10821102
"""Tuning data configuration."""
1083-
tune_image: Optional[str] = None
1084-
tune_label: Optional[str] = None
1085-
tune_mask: Optional[str] = None
1103+
# These can be strings (single file), lists (multiple files), or None
1104+
# Using Any to support both str and List[str] (OmegaConf doesn't support Union of containers)
1105+
tune_image: Any = None # str, List[str], or None
1106+
tune_label: Any = None # str, List[str], or None
1107+
tune_mask: Any = None # str, List[str], or None
10861108
tune_resolution: Optional[List[int]] = None
10871109
# Image transformation (applied to tune images during inference)
10881110
image_transform: ImageTransformConfig = field(default_factory=ImageTransformConfig)

connectomics/inference/tta.py

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
import torch
1414
from monai.transforms import Flip
1515

16+
try:
17+
from omegaconf import ListConfig
18+
HAS_OMEGACONF = True
19+
except ImportError:
20+
HAS_OMEGACONF = False
21+
ListConfig = list # Fallback
22+
1623

1724
class TTAPredictor:
1825
"""Encapsulates TTA preprocessing and flip ensemble logic."""
@@ -126,7 +133,7 @@ def _sliding_window_predict(self, inputs: torch.Tensor) -> torch.Tensor:
126133

127134
def predict(self, images: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
128135
"""
129-
Perform test-time augmentation using flips and ensemble predictions.
136+
Perform test-time augmentation using flips, rotations, and ensemble predictions.
130137
131138
Args:
132139
images: Input volume (B, C, D, H, W) or (B, D, H, W) or (D, H, W)
@@ -153,17 +160,24 @@ def predict(self, images: torch.Tensor, mask: Optional[torch.Tensor] = None) ->
153160
if getattr(self.cfg.data, "do_2d", False) and images.size(2) == 1:
154161
images = images.squeeze(2)
155162

163+
# Get TTA configuration
156164
if hasattr(self.cfg, "inference") and hasattr(self.cfg.inference, "test_time_augmentation"):
157165
tta_flip_axes_config = getattr(
158166
self.cfg.inference.test_time_augmentation, "flip_axes", None
159167
)
168+
tta_rotation90_axes_config = getattr(
169+
self.cfg.inference.test_time_augmentation, "rotation90_axes", None
170+
)
160171
else:
161172
tta_flip_axes_config = None
173+
tta_rotation90_axes_config = None
162174

163-
if tta_flip_axes_config is None:
175+
# If no augmentation configured, run network once
176+
if tta_flip_axes_config is None and tta_rotation90_axes_config is None:
164177
pred = self._run_network(images)
165178
ensemble_result = self.apply_preprocessing(pred)
166179
else:
180+
# Parse flip axes configuration
167181
if tta_flip_axes_config == "all" or tta_flip_axes_config == []:
168182
if images.dim() == 5:
169183
spatial_axes = [1, 2, 3]
@@ -178,34 +192,127 @@ def predict(self, images: torch.Tensor, mask: Optional[torch.Tensor] = None) ->
178192

179193
for combo in combinations(spatial_axes, r):
180194
tta_flip_axes.append(list(combo))
195+
elif HAS_OMEGACONF and isinstance(tta_flip_axes_config, ListConfig):
196+
# OmegaConf ListConfig - convert to regular list
197+
tta_flip_axes_config = [
198+
list(item) if isinstance(item, ListConfig) else item
199+
for item in tta_flip_axes_config
200+
]
201+
tta_flip_axes = [[]] + tta_flip_axes_config
181202
elif isinstance(tta_flip_axes_config, (list, tuple)):
182203
tta_flip_axes = [[]] + list(tta_flip_axes_config)
204+
elif tta_flip_axes_config is None:
205+
tta_flip_axes = [[]] # No flip augmentation
183206
else:
184207
raise ValueError(
185208
f"Invalid tta_flip_axes: {tta_flip_axes_config}. "
186209
f"Expected 'all' (8 flips), null (no aug), or list of flip axes."
187210
)
188211

212+
# Parse rotation90 axes configuration
213+
# NOTE: We use torch.rot90 which expects full tensor axes
214+
# For 5D tensor (B, C, D, H, W): D=2, H=3, W=4
215+
# For 4D tensor (B, C, H, W): H=2, W=3
216+
# Spatial axes from config (0=D, 1=H, 2=W) need to be converted
217+
spatial_offset = 2 # Offset for batch and channel dimensions
218+
219+
if tta_rotation90_axes_config == "all":
220+
if images.dim() == 5:
221+
# For 3D data (B, C, D, H, W), all possible rotation planes
222+
tta_rotation90_axes = [
223+
(2, 3), # D-H plane
224+
(2, 4), # D-W plane
225+
(3, 4), # H-W plane
226+
]
227+
elif images.dim() == 4:
228+
# For 2D data (B, C, H, W), only one rotation plane
229+
tta_rotation90_axes = [(2, 3)] # H-W plane
230+
else:
231+
raise ValueError(f"Unsupported data dimensions: {images.dim()}")
232+
elif HAS_OMEGACONF and isinstance(tta_rotation90_axes_config, ListConfig):
233+
# OmegaConf ListConfig - convert to list and process
234+
tta_rotation90_axes_config = list(tta_rotation90_axes_config)
235+
if len(tta_rotation90_axes_config) > 0:
236+
tta_rotation90_axes = []
237+
for axes in tta_rotation90_axes_config:
238+
if HAS_OMEGACONF and isinstance(axes, ListConfig):
239+
axes = list(axes)
240+
if not isinstance(axes, (list, tuple)) or len(axes) != 2:
241+
raise ValueError(
242+
f"Invalid rotation plane: {axes}. Each plane must be a list/tuple of 2 axes."
243+
)
244+
# Convert spatial axes to full tensor axes
245+
full_axes = tuple(a + spatial_offset for a in axes)
246+
tta_rotation90_axes.append(full_axes)
247+
else:
248+
tta_rotation90_axes = []
249+
elif isinstance(tta_rotation90_axes_config, (list, tuple)) and len(tta_rotation90_axes_config) > 0:
250+
# User-specified rotation planes: e.g., [[1, 2], [2, 3]]
251+
# Validate that each entry is a list/tuple of length 2
252+
tta_rotation90_axes = []
253+
for axes in tta_rotation90_axes_config:
254+
if not isinstance(axes, (list, tuple)) or len(axes) != 2:
255+
raise ValueError(
256+
f"Invalid rotation plane: {axes}. Each plane must be a list/tuple of 2 axes."
257+
)
258+
# Convert spatial axes to full tensor axes
259+
full_axes = tuple(a + spatial_offset for a in axes)
260+
tta_rotation90_axes.append(full_axes)
261+
elif tta_rotation90_axes_config is None:
262+
tta_rotation90_axes = [] # No rotation augmentation
263+
else:
264+
raise ValueError(
265+
f"Invalid tta_rotation90_axes: {tta_rotation90_axes_config}. "
266+
f"Expected 'all', null (no rotation), or list of rotation planes like [[1, 2]]."
267+
)
268+
189269
ensemble_mode = getattr(
190270
self.cfg.inference.test_time_augmentation, "ensemble_mode", "mean"
191271
)
192272

193273
ensemble_result = None
194274
num_predictions = 0
195275

276+
# Generate all combinations of (flip_axes, rotation_plane, k_rotations)
277+
# For each rotation plane, we try k=0,1,2,3 (0°, 90°, 180°, 270°)
278+
augmentation_combinations = []
279+
196280
for flip_axes in tta_flip_axes:
197-
if flip_axes:
198-
x_aug = Flip(spatial_axis=flip_axes)(images)
281+
if not tta_rotation90_axes:
282+
# No rotation: just add flip augmentation
283+
augmentation_combinations.append((flip_axes, None, 0))
199284
else:
200-
x_aug = images
285+
# Add all rotation combinations for this flip
286+
for rotation_plane in tta_rotation90_axes:
287+
for k in range(4): # 0, 1, 2, 3 rotations (0°, 90°, 180°, 270°)
288+
augmentation_combinations.append((flip_axes, rotation_plane, k))
201289

290+
# Apply each augmentation combination
291+
for flip_axes, rotation_plane, k_rotations in augmentation_combinations:
292+
x_aug = images
293+
294+
# Apply flip augmentation
295+
if flip_axes:
296+
x_aug = Flip(spatial_axis=flip_axes)(x_aug)
297+
298+
# Apply rotation augmentation using torch.rot90
299+
if rotation_plane is not None and k_rotations > 0:
300+
x_aug = torch.rot90(x_aug, k=k_rotations, dims=rotation_plane)
301+
302+
# Run network
202303
pred = self._run_network(x_aug)
203304

305+
# Reverse rotation augmentation
306+
if rotation_plane is not None and k_rotations > 0:
307+
pred = torch.rot90(pred, k=-k_rotations, dims=rotation_plane)
308+
309+
# Reverse flip augmentation
204310
if flip_axes:
205311
pred = Flip(spatial_axis=flip_axes)(pred)
206312

207313
pred_processed = self.apply_preprocessing(pred)
208314

315+
# Ensemble predictions
209316
if ensemble_result is None:
210317
ensemble_result = pred_processed.clone()
211318
else:

connectomics/models/loss/build.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
WeightedBCEWithLogitsLoss,
2929
WeightedMSELoss,
3030
WeightedMAELoss,
31+
SmoothL1Loss,
3132
GANLoss,
3233
)
3334

@@ -77,6 +78,7 @@ def create_loss(
7778
'CrossEntropyLoss': CrossEntropyLossWrapper, # Use wrapper for shape handling
7879
'MSELoss': nn.MSELoss,
7980
'L1Loss': nn.L1Loss,
81+
'SmoothL1Loss': SmoothL1Loss,
8082

8183
# Custom connectomics losses
8284
'WeightedBCEWithLogitsLoss': WeightedBCEWithLogitsLoss,
@@ -320,4 +322,4 @@ def list_available_losses() -> List[str]:
320322
'create_multiclass_segmentation_loss',
321323
'create_focal_loss',
322324
'list_available_losses',
323-
]
325+
]

connectomics/models/loss/losses.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,41 @@ def forward(
226226
return mae
227227

228228

229+
class SmoothL1Loss(nn.Module):
230+
"""
231+
Smooth L1 (Huber) loss with optional tanh activation and spatial weighting.
232+
233+
Useful for distance transform regression where large outliers should be
234+
down-weighted relative to MSE.
235+
"""
236+
237+
def __init__(self, beta: float = 1.0, reduction: str = 'mean', tanh: bool = False):
238+
super().__init__()
239+
self.beta = beta
240+
self.reduction = reduction
241+
self.tanh = tanh
242+
243+
def forward(
244+
self,
245+
pred: torch.Tensor,
246+
target: torch.Tensor,
247+
weight: torch.Tensor = None,
248+
) -> torch.Tensor:
249+
if self.tanh:
250+
pred = torch.tanh(pred)
251+
252+
loss = F.smooth_l1_loss(pred, target, beta=self.beta, reduction='none')
253+
254+
if weight is not None:
255+
loss = loss * weight
256+
257+
if self.reduction == 'mean':
258+
return loss.mean()
259+
elif self.reduction == 'sum':
260+
return loss.sum()
261+
return loss
262+
263+
229264
class GANLoss(nn.Module):
230265
"""
231266
GAN loss for adversarial training.
@@ -316,4 +351,4 @@ def forward(
316351
'WeightedMSELoss',
317352
'WeightedMAELoss',
318353
'GANLoss',
319-
]
354+
]

0 commit comments

Comments
 (0)