diff --git a/auglab/configs/transform_params_gpu.json b/auglab/configs/transform_params_gpu.json index 739636b..5c6ea2e 100644 --- a/auglab/configs/transform_params_gpu.json +++ b/auglab/configs/transform_params_gpu.json @@ -3,9 +3,9 @@ "kernel_type": "Scharr", "absolute": true, "retain_stats": true, - "mix_prob": 0.90, - "in_seg": 0.5, - "out_seg": 0.5, + "mix_prob": 0.50, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, "probability": 0.15 }, @@ -14,18 +14,18 @@ "sigma": 1.0, "retain_stats": false, "mix_prob": 0.50, - "in_seg": 0.5, - "out_seg": 0.5, - "mix_in_out": false, - "probability": 0.20 + "in_seg": 0.1, + "out_seg": 0.1, + "mix_in_out": true, + "probability": 0.15 }, "UnsharpMaskTransform": { "kernel_type": "UnsharpMask", "sigma": 1.0, "unsharp_amount": 1.5, "retain_stats": false, - "in_seg": 0.5, - "out_seg": 0.5, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, "mix_prob": 0.50, "probability": 0.10 @@ -34,83 +34,85 @@ "kernel_type": "RandConv", "kernel_sizes": [3,5,7], "retain_stats": true, - "in_seg": 0.5, - "out_seg": 0.5, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, "mix_prob": 0.50, "probability": 0.10 }, "RedistributeSegTransform": { - "in_seg": 0.2, + "in_seg": 0.25, "retain_stats": true, - "probability": 0.5 + "probability": 0.1 }, "GaussianNoiseTransform": { "mean": 0.0, - "std": 0.1, + "std": 1.0, "in_seg": 0.5, "out_seg": 0.5, "mix_in_out": true, - "probability": 0.10 + "probability": 0.15 }, "ClampTransform": { "max_clamp_amount": 0.2, - "retain_stats": false, - "in_seg": 0.5, - "out_seg": 0.5, + "retain_stats": true, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, - "probability": 0.40 + "probability": 0.05 }, "BrightnessTransform": { "brightness_range": [0.75, 1.25], - "in_seg": 0.0, - "out_seg": 0.0, - "mix_in_out": false, - "probability": 0.15 + "in_seg": 0.1, + "out_seg": 0.1, + "mix_in_out": true, + "probability": 0.5 }, "GammaTransform": { "gamma_range": [0.7, 1.5], "retain_stats": true, - "in_seg": 0.5, - "out_seg": 0.5, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, - "probability": 0.30 + "probability": 0.50 }, "InvGammaTransform": { "gamma_range": [0.7, 1.5], "retain_stats": true, - "in_seg": 0.5, - "out_seg": 0.5, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, - "probability": 0.10 + "probability": 0.30 }, "ContrastTransform": { "contrast_range": [0.75, 1.25], "retain_stats": false, - "in_seg": 0.0, - "out_seg": 0.0, - "mix_in_out": false, - "probability": 0.15 + "in_seg": 0.1, + "out_seg": 0.1, + "mix_in_out": true, + "probability": 0.5 }, "FunctionTransform": { "retain_stats": true, - "in_seg": 0.5, - "out_seg": 0.5, - "mix_in_out": true, - "probability": 0.05 + "in_seg": 0, + "out_seg": 0, + "mix_in_out": false, + "probability": 0.025 }, "InverseTransform": { "retain_stats": true, - "in_seg": 0.5, - "out_seg": 0.5, + "in_seg": 0.1, + "out_seg": 0.1, "mix_in_out": true, - "probability": 0.05 + "mix_prob": 0.50, + "probability": 0.1 }, "HistogramEqualizationTransform": { - "retain_stats": false, - "in_seg": 0.5, - "out_seg": 0.5, + "retain_stats": true, + "in_seg": 0, + "out_seg": 0, "mix_in_out": true, + "mix_prob": 0.80, "probability": 0.10 }, "SimulateLowResTransform": { @@ -123,14 +125,14 @@ "scale": [0.6, 1.0], "crop": [1.0, 1.0], "same_on_batch": false, - "probability": 0.40 + "probability": 0.05 }, "BiasFieldTransform": { - "retain_stats": false, - "coefficients": 0.5, - "in_seg": 0.5, - "out_seg": 0.5, - "mix_in_out": false, + "retain_stats": true, + "coefficients": 0.2, + "in_seg": 0.1, + "out_seg": 0.1, + "mix_in_out": true, "probability": 0.10 }, "FlipTransform": { @@ -145,19 +147,9 @@ "scale": [0.7, 1.4], "shear": [-5, 5, -5, 5, -5, 5], "resample": "bilinear", - "probability": 0 - }, - "nnUNetSpatialTransform": { - "patch_center_dist_from_border": 80, - "random_crop": true, - "p_elastic_deform": 0.2, - "p_rotation": 0.5, - "p_scaling": 0.5, - "scaling": [0.7, 1.4], - "p_synchronize_scaling_across_axes": 1, - "bg_style_seg_sampling": false + "probability": 0.5 }, "ZscoreNormalizationTransform": { - "probability": 0.3 + "probability": 0.0 } } \ No newline at end of file diff --git a/auglab/trainers/nnUNetTrainerDAExt.py b/auglab/trainers/nnUNetTrainerDAExt.py index 2d224f3..536ce22 100644 --- a/auglab/trainers/nnUNetTrainerDAExt.py +++ b/auglab/trainers/nnUNetTrainerDAExt.py @@ -210,7 +210,7 @@ def get_training_transforms( else: patch_size_spatial = patch_size ignore_axes = None - + if 'nnUNetSpatialTransform' in config: spatial_params = config['nnUNetSpatialTransform'] else: @@ -234,7 +234,7 @@ def get_training_transforms( if do_dummy_2d_data_aug: transforms.append(Convert2DTo3DTransform()) - + if use_mask_for_norm is not None and any(use_mask_for_norm): transforms.append(MaskImageTransform( apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], @@ -284,8 +284,8 @@ def get_training_transforms( channel_in_seg=0 ) ) - - transforms.append(ZscoreNormalization()) + + # transforms.append(ZscoreNormalization()) # NOTE: DownsampleSegForDSTransform is now handled in train_step for GPU augmentations # if deep_supervision_scales is not None: @@ -323,13 +323,13 @@ def get_validation_transforms( channel_in_seg=0 ) ) - - transforms.append(ZscoreNormalization()) + + # transforms.append(ZscoreNormalization()) if deep_supervision_scales is not None: transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales)) return ComposeTransforms(transforms) - + def train_step(self, batch: dict) -> dict: data = batch['data'] target = batch['target'] @@ -350,13 +350,13 @@ def train_step(self, batch: dict) -> dict: with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): # Apply GPU augmentations to full-resolution data/target data, target = self.transforms(data, target) - + # Create multi-scale targets for deep supervision after augmentation deep_supervision_scales = self._get_deep_supervision_scales() if deep_supervision_scales is not None: ds_transform = DownsampleSegForDSTransformCustom(ds_scales=deep_supervision_scales) target = ds_transform(target) - + output = self.network(data) # del data l = self.loss(output, target) @@ -519,4 +519,3 @@ def train_step(self, batch: dict) -> dict: torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return {'loss': l.detach().cpu().numpy()} - diff --git a/auglab/transforms/gpu/contrast.py b/auglab/transforms/gpu/contrast.py index 4638a1b..570162b 100644 --- a/auglab/transforms/gpu/contrast.py +++ b/auglab/transforms/gpu/contrast.py @@ -29,7 +29,15 @@ def _choose_region_mode(p_in: float, p_out: float, seg_mask: Optional[torch.Tens return 'out' return 'all' -def _apply_region_mode(orig: torch.Tensor, transformed: torch.Tensor, seg_mask: Optional[torch.Tensor], mode: str, normalize: bool = True, mix_in_out: bool = False) -> torch.Tensor: + +def _apply_region_mode( + orig: torch.Tensor, + transformed: torch.Tensor, + seg_mask: Optional[torch.Tensor], + mode: str, + normalize: bool = False, + mix_in_out: bool = False, +) -> torch.Tensor: """Blend transformed with orig based on region selection mode. - mode 'all': return transformed @@ -40,7 +48,7 @@ def _apply_region_mode(orig: torch.Tensor, transformed: torch.Tensor, seg_mask: """ if seg_mask is None or mode == 'all': return transformed - + # Rescale transformed based on min max orig # Needed due to the important change in the image if orig.dim() == 4: @@ -51,17 +59,19 @@ def _apply_region_mode(orig: torch.Tensor, transformed: torch.Tensor, seg_mask: transformed_max = torch.amax(transformed, dim=tuple(range(1, transformed.dim())), keepdim=True) transformed = (transformed - transformed_min) / (transformed_max - transformed_min + 1e-8) * (orig_max - orig_min) + orig_min - m = seg_mask.to(transformed.dtype) - if mode == 'out': - m = 1.0 - m + m = seg_mask.to(transformed.dtype).clone() if mix_in_out: for i in range(seg_mask.shape[0]): # Create a tensor with random one and zero - + o = torch.randint(0, 2, (seg_mask.shape[1],), device=seg_mask.device, dtype=seg_mask.dtype) m[i] = m[i] * o.view(-1, 1, 1, 1) # Broadcasting o to match the dimensions of m - - m = torch.sum(m, axis=1) + + m = torch.argmax(m, axis=1) > 0 + m = m.to(transformed.dtype) + if mode == "out": + m = 1.0 - m + elif orig.dim() == 3: if normalize: orig_min = torch.amin(orig) @@ -69,20 +79,23 @@ def _apply_region_mode(orig: torch.Tensor, transformed: torch.Tensor, seg_mask: transformed_min = torch.amin(transformed) transformed_max = torch.amax(transformed) transformed = (transformed - transformed_min) / (transformed_max - transformed_min + 1e-8) * (orig_max - orig_min) + orig_min - - m = seg_mask.to(transformed.dtype) - if mode == 'out': - m = 1.0 - m + + m = seg_mask.to(transformed.dtype).clone() if mix_in_out: # Create a tensor with random one and zero o = torch.randint(0, 2, (seg_mask.shape[0],), device=seg_mask.device, dtype=seg_mask.dtype) m = m * o.view(-1, 1, 1, 1) # Broadcasting o to match the dimensions of m - m = torch.sum(m, axis=0) + m = torch.argmax(m, axis=0) > 0 + m = m.to(transformed.dtype) + if mode == "out": + m = 1.0 - m + else: raise ValueError(f"Only 4D and 3D images are supported. Got {orig.dim()}D.") return m * transformed + (1.0 - m) * orig + ## Convolution transform class RandomConvTransformGPU(ImageOnlyTransform): """Apply convolution to image. @@ -191,7 +204,7 @@ def apply_transform( ) -> Tensor: # Initialize kernel kernel = self.get_kernel(device=input.device) - + # Load segmentation seg_mask = params.get('seg', None) @@ -234,6 +247,11 @@ def apply_transform( x = torch.stack(out, dim=0) + # Mix with original based on mix_prob + if torch.rand(1).item() < self.mix_prob: + alpha = torch.rand(1, device=input.device) + x = alpha * orig + (1 - alpha) * x + if self.retain_stats: # Adjust mean and std to match original eps = 1e-8 @@ -247,23 +265,18 @@ def apply_transform( om = orig_means.view(shape) os = orig_stds.view(shape) x = (x - nm) / (ns + eps) * os + om - + # Apply region selection if not seg_mask is None: region_mode = _choose_region_mode(self.in_seg, self.out_seg, seg_mask) x = _apply_region_mode(orig, x, seg_mask, region_mode, mix_in_out=self.mix_in_out) - - # Mix with original based on mix_prob - if torch.rand(1).item() < self.mix_prob: - alpha = torch.rand(1, device=input.device) - x = alpha * orig + (1 - alpha) * x - + # Final safety: check if nan/inf appeared if torch.isnan(x).any() or torch.isinf(x).any(): print(f"Warning nan: {self.__class__.__name__} with kernel={self.kernel_type}", flush=True) continue input[:, c] = x - + return input @@ -790,6 +803,7 @@ def __init__( in_seg: float = 0.0, out_seg: float = 0.0, mix_in_out: bool = False, + mix_prob: float = 0.0, p: float = 1.0, keepdim: bool = True, **kwargs, @@ -799,6 +813,7 @@ def __init__( self.retain_stats = retain_stats self.in_seg = in_seg self.out_seg = out_seg + self.mix_prob = mix_prob self.mix_in_out = mix_in_out @torch.no_grad() # disable gradients for efficiency @@ -817,12 +832,19 @@ def apply_transform( orig_stds = x.std() max_val = x.max() x = max_val - x + if self.retain_stats: # Adjust mean and std to match original eps = 1e-8 new_mean = x.mean() # scalar new_std = x.std() # scalar x = (x - new_mean) / (new_std + eps) * orig_stds + orig_means + + # Mix with original based on mix_prob + if torch.rand(1).item() < self.mix_prob: + alpha = torch.rand(1, device=input.device) + x = alpha * orig + (1 - alpha) * x + if not seg_mask is None: region_mode = _choose_region_mode(self.in_seg, self.out_seg, seg_mask[i]) x = _apply_region_mode(orig, x, seg_mask[i], region_mode, mix_in_out=self.mix_in_out) @@ -831,7 +853,7 @@ def apply_transform( print(f"Warning nan: {self.__class__.__name__}", flush=True) continue input[i, c] = x - + return input @@ -859,6 +881,7 @@ def __init__( in_seg: float = 0.0, out_seg: float = 0.0, mix_in_out: bool = False, + mix_prob: float = 0.0, p: float = 1.0, keepdim: bool = True, **kwargs, @@ -869,6 +892,7 @@ def __init__( self.in_seg = in_seg self.out_seg = out_seg self.mix_in_out = mix_in_out + self.mix_prob = mix_prob @torch.no_grad() # disable gradients for efficiency def apply_transform( @@ -880,7 +904,7 @@ def apply_transform( for c in self.apply_to_channel: channel_data = input[:, c] # shape [N, ...spatial...] orig = channel_data.clone() - + if self.retain_stats: reduce_dims = tuple(range(1, channel_data.dim())) # store per-sample mean/std (shape [N]) @@ -913,6 +937,11 @@ def apply_transform( img_eq = cdf[indices] channel_data[b] = img_eq.reshape(img_b.shape) + # Mix with original based on mix_prob + if torch.rand(1).item() < self.mix_prob: + alpha = torch.rand(1, device=input.device) + channel_data[b] = alpha * orig[b] + (1 - alpha) * channel_data[b] + if self.retain_stats: # Adjust mean and std to match original eps = 1e-8 @@ -926,7 +955,7 @@ def apply_transform( om = orig_means.view(shape) os = orig_stds.view(shape) channel_data = (channel_data - nm) / (ns + eps) * os + om - + if not seg_mask is None: region_mode = _choose_region_mode(self.in_seg, self.out_seg, seg_mask) channel_data = _apply_region_mode(orig, channel_data, seg_mask, region_mode, mix_in_out=self.mix_in_out) diff --git a/auglab/transforms/gpu/transforms.py b/auglab/transforms/gpu/transforms.py index 79eac7f..2ec84fd 100644 --- a/auglab/transforms/gpu/transforms.py +++ b/auglab/transforms/gpu/transforms.py @@ -4,6 +4,10 @@ import torch import numpy as np +from auglab.transforms.gpu.base import ImageOnlyTransform +from typing import Any, Dict, Optional, Tuple, Union, List +from kornia.core import Tensor + from auglab.transforms.gpu.contrast import RandomConvTransformGPU, RandomGaussianNoiseGPU, RandomBrightnessGPU, RandomGammaGPU, RandomFunctionGPU, \ RandomHistogramEqualizationGPU, RandomInverseGPU, RandomBiasFieldGPU, RandomContrastGPU, ZscoreNormalizationGPU, RandomClampGPU from auglab.transforms.gpu.spatial import RandomAffine3DCustom, RandomLowResTransformGPU, RandomFlipTransformGPU, RandomAcqTransformGPU @@ -19,7 +23,7 @@ def __init__(self, json_path: str): config_path = os.path.join(json_path) with open(config_path, 'r') as f: config = json.load(f) - + if 'GPU' in config.keys(): self.transform_params = config['GPU'] else: @@ -52,30 +56,36 @@ def _build_transforms(self) -> list[nn.Module]: resample=affine_params.get('resample', "bilinear"), p=affine_params.get('probability', 0) )) - + ## Transfer augmentations (TA) # Inverse transform (max - pixel_value) inverse_params = self.transform_params.get('InverseTransform') if inverse_params is not None: - transforms.append(RandomInverseGPU( - p=inverse_params.get('probability', 0), - in_seg=inverse_params.get('in_seg', 0.0), - out_seg=inverse_params.get('out_seg', 0.0), - mix_in_out=inverse_params.get('mix_in_out', False), - retain_stats=inverse_params.get('retain_stats', False), - )) - + transforms.append( + RandomInverseGPU( + p=inverse_params.get("probability", 0), + in_seg=inverse_params.get("in_seg", 0.0), + out_seg=inverse_params.get("out_seg", 0.0), + mix_in_out=inverse_params.get("mix_in_out", False), + mix_prob=inverse_params.get("mix_prob", 0.0), + retain_stats=inverse_params.get("retain_stats", False), + ) + ) + # Histogram manipulations histo_params = self.transform_params.get('HistogramEqualizationTransform') if histo_params is not None: - transforms.append(RandomHistogramEqualizationGPU( - p=histo_params.get('probability', 0), - in_seg=histo_params.get('in_seg', 0.0), - out_seg=histo_params.get('out_seg', 0.0), - mix_in_out=histo_params.get('mix_in_out', False), - retain_stats=histo_params.get('retain_stats', False), - )) - + transforms.append( + RandomHistogramEqualizationGPU( + p=histo_params.get("probability", 0), + in_seg=histo_params.get("in_seg", 0.0), + out_seg=histo_params.get("out_seg", 0.0), + mix_in_out=histo_params.get("mix_in_out", False), + mix_prob=histo_params.get("mix_prob", 0.0), + retain_stats=histo_params.get("retain_stats", False), + ) + ) + # Redistribute segmentation values transform redistribute_params = self.transform_params.get('RedistributeSegTransform') if redistribute_params is not None: @@ -112,7 +122,7 @@ def _build_transforms(self) -> list[nn.Module]: unsharp_amount=unsharp_params.get('unsharp_amount', 1.5), mix_prob=unsharp_params.get('mix_prob', 0.0), )) - + # RandomConv transform randconv_params = self.transform_params.get('RandomConvTransform') if randconv_params is not None: @@ -126,7 +136,7 @@ def _build_transforms(self) -> list[nn.Module]: kernel_sizes=randconv_params.get('kernel_sizes', [1,3,5,7]), mix_prob=randconv_params.get('mix_prob', 0.0), )) - + ## General enhancement (GE) # Clamping transform clamp_params = self.transform_params.get('ClampTransform') @@ -151,7 +161,7 @@ def _build_transforms(self) -> list[nn.Module]: mix_in_out=noise_params.get('mix_in_out', False), p=noise_params.get('probability', 0), )) - + # Gaussian blur gaussianblur_params = self.transform_params.get('GaussianBlurTransform') if gaussianblur_params is not None: @@ -199,7 +209,7 @@ def _build_transforms(self) -> list[nn.Module]: invert_image=True, retain_stats=inv_gamma_params.get('retain_stats', False), )) - + # nnUNetV2 Contrast transforms contrast_params = self.transform_params.get('ContrastTransform') if contrast_params is not None: @@ -251,7 +261,7 @@ def _build_transforms(self) -> list[nn.Module]: one_dim=True, same_on_batch=acq_params.get('same_on_batch', False) )) - + # Bias field artifact bias_field_params = self.transform_params.get('BiasFieldTransform') if bias_field_params is not None: @@ -263,7 +273,7 @@ def _build_transforms(self) -> list[nn.Module]: retain_stats=bias_field_params.get('retain_stats', False), coefficients=bias_field_params.get('coefficients', 0.5), )) - + ## Random Z-score normalization zscore_params = self.transform_params.get('ZscoreNormalizationTransform') if zscore_params is not None: @@ -273,6 +283,82 @@ def _build_transforms(self) -> list[nn.Module]: return transforms +class RandomChooseXTransformsGPU(ImageOnlyTransform): + """Randomly choose X transforms to apply from a given list of ImageOnlyTransform transforms (GPU version). + + Args: + transforms_list: List of initialized ImageOnlyTransform to choose from. + num_transforms: Number of transforms to randomly select and apply. + same_on_batch: apply the same transformation across the batch. + p: probability for applying the X transforms to a batch. This param controls the augmentation + probabilities batch-wise. + keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch + form ``False``. + + """ + + def __init__( + self, + transforms_list: List[ImageOnlyTransform], + num_transforms: int = 1, + same_on_batch: bool = False, + p: float = 1.0, + keepdim: bool = True, + **kwargs, + ) -> None: + super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) + if not isinstance(num_transforms, int) or num_transforms < 0: + raise ValueError(f"num_transforms must be a non-negative int. Got {num_transforms!r}.") + self.transforms_list = nn.ModuleList(transforms_list) + self.num_transforms = num_transforms + + def _apply_mix(self, x: Tensor, seg: Optional[Tensor]) -> Tensor: + if self.num_transforms == 0 or len(self.transforms_list) == 0: + return x + + k = min(self.num_transforms, len(self.transforms_list)) + # sample without replacement + idx = torch.randperm(len(self.transforms_list), device=x.device)[:k] + + child_params: Dict[str, Tensor] = {} + if seg is not None: + child_params["seg"] = seg + + for j in idx.tolist(): + t = self.transforms_list[j] + if torch.rand(1, device=x.device, dtype=x.dtype) > t.p: + continue + if not hasattr(t, "apply_transform"): + raise TypeError( + f"All transforms must implement apply_transform like ImageOnlyTransform. Got {type(t)}" + ) + # Most contrast transforms perform their random sampling inside apply_transform. + t_flags = getattr(t, "flags", {}) + x = t.apply_transform(x, child_params, t_flags, transform=None) + return x + + @torch.no_grad() # disable gradients for efficiency + def apply_transform( + self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None + ) -> Tensor: + seg = params.get("seg", None) + + if self.same_on_batch: + return self._apply_mix(input, seg) + + batch_size = input.shape[0] + out = input + for i in range(batch_size): + xi = out[i : i + 1] + seg_i = None + if seg is not None and isinstance(seg, torch.Tensor) and seg.shape[0] == batch_size: + seg_i = seg[i : i + 1] + else: + seg_i = seg + xi = self._apply_mix(xi, seg_i) + out[i : i + 1] = xi + return out + def normalize(arr: np.ndarray) -> np.ndarray: """ Normalize a tensor to the range [0, 1]. @@ -412,4 +498,4 @@ def pad_numpy_array(arr, shape): combined_img2 = np.concatenate([img_line2, seg_line2, augmented_img_line2, augmented_seg_line2, not_augmented_channel_line2], axis=0) cv2.imwrite('img/combined2.png', combined_img2*255) - print(augmentor) \ No newline at end of file + print(augmentor)