From 7f723086590b6ed25097cb0c3c15fca68ff4393e Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 14 Jan 2026 10:48:28 +0000 Subject: [PATCH 1/7] fixed _apply_region_mode logic by cloning the segmentation so as not to change it, and also using argmax instead of sum and inverting at the correct order to avoid multiplying the effect of the transformation out of proportions --- auglab/transforms/gpu/contrast.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/auglab/transforms/gpu/contrast.py b/auglab/transforms/gpu/contrast.py index 4638a1b..86b4ce4 100644 --- a/auglab/transforms/gpu/contrast.py +++ b/auglab/transforms/gpu/contrast.py @@ -40,7 +40,7 @@ def _apply_region_mode(orig: torch.Tensor, transformed: torch.Tensor, seg_mask: """ if seg_mask is None or mode == 'all': return transformed - + # Rescale transformed based on min max orig # Needed due to the important change in the image if orig.dim() == 4: @@ -51,17 +51,19 @@ def _apply_region_mode(orig: torch.Tensor, transformed: torch.Tensor, seg_mask: transformed_max = torch.amax(transformed, dim=tuple(range(1, transformed.dim())), keepdim=True) transformed = (transformed - transformed_min) / (transformed_max - transformed_min + 1e-8) * (orig_max - orig_min) + orig_min - m = seg_mask.to(transformed.dtype) - if mode == 'out': - m = 1.0 - m + m = seg_mask.to(transformed.dtype).clone() if mix_in_out: for i in range(seg_mask.shape[0]): # Create a tensor with random one and zero - + o = torch.randint(0, 2, (seg_mask.shape[1],), device=seg_mask.device, dtype=seg_mask.dtype) m[i] = m[i] * o.view(-1, 1, 1, 1) # Broadcasting o to match the dimensions of m - - m = torch.sum(m, axis=1) + + m = torch.argmax(m, axis=1) > 0 + m = m.to(transformed.dtype) + if mode == "out": + m = 1.0 - m + elif orig.dim() == 3: if normalize: orig_min = torch.amin(orig) @@ -69,15 +71,17 @@ def _apply_region_mode(orig: torch.Tensor, transformed: torch.Tensor, seg_mask: transformed_min = torch.amin(transformed) transformed_max = torch.amax(transformed) transformed = (transformed - transformed_min) / (transformed_max - transformed_min + 1e-8) * (orig_max - orig_min) + orig_min - - m = seg_mask.to(transformed.dtype) - if mode == 'out': - m = 1.0 - m + + m = seg_mask.to(transformed.dtype).clone() if mix_in_out: # Create a tensor with random one and zero o = torch.randint(0, 2, (seg_mask.shape[0],), device=seg_mask.device, dtype=seg_mask.dtype) m = m * o.view(-1, 1, 1, 1) # Broadcasting o to match the dimensions of m - m = torch.sum(m, axis=0) + m = torch.argmax(m, axis=0) > 0 + m = m.to(transformed.dtype) + if mode == "out": + m = 1.0 - m + else: raise ValueError(f"Only 4D and 3D images are supported. Got {orig.dim()}D.") From 95a2a7c1fdbc354b2ea466cc024747649e6626f7 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 14 Jan 2026 14:54:17 +0000 Subject: [PATCH 2/7] fixed ordering of mix_prob and added it to inverse and histogram equalization --- auglab/transforms/gpu/contrast.py | 51 +++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/auglab/transforms/gpu/contrast.py b/auglab/transforms/gpu/contrast.py index 86b4ce4..570162b 100644 --- a/auglab/transforms/gpu/contrast.py +++ b/auglab/transforms/gpu/contrast.py @@ -29,7 +29,15 @@ def _choose_region_mode(p_in: float, p_out: float, seg_mask: Optional[torch.Tens return 'out' return 'all' -def _apply_region_mode(orig: torch.Tensor, transformed: torch.Tensor, seg_mask: Optional[torch.Tensor], mode: str, normalize: bool = True, mix_in_out: bool = False) -> torch.Tensor: + +def _apply_region_mode( + orig: torch.Tensor, + transformed: torch.Tensor, + seg_mask: Optional[torch.Tensor], + mode: str, + normalize: bool = False, + mix_in_out: bool = False, +) -> torch.Tensor: """Blend transformed with orig based on region selection mode. - mode 'all': return transformed @@ -87,6 +95,7 @@ def _apply_region_mode(orig: torch.Tensor, transformed: torch.Tensor, seg_mask: return m * transformed + (1.0 - m) * orig + ## Convolution transform class RandomConvTransformGPU(ImageOnlyTransform): """Apply convolution to image. @@ -195,7 +204,7 @@ def apply_transform( ) -> Tensor: # Initialize kernel kernel = self.get_kernel(device=input.device) - + # Load segmentation seg_mask = params.get('seg', None) @@ -238,6 +247,11 @@ def apply_transform( x = torch.stack(out, dim=0) + # Mix with original based on mix_prob + if torch.rand(1).item() < self.mix_prob: + alpha = torch.rand(1, device=input.device) + x = alpha * orig + (1 - alpha) * x + if self.retain_stats: # Adjust mean and std to match original eps = 1e-8 @@ -251,23 +265,18 @@ def apply_transform( om = orig_means.view(shape) os = orig_stds.view(shape) x = (x - nm) / (ns + eps) * os + om - + # Apply region selection if not seg_mask is None: region_mode = _choose_region_mode(self.in_seg, self.out_seg, seg_mask) x = _apply_region_mode(orig, x, seg_mask, region_mode, mix_in_out=self.mix_in_out) - - # Mix with original based on mix_prob - if torch.rand(1).item() < self.mix_prob: - alpha = torch.rand(1, device=input.device) - x = alpha * orig + (1 - alpha) * x - + # Final safety: check if nan/inf appeared if torch.isnan(x).any() or torch.isinf(x).any(): print(f"Warning nan: {self.__class__.__name__} with kernel={self.kernel_type}", flush=True) continue input[:, c] = x - + return input @@ -794,6 +803,7 @@ def __init__( in_seg: float = 0.0, out_seg: float = 0.0, mix_in_out: bool = False, + mix_prob: float = 0.0, p: float = 1.0, keepdim: bool = True, **kwargs, @@ -803,6 +813,7 @@ def __init__( self.retain_stats = retain_stats self.in_seg = in_seg self.out_seg = out_seg + self.mix_prob = mix_prob self.mix_in_out = mix_in_out @torch.no_grad() # disable gradients for efficiency @@ -821,12 +832,19 @@ def apply_transform( orig_stds = x.std() max_val = x.max() x = max_val - x + if self.retain_stats: # Adjust mean and std to match original eps = 1e-8 new_mean = x.mean() # scalar new_std = x.std() # scalar x = (x - new_mean) / (new_std + eps) * orig_stds + orig_means + + # Mix with original based on mix_prob + if torch.rand(1).item() < self.mix_prob: + alpha = torch.rand(1, device=input.device) + x = alpha * orig + (1 - alpha) * x + if not seg_mask is None: region_mode = _choose_region_mode(self.in_seg, self.out_seg, seg_mask[i]) x = _apply_region_mode(orig, x, seg_mask[i], region_mode, mix_in_out=self.mix_in_out) @@ -835,7 +853,7 @@ def apply_transform( print(f"Warning nan: {self.__class__.__name__}", flush=True) continue input[i, c] = x - + return input @@ -863,6 +881,7 @@ def __init__( in_seg: float = 0.0, out_seg: float = 0.0, mix_in_out: bool = False, + mix_prob: float = 0.0, p: float = 1.0, keepdim: bool = True, **kwargs, @@ -873,6 +892,7 @@ def __init__( self.in_seg = in_seg self.out_seg = out_seg self.mix_in_out = mix_in_out + self.mix_prob = mix_prob @torch.no_grad() # disable gradients for efficiency def apply_transform( @@ -884,7 +904,7 @@ def apply_transform( for c in self.apply_to_channel: channel_data = input[:, c] # shape [N, ...spatial...] orig = channel_data.clone() - + if self.retain_stats: reduce_dims = tuple(range(1, channel_data.dim())) # store per-sample mean/std (shape [N]) @@ -917,6 +937,11 @@ def apply_transform( img_eq = cdf[indices] channel_data[b] = img_eq.reshape(img_b.shape) + # Mix with original based on mix_prob + if torch.rand(1).item() < self.mix_prob: + alpha = torch.rand(1, device=input.device) + channel_data[b] = alpha * orig[b] + (1 - alpha) * channel_data[b] + if self.retain_stats: # Adjust mean and std to match original eps = 1e-8 @@ -930,7 +955,7 @@ def apply_transform( om = orig_means.view(shape) os = orig_stds.view(shape) channel_data = (channel_data - nm) / (ns + eps) * os + om - + if not seg_mask is None: region_mode = _choose_region_mode(self.in_seg, self.out_seg, seg_mask) channel_data = _apply_region_mode(orig, channel_data, seg_mask, region_mode, mix_in_out=self.mix_in_out) From d72198020a2d15918e1f8bbd532e2a73fdc9de63 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 14 Jan 2026 14:54:44 +0000 Subject: [PATCH 3/7] embedded new arguments --- auglab/transforms/gpu/transforms.py | 56 ++++++++++++++++------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/auglab/transforms/gpu/transforms.py b/auglab/transforms/gpu/transforms.py index 79eac7f..e150264 100644 --- a/auglab/transforms/gpu/transforms.py +++ b/auglab/transforms/gpu/transforms.py @@ -19,7 +19,7 @@ def __init__(self, json_path: str): config_path = os.path.join(json_path) with open(config_path, 'r') as f: config = json.load(f) - + if 'GPU' in config.keys(): self.transform_params = config['GPU'] else: @@ -52,30 +52,36 @@ def _build_transforms(self) -> list[nn.Module]: resample=affine_params.get('resample', "bilinear"), p=affine_params.get('probability', 0) )) - + ## Transfer augmentations (TA) # Inverse transform (max - pixel_value) inverse_params = self.transform_params.get('InverseTransform') if inverse_params is not None: - transforms.append(RandomInverseGPU( - p=inverse_params.get('probability', 0), - in_seg=inverse_params.get('in_seg', 0.0), - out_seg=inverse_params.get('out_seg', 0.0), - mix_in_out=inverse_params.get('mix_in_out', False), - retain_stats=inverse_params.get('retain_stats', False), - )) - + transforms.append( + RandomInverseGPU( + p=inverse_params.get("probability", 0), + in_seg=inverse_params.get("in_seg", 0.0), + out_seg=inverse_params.get("out_seg", 0.0), + mix_in_out=inverse_params.get("mix_in_out", False), + mix_prob=inverse_params.get("mix_prob", 0.0), + retain_stats=inverse_params.get("retain_stats", False), + ) + ) + # Histogram manipulations histo_params = self.transform_params.get('HistogramEqualizationTransform') if histo_params is not None: - transforms.append(RandomHistogramEqualizationGPU( - p=histo_params.get('probability', 0), - in_seg=histo_params.get('in_seg', 0.0), - out_seg=histo_params.get('out_seg', 0.0), - mix_in_out=histo_params.get('mix_in_out', False), - retain_stats=histo_params.get('retain_stats', False), - )) - + transforms.append( + RandomHistogramEqualizationGPU( + p=histo_params.get("probability", 0), + in_seg=histo_params.get("in_seg", 0.0), + out_seg=histo_params.get("out_seg", 0.0), + mix_in_out=histo_params.get("mix_in_out", False), + mix_prob=histo_params.get("mix_prob", 0.0), + retain_stats=histo_params.get("retain_stats", False), + ) + ) + # Redistribute segmentation values transform redistribute_params = self.transform_params.get('RedistributeSegTransform') if redistribute_params is not None: @@ -112,7 +118,7 @@ def _build_transforms(self) -> list[nn.Module]: unsharp_amount=unsharp_params.get('unsharp_amount', 1.5), mix_prob=unsharp_params.get('mix_prob', 0.0), )) - + # RandomConv transform randconv_params = self.transform_params.get('RandomConvTransform') if randconv_params is not None: @@ -126,7 +132,7 @@ def _build_transforms(self) -> list[nn.Module]: kernel_sizes=randconv_params.get('kernel_sizes', [1,3,5,7]), mix_prob=randconv_params.get('mix_prob', 0.0), )) - + ## General enhancement (GE) # Clamping transform clamp_params = self.transform_params.get('ClampTransform') @@ -151,7 +157,7 @@ def _build_transforms(self) -> list[nn.Module]: mix_in_out=noise_params.get('mix_in_out', False), p=noise_params.get('probability', 0), )) - + # Gaussian blur gaussianblur_params = self.transform_params.get('GaussianBlurTransform') if gaussianblur_params is not None: @@ -199,7 +205,7 @@ def _build_transforms(self) -> list[nn.Module]: invert_image=True, retain_stats=inv_gamma_params.get('retain_stats', False), )) - + # nnUNetV2 Contrast transforms contrast_params = self.transform_params.get('ContrastTransform') if contrast_params is not None: @@ -251,7 +257,7 @@ def _build_transforms(self) -> list[nn.Module]: one_dim=True, same_on_batch=acq_params.get('same_on_batch', False) )) - + # Bias field artifact bias_field_params = self.transform_params.get('BiasFieldTransform') if bias_field_params is not None: @@ -263,7 +269,7 @@ def _build_transforms(self) -> list[nn.Module]: retain_stats=bias_field_params.get('retain_stats', False), coefficients=bias_field_params.get('coefficients', 0.5), )) - + ## Random Z-score normalization zscore_params = self.transform_params.get('ZscoreNormalizationTransform') if zscore_params is not None: @@ -412,4 +418,4 @@ def pad_numpy_array(arr, shape): combined_img2 = np.concatenate([img_line2, seg_line2, augmented_img_line2, augmented_seg_line2, not_augmented_channel_line2], axis=0) cv2.imwrite('img/combined2.png', combined_img2*255) - print(augmentor) \ No newline at end of file + print(augmentor) From 4fa9840b9c7fcd873c4c4d13d5dd09f68f29b5f7 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 14 Jan 2026 14:54:54 +0000 Subject: [PATCH 4/7] updated config file to newest observations --- auglab/configs/transform_params_gpu.json | 114 +++++++++++------------ 1 file changed, 53 insertions(+), 61 deletions(-) diff --git a/auglab/configs/transform_params_gpu.json b/auglab/configs/transform_params_gpu.json index 739636b..5c6ea2e 100644 --- a/auglab/configs/transform_params_gpu.json +++ b/auglab/configs/transform_params_gpu.json @@ -3,9 +3,9 @@ "kernel_type": "Scharr", "absolute": true, "retain_stats": true, - "mix_prob": 0.90, - "in_seg": 0.5, - "out_seg": 0.5, + "mix_prob": 0.50, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, "probability": 0.15 }, @@ -14,18 +14,18 @@ "sigma": 1.0, "retain_stats": false, "mix_prob": 0.50, - "in_seg": 0.5, - "out_seg": 0.5, - "mix_in_out": false, - "probability": 0.20 + "in_seg": 0.1, + "out_seg": 0.1, + "mix_in_out": true, + "probability": 0.15 }, "UnsharpMaskTransform": { "kernel_type": "UnsharpMask", "sigma": 1.0, "unsharp_amount": 1.5, "retain_stats": false, - "in_seg": 0.5, - "out_seg": 0.5, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, "mix_prob": 0.50, "probability": 0.10 @@ -34,83 +34,85 @@ "kernel_type": "RandConv", "kernel_sizes": [3,5,7], "retain_stats": true, - "in_seg": 0.5, - "out_seg": 0.5, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, "mix_prob": 0.50, "probability": 0.10 }, "RedistributeSegTransform": { - "in_seg": 0.2, + "in_seg": 0.25, "retain_stats": true, - "probability": 0.5 + "probability": 0.1 }, "GaussianNoiseTransform": { "mean": 0.0, - "std": 0.1, + "std": 1.0, "in_seg": 0.5, "out_seg": 0.5, "mix_in_out": true, - "probability": 0.10 + "probability": 0.15 }, "ClampTransform": { "max_clamp_amount": 0.2, - "retain_stats": false, - "in_seg": 0.5, - "out_seg": 0.5, + "retain_stats": true, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, - "probability": 0.40 + "probability": 0.05 }, "BrightnessTransform": { "brightness_range": [0.75, 1.25], - "in_seg": 0.0, - "out_seg": 0.0, - "mix_in_out": false, - "probability": 0.15 + "in_seg": 0.1, + "out_seg": 0.1, + "mix_in_out": true, + "probability": 0.5 }, "GammaTransform": { "gamma_range": [0.7, 1.5], "retain_stats": true, - "in_seg": 0.5, - "out_seg": 0.5, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, - "probability": 0.30 + "probability": 0.50 }, "InvGammaTransform": { "gamma_range": [0.7, 1.5], "retain_stats": true, - "in_seg": 0.5, - "out_seg": 0.5, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, - "probability": 0.10 + "probability": 0.30 }, "ContrastTransform": { "contrast_range": [0.75, 1.25], "retain_stats": false, - "in_seg": 0.0, - "out_seg": 0.0, - "mix_in_out": false, - "probability": 0.15 + "in_seg": 0.1, + "out_seg": 0.1, + "mix_in_out": true, + "probability": 0.5 }, "FunctionTransform": { "retain_stats": true, - "in_seg": 0.5, - "out_seg": 0.5, - "mix_in_out": true, - "probability": 0.05 + "in_seg": 0, + "out_seg": 0, + "mix_in_out": false, + "probability": 0.025 }, "InverseTransform": { "retain_stats": true, - "in_seg": 0.5, - "out_seg": 0.5, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, - "probability": 0.05 + "mix_prob": 0.50, + "probability": 0.1 }, "HistogramEqualizationTransform": { - "retain_stats": false, - "in_seg": 0.5, - "out_seg": 0.5, + "retain_stats": true, + "in_seg": 0, + "out_seg": 0, "mix_in_out": true, + "mix_prob": 0.80, "probability": 0.10 }, "SimulateLowResTransform": { @@ -123,14 +125,14 @@ "scale": [0.6, 1.0], "crop": [1.0, 1.0], "same_on_batch": false, - "probability": 0.40 + "probability": 0.05 }, "BiasFieldTransform": { - "retain_stats": false, - "coefficients": 0.5, - "in_seg": 0.5, - "out_seg": 0.5, - "mix_in_out": false, + "retain_stats": true, + "coefficients": 0.2, + "in_seg": 0.1, + "out_seg": 0.1, + "mix_in_out": true, "probability": 0.10 }, "FlipTransform": { @@ -145,19 +147,9 @@ "scale": [0.7, 1.4], "shear": [-5, 5, -5, 5, -5, 5], "resample": "bilinear", - "probability": 0 - }, - "nnUNetSpatialTransform": { - "patch_center_dist_from_border": 80, - "random_crop": true, - "p_elastic_deform": 0.2, - "p_rotation": 0.5, - "p_scaling": 0.5, - "scaling": [0.7, 1.4], - "p_synchronize_scaling_across_axes": 1, - "bg_style_seg_sampling": false + "probability": 0.5 }, "ZscoreNormalizationTransform": { - "probability": 0.3 + "probability": 0.0 } } \ No newline at end of file From 20c93e8f5ecc96e6106420a3fc36eae1f6e6658f Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Thu, 15 Jan 2026 16:45:12 -0500 Subject: [PATCH 5/7] add RandomChooseXTransformsGPU --- auglab/transforms/gpu/transforms.py | 78 +++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/auglab/transforms/gpu/transforms.py b/auglab/transforms/gpu/transforms.py index e150264..e18ef55 100644 --- a/auglab/transforms/gpu/transforms.py +++ b/auglab/transforms/gpu/transforms.py @@ -4,6 +4,10 @@ import torch import numpy as np +from auglab.transforms.gpu.base import ImageOnlyTransform +from typing import Any, Dict, Optional, Tuple, Union, List +from kornia.core import Tensor + from auglab.transforms.gpu.contrast import RandomConvTransformGPU, RandomGaussianNoiseGPU, RandomBrightnessGPU, RandomGammaGPU, RandomFunctionGPU, \ RandomHistogramEqualizationGPU, RandomInverseGPU, RandomBiasFieldGPU, RandomContrastGPU, ZscoreNormalizationGPU, RandomClampGPU from auglab.transforms.gpu.spatial import RandomAffine3DCustom, RandomLowResTransformGPU, RandomFlipTransformGPU, RandomAcqTransformGPU @@ -279,6 +283,80 @@ def _build_transforms(self) -> list[nn.Module]: return transforms +class RandomChooseXTransformsGPU(ImageOnlyTransform): + """Randomly choose X transforms to apply from a given list of ImageOnlyTransform transforms (GPU version). + + Args: + transforms_list: List of initialized ImageOnlyTransform to choose from. + num_transforms: Number of transforms to randomly select and apply. + same_on_batch: apply the same transformation across the batch. + p: probability for applying the X transforms to a batch. This param controls the augmentation + probabilities batch-wise. + keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch + form ``False``. + + """ + + def __init__( + self, + transforms_list: List[ImageOnlyTransform], + num_transforms: int = 1, + same_on_batch: bool = False, + p: float = 1.0, + keepdim: bool = True, + **kwargs, + ) -> None: + super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) + if not isinstance(num_transforms, int) or num_transforms < 0: + raise ValueError(f"num_transforms must be a non-negative int. Got {num_transforms!r}.") + self.transforms_list = nn.ModuleList(transforms_list) + self.num_transforms = num_transforms + + def _apply_mix(self, x: Tensor, seg: Optional[Tensor]) -> Tensor: + if self.num_transforms == 0 or len(self.transforms_list) == 0: + return x + + k = min(self.num_transforms, len(self.transforms_list)) + # sample without replacement + idx = torch.randperm(len(self.transforms_list), device=x.device)[:k] + + child_params: Dict[str, Tensor] = {} + if seg is not None: + child_params["seg"] = seg + + for j in idx.tolist(): + t = self.transforms_list[j] + if not hasattr(t, "apply_transform"): + raise TypeError( + f"All transforms must implement apply_transform like ImageOnlyTransform. Got {type(t)}" + ) + # Most contrast transforms perform their random sampling inside apply_transform. + t_flags = getattr(t, "flags", {}) + x = t.apply_transform(x, child_params, t_flags, transform=None) + return x + + @torch.no_grad() # disable gradients for efficiency + def apply_transform( + self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None + ) -> Tensor: + seg = params.get("seg", None) + + if self.same_on_batch: + return self._apply_mix(input, seg) + + batch_size = input.shape[0] + out = input + for i in range(batch_size): + xi = out[i : i + 1] + seg_i = None + if seg is not None and isinstance(seg, torch.Tensor) and seg.shape[0] == batch_size: + seg_i = seg[i : i + 1] + else: + seg_i = seg + xi = self._apply_mix(xi, seg_i) + out[i : i + 1] = xi + return out + def normalize(arr: np.ndarray) -> np.ndarray: """ Normalize a tensor to the range [0, 1]. From 859c229a06e77f46bb30362d8123d4b630db2e11 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 21 Jan 2026 13:14:07 +0000 Subject: [PATCH 6/7] removed zscorenormalization as training and validation transform at the end --- auglab/trainers/nnUNetTrainerDAExt.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/auglab/trainers/nnUNetTrainerDAExt.py b/auglab/trainers/nnUNetTrainerDAExt.py index 2d224f3..536ce22 100644 --- a/auglab/trainers/nnUNetTrainerDAExt.py +++ b/auglab/trainers/nnUNetTrainerDAExt.py @@ -210,7 +210,7 @@ def get_training_transforms( else: patch_size_spatial = patch_size ignore_axes = None - + if 'nnUNetSpatialTransform' in config: spatial_params = config['nnUNetSpatialTransform'] else: @@ -234,7 +234,7 @@ def get_training_transforms( if do_dummy_2d_data_aug: transforms.append(Convert2DTo3DTransform()) - + if use_mask_for_norm is not None and any(use_mask_for_norm): transforms.append(MaskImageTransform( apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], @@ -284,8 +284,8 @@ def get_training_transforms( channel_in_seg=0 ) ) - - transforms.append(ZscoreNormalization()) + + # transforms.append(ZscoreNormalization()) # NOTE: DownsampleSegForDSTransform is now handled in train_step for GPU augmentations # if deep_supervision_scales is not None: @@ -323,13 +323,13 @@ def get_validation_transforms( channel_in_seg=0 ) ) - - transforms.append(ZscoreNormalization()) + + # transforms.append(ZscoreNormalization()) if deep_supervision_scales is not None: transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales)) return ComposeTransforms(transforms) - + def train_step(self, batch: dict) -> dict: data = batch['data'] target = batch['target'] @@ -350,13 +350,13 @@ def train_step(self, batch: dict) -> dict: with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): # Apply GPU augmentations to full-resolution data/target data, target = self.transforms(data, target) - + # Create multi-scale targets for deep supervision after augmentation deep_supervision_scales = self._get_deep_supervision_scales() if deep_supervision_scales is not None: ds_transform = DownsampleSegForDSTransformCustom(ds_scales=deep_supervision_scales) target = ds_transform(target) - + output = self.network(data) # del data l = self.loss(output, target) @@ -519,4 +519,3 @@ def train_step(self, batch: dict) -> dict: torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return {'loss': l.detach().cpu().numpy()} - From cd4330b833bc8ad20391dbe72c3a67f736f928a7 Mon Sep 17 00:00:00 2001 From: Nathan Molinier Date: Wed, 21 Jan 2026 11:09:55 -0500 Subject: [PATCH 7/7] Account for transfrom probability in RandomChooseXTransformsGPU --- auglab/transforms/gpu/transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auglab/transforms/gpu/transforms.py b/auglab/transforms/gpu/transforms.py index e18ef55..2ec84fd 100644 --- a/auglab/transforms/gpu/transforms.py +++ b/auglab/transforms/gpu/transforms.py @@ -326,6 +326,8 @@ def _apply_mix(self, x: Tensor, seg: Optional[Tensor]) -> Tensor: for j in idx.tolist(): t = self.transforms_list[j] + if torch.rand(1, device=x.device, dtype=x.dtype) > t.p: + continue if not hasattr(t, "apply_transform"): raise TypeError( f"All transforms must implement apply_transform like ImageOnlyTransform. Got {type(t)}"