1313import torch
1414from monai .transforms import Flip
1515
16+ try :
17+ from omegaconf import ListConfig
18+ HAS_OMEGACONF = True
19+ except ImportError :
20+ HAS_OMEGACONF = False
21+ ListConfig = list # Fallback
22+
1623
1724class TTAPredictor :
1825 """Encapsulates TTA preprocessing and flip ensemble logic."""
@@ -126,7 +133,7 @@ def _sliding_window_predict(self, inputs: torch.Tensor) -> torch.Tensor:
126133
127134 def predict (self , images : torch .Tensor , mask : Optional [torch .Tensor ] = None ) -> torch .Tensor :
128135 """
129- Perform test-time augmentation using flips and ensemble predictions.
136+ Perform test-time augmentation using flips, rotations, and ensemble predictions.
130137
131138 Args:
132139 images: Input volume (B, C, D, H, W) or (B, D, H, W) or (D, H, W)
@@ -153,17 +160,24 @@ def predict(self, images: torch.Tensor, mask: Optional[torch.Tensor] = None) ->
153160 if getattr (self .cfg .data , "do_2d" , False ) and images .size (2 ) == 1 :
154161 images = images .squeeze (2 )
155162
163+ # Get TTA configuration
156164 if hasattr (self .cfg , "inference" ) and hasattr (self .cfg .inference , "test_time_augmentation" ):
157165 tta_flip_axes_config = getattr (
158166 self .cfg .inference .test_time_augmentation , "flip_axes" , None
159167 )
168+ tta_rotation90_axes_config = getattr (
169+ self .cfg .inference .test_time_augmentation , "rotation90_axes" , None
170+ )
160171 else :
161172 tta_flip_axes_config = None
173+ tta_rotation90_axes_config = None
162174
163- if tta_flip_axes_config is None :
175+ # If no augmentation configured, run network once
176+ if tta_flip_axes_config is None and tta_rotation90_axes_config is None :
164177 pred = self ._run_network (images )
165178 ensemble_result = self .apply_preprocessing (pred )
166179 else :
180+ # Parse flip axes configuration
167181 if tta_flip_axes_config == "all" or tta_flip_axes_config == []:
168182 if images .dim () == 5 :
169183 spatial_axes = [1 , 2 , 3 ]
@@ -178,34 +192,127 @@ def predict(self, images: torch.Tensor, mask: Optional[torch.Tensor] = None) ->
178192
179193 for combo in combinations (spatial_axes , r ):
180194 tta_flip_axes .append (list (combo ))
195+ elif HAS_OMEGACONF and isinstance (tta_flip_axes_config , ListConfig ):
196+ # OmegaConf ListConfig - convert to regular list
197+ tta_flip_axes_config = [
198+ list (item ) if isinstance (item , ListConfig ) else item
199+ for item in tta_flip_axes_config
200+ ]
201+ tta_flip_axes = [[]] + tta_flip_axes_config
181202 elif isinstance (tta_flip_axes_config , (list , tuple )):
182203 tta_flip_axes = [[]] + list (tta_flip_axes_config )
204+ elif tta_flip_axes_config is None :
205+ tta_flip_axes = [[]] # No flip augmentation
183206 else :
184207 raise ValueError (
185208 f"Invalid tta_flip_axes: { tta_flip_axes_config } . "
186209 f"Expected 'all' (8 flips), null (no aug), or list of flip axes."
187210 )
188211
212+ # Parse rotation90 axes configuration
213+ # NOTE: We use torch.rot90 which expects full tensor axes
214+ # For 5D tensor (B, C, D, H, W): D=2, H=3, W=4
215+ # For 4D tensor (B, C, H, W): H=2, W=3
216+ # Spatial axes from config (0=D, 1=H, 2=W) need to be converted
217+ spatial_offset = 2 # Offset for batch and channel dimensions
218+
219+ if tta_rotation90_axes_config == "all" :
220+ if images .dim () == 5 :
221+ # For 3D data (B, C, D, H, W), all possible rotation planes
222+ tta_rotation90_axes = [
223+ (2 , 3 ), # D-H plane
224+ (2 , 4 ), # D-W plane
225+ (3 , 4 ), # H-W plane
226+ ]
227+ elif images .dim () == 4 :
228+ # For 2D data (B, C, H, W), only one rotation plane
229+ tta_rotation90_axes = [(2 , 3 )] # H-W plane
230+ else :
231+ raise ValueError (f"Unsupported data dimensions: { images .dim ()} " )
232+ elif HAS_OMEGACONF and isinstance (tta_rotation90_axes_config , ListConfig ):
233+ # OmegaConf ListConfig - convert to list and process
234+ tta_rotation90_axes_config = list (tta_rotation90_axes_config )
235+ if len (tta_rotation90_axes_config ) > 0 :
236+ tta_rotation90_axes = []
237+ for axes in tta_rotation90_axes_config :
238+ if HAS_OMEGACONF and isinstance (axes , ListConfig ):
239+ axes = list (axes )
240+ if not isinstance (axes , (list , tuple )) or len (axes ) != 2 :
241+ raise ValueError (
242+ f"Invalid rotation plane: { axes } . Each plane must be a list/tuple of 2 axes."
243+ )
244+ # Convert spatial axes to full tensor axes
245+ full_axes = tuple (a + spatial_offset for a in axes )
246+ tta_rotation90_axes .append (full_axes )
247+ else :
248+ tta_rotation90_axes = []
249+ elif isinstance (tta_rotation90_axes_config , (list , tuple )) and len (tta_rotation90_axes_config ) > 0 :
250+ # User-specified rotation planes: e.g., [[1, 2], [2, 3]]
251+ # Validate that each entry is a list/tuple of length 2
252+ tta_rotation90_axes = []
253+ for axes in tta_rotation90_axes_config :
254+ if not isinstance (axes , (list , tuple )) or len (axes ) != 2 :
255+ raise ValueError (
256+ f"Invalid rotation plane: { axes } . Each plane must be a list/tuple of 2 axes."
257+ )
258+ # Convert spatial axes to full tensor axes
259+ full_axes = tuple (a + spatial_offset for a in axes )
260+ tta_rotation90_axes .append (full_axes )
261+ elif tta_rotation90_axes_config is None :
262+ tta_rotation90_axes = [] # No rotation augmentation
263+ else :
264+ raise ValueError (
265+ f"Invalid tta_rotation90_axes: { tta_rotation90_axes_config } . "
266+ f"Expected 'all', null (no rotation), or list of rotation planes like [[1, 2]]."
267+ )
268+
189269 ensemble_mode = getattr (
190270 self .cfg .inference .test_time_augmentation , "ensemble_mode" , "mean"
191271 )
192272
193273 ensemble_result = None
194274 num_predictions = 0
195275
276+ # Generate all combinations of (flip_axes, rotation_plane, k_rotations)
277+ # For each rotation plane, we try k=0,1,2,3 (0°, 90°, 180°, 270°)
278+ augmentation_combinations = []
279+
196280 for flip_axes in tta_flip_axes :
197- if flip_axes :
198- x_aug = Flip (spatial_axis = flip_axes )(images )
281+ if not tta_rotation90_axes :
282+ # No rotation: just add flip augmentation
283+ augmentation_combinations .append ((flip_axes , None , 0 ))
199284 else :
200- x_aug = images
285+ # Add all rotation combinations for this flip
286+ for rotation_plane in tta_rotation90_axes :
287+ for k in range (4 ): # 0, 1, 2, 3 rotations (0°, 90°, 180°, 270°)
288+ augmentation_combinations .append ((flip_axes , rotation_plane , k ))
201289
290+ # Apply each augmentation combination
291+ for flip_axes , rotation_plane , k_rotations in augmentation_combinations :
292+ x_aug = images
293+
294+ # Apply flip augmentation
295+ if flip_axes :
296+ x_aug = Flip (spatial_axis = flip_axes )(x_aug )
297+
298+ # Apply rotation augmentation using torch.rot90
299+ if rotation_plane is not None and k_rotations > 0 :
300+ x_aug = torch .rot90 (x_aug , k = k_rotations , dims = rotation_plane )
301+
302+ # Run network
202303 pred = self ._run_network (x_aug )
203304
305+ # Reverse rotation augmentation
306+ if rotation_plane is not None and k_rotations > 0 :
307+ pred = torch .rot90 (pred , k = - k_rotations , dims = rotation_plane )
308+
309+ # Reverse flip augmentation
204310 if flip_axes :
205311 pred = Flip (spatial_axis = flip_axes )(pred )
206312
207313 pred_processed = self .apply_preprocessing (pred )
208314
315+ # Ensemble predictions
209316 if ensemble_result is None :
210317 ensemble_result = pred_processed .clone ()
211318 else :
0 commit comments