Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 98 additions & 165 deletions brain_deform/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,31 @@
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union, Dict
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
import kornia.augmentation as K
import torch
from dataclasses import dataclass

from brain_deform.cuda.deform import Interpolation
from brain_deform.matching import match_intensity_batch
from brain_deform.registration import (
register_linear,
register_nonlinear_from_coefs,
register_nonlinear_from_field,
xregister_nonlinear_from_coefs,
xregister_nonlinear_invert_from_coefs,
xregister_nonlinear_invert_from_field,
register_nonlinear,
xregister_nonlinear,
xregister_nonlinear_invert,
)


@dataclass
class RequiredParams:
registration: str
random_warp: float
cross_warp: float
interpolation: str
target_shape: Tuple[int, int, int]
spacing: Tuple[float, float, float]
deformations_are_coefs: bool
invert_on_the_fly: bool
use_isotropic_scaling: bool
augmentation_probability: float
cross_intensity: float
translation: Optional[int]
rotation: Optional[int]
scale: Optional[int]
flip: bool


def process_image(
img_main: torch.Tensor,
deformation_main: torch.Tensor,
premat_main: torch.Tensor,
postmat_main: torch.Tensor,
img_aux: torch.Tensor,
deformation_aux: torch.Tensor,
premat_aux: torch.Tensor,
postmat_aux: torch.Tensor,
deformation_aug: torch.Tensor,
premat_aug: torch.Tensor,
hparams: RequiredParams,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_main,
coefs_main,
premat_main,
postmat_main,
img_aux,
coefs_aux,
premat_aux,
postmat_aux,
coefs_aug,
premat_aug,
hparams,
):
# Final resampling interpolation method
if hparams.interpolation == "nearest":
interpolation = Interpolation.Nearest
Expand All @@ -58,17 +36,16 @@ def process_image(

warp_args = (
img_main,
deformation_main,
coefs_main,
premat_main,
postmat_main,
deformation_aug,
coefs_aug,
premat_aug,
hparams.invert_on_the_fly,
interpolation,
hparams.target_shape,
hparams.spacing,
hparams.augmentation_probability,
hparams.deformations_are_coefs,
)

if hparams.registration == "nonlinear":
Expand All @@ -91,17 +68,16 @@ def process_image(
if hparams.cross_intensity > 0:
warp_args = (
img_aux,
deformation_aux,
coefs_aux,
premat_aux,
postmat_aux,
deformation_aug,
coefs_aug,
premat_aug,
hparams.invert_on_the_fly,
interpolation,
hparams.target_shape,
hparams.spacing,
hparams.augmentation_probability,
hparams.deformations_are_coefs,
)
if hparams.registration == "nonlinear":
img_aux_registered = warp_augmentation("nonlinear", 0, *warp_args)
Expand Down Expand Up @@ -139,22 +115,22 @@ def affine_augmentation(
target_shape: Tuple[int, int, int] = (182, 218, 182),
use_isotropic_scaling: bool = False,
augmentation_probability: float = 1,
) -> torch.Tensor:
) -> Optional[Union[K.AugmentationSequential, K.RandomAffine3D]]:
augmentation = None

if (translation is not None) and translation > 0:
if translation > 0:
translation = (
translation / target_shape[0],
translation / target_shape[1],
translation / target_shape[2],
) # type: ignore
)
else:
translation = None

if (scale is not None) and scale > 0:
scale = (1 - scale, 1 + scale) # type: ignore
if scale > 0:
scale = (1 - scale, 1 + scale)
if not use_isotropic_scaling:
scale = (scale, scale, scale) # type: ignore
scale = (scale, scale, scale)
else:
scale = None

Expand All @@ -175,16 +151,12 @@ def affine_augmentation(
augmentation = K.AugmentationSequential(flip_aug, augmentation)
else:
augmentation = flip_aug
return augmentation(img) if augmentation is not None else img # type: ignore
return augmentation(img) if augmentation is not None else img


def intensity_augmentation(
img_main: torch.Tensor,
img_aux: torch.Tensor,
p_max: float,
augmentation_probability: float,
mask_thresh: float = 1e-4,
) -> torch.Tensor:
img_main, img_aux, p_max, augmentation_probability, mask_thresh=1e-4
):

main_mask = img_main > mask_thresh
aux_mask = img_aux > mask_thresh
Expand All @@ -199,141 +171,102 @@ def warp_augmentation(
mode: str,
mode_factor: float,
img: torch.Tensor,
deformation1: torch.Tensor,
coefs1: torch.Tensor,
premat1: torch.Tensor,
postmat1: torch.Tensor,
deformation2: torch.Tensor,
coefs2: torch.Tensor,
premat2: torch.Tensor,
invert_on_the_fly: bool,
interpolation: Interpolation,
target_shape: Tuple[int, int, int],
spacing: Tuple[float, float, float],
augmentation_probability: float,
deformations_are_coefs: bool,
) -> torch.Tensor:
invert_on_the_fly,
interpolation,
target_shape,
spacing,
augmentation_probability,
) -> Tuple[torch.Tensor, ...]:

device = img.device
batch_size = img.size(0)


if len(img.shape) == 3:
batch_size = 1
elif len(img.shape) == 4:
batch_size = img.size(0)
else:
raise ValueError("Invalid shape for image_data. Expected 3 or 4 dimensions.")

if mode == "linear":
# note: nonlinear transform is used, but with 0 as scale, wich results in only a linear transform
scale_vec = torch.zeros(batch_size, device=device)
if deformations_are_coefs:
return register_nonlinear_from_coefs(
img,
deformation1,
premat1,
postmat1,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
spacing=spacing,
)
else:
return register_nonlinear_from_field(
img,
deformation1,
premat1,
postmat1,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
)

return register_nonlinear(
img,
coefs1,
premat1,
postmat1,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
spacing=spacing,
)
elif mode == "nonlinear":
scale_vec = torch.ones(batch_size, device=device)
if deformations_are_coefs:
return register_nonlinear_from_coefs(
return register_nonlinear(
img,
coefs1,
premat1,
postmat1,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
spacing=spacing,
)
elif mode == "cross_warp":
# S ~ U(0, factor)
p_mask = torch.rand(batch_size, device=device) < augmentation_probability
scale_vec = torch.rand(batch_size, device=device) * mode_factor * p_mask
if invert_on_the_fly:
return xregister_nonlinear_invert(
img,
deformation1,
coefs1,
premat1,
postmat1,
coefs2,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
spacing=spacing,
)
else:
return register_nonlinear_from_field(
return xregister_nonlinear(
img,
deformation1,
coefs1,
premat1,
postmat1,
coefs2,
premat2,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
spacing=spacing,
)
elif mode == "cross_warp":
elif mode == "random_warp":
# Assuming 10mm, max def = 10/2
# X ~ U(-5, 5)
max_displacement = np.mean(spacing) / 2
# print("using rand max disp", max_displacement)
rnd_coefs = (
torch.rand_like(coefs1, dtype=torch.float32, device=device) * 2 - 1
) * max_displacement

# S ~ U(0, factor)
p_mask = torch.rand(batch_size, device=device) < augmentation_probability
scale_vec = torch.rand(batch_size, device=device) * mode_factor * p_mask
if deformations_are_coefs:
if invert_on_the_fly:
return xregister_nonlinear_invert_from_coefs(
img,
deformation1,
premat1,
postmat1,
deformation2,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
spacing=spacing,
)
else:
return xregister_nonlinear_from_coefs(
img,
deformation1,
premat1,
postmat1,
deformation2,
premat2,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
spacing=spacing,
)
else:
if invert_on_the_fly:
return xregister_nonlinear_invert_from_field(
img,
deformation1,
deformation2,
premat1,
postmat1,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
)
else:
# Left as an exercise for the reader
raise NotImplementedError

elif mode == "random_warp":
if deformations_are_coefs:
# Assuming 10mm, max def = 10/2
# X ~ U(-5, 5)
max_displacement = np.mean(spacing) / 2
# print("using rand max disp", max_displacement)
rnd_coefs = (
torch.rand_like(deformation1, dtype=torch.float32, device=device) * 2
- 1
) * max_displacement

# S ~ U(0, factor)
p_mask = torch.rand(batch_size, device=device) < augmentation_probability
scale_vec = torch.rand(batch_size, device=device) * mode_factor * p_mask
return register_nonlinear_from_coefs(
img,
rnd_coefs,
premat1,
postmat1,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
spacing=spacing,
)
else:
raise NotImplementedError
return register_nonlinear(
img,
rnd_coefs,
premat1,
postmat1,
scale_vec,
interpolation=interpolation,
target_shape=target_shape,
spacing=spacing,
)
else:
raise ValueError("invalid warp augmentation mode")
raise ValueError("invalid mode")
Loading