diff --git a/brain_deform/augmentation.py b/brain_deform/augmentation.py index 405afc1..2dbb6f7 100644 --- a/brain_deform/augmentation.py +++ b/brain_deform/augmentation.py @@ -1,53 +1,31 @@ -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union, Dict +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import kornia.augmentation as K import torch -from dataclasses import dataclass from brain_deform.cuda.deform import Interpolation from brain_deform.matching import match_intensity_batch from brain_deform.registration import ( register_linear, - register_nonlinear_from_coefs, - register_nonlinear_from_field, - xregister_nonlinear_from_coefs, - xregister_nonlinear_invert_from_coefs, - xregister_nonlinear_invert_from_field, + register_nonlinear, + xregister_nonlinear, + xregister_nonlinear_invert, ) -@dataclass -class RequiredParams: - registration: str - random_warp: float - cross_warp: float - interpolation: str - target_shape: Tuple[int, int, int] - spacing: Tuple[float, float, float] - deformations_are_coefs: bool - invert_on_the_fly: bool - use_isotropic_scaling: bool - augmentation_probability: float - cross_intensity: float - translation: Optional[int] - rotation: Optional[int] - scale: Optional[int] - flip: bool - - def process_image( - img_main: torch.Tensor, - deformation_main: torch.Tensor, - premat_main: torch.Tensor, - postmat_main: torch.Tensor, - img_aux: torch.Tensor, - deformation_aux: torch.Tensor, - premat_aux: torch.Tensor, - postmat_aux: torch.Tensor, - deformation_aug: torch.Tensor, - premat_aug: torch.Tensor, - hparams: RequiredParams, -) -> Tuple[torch.Tensor, torch.Tensor]: + img_main, + coefs_main, + premat_main, + postmat_main, + img_aux, + coefs_aux, + premat_aux, + postmat_aux, + coefs_aug, + premat_aug, + hparams, +): # Final resampling interpolation method if hparams.interpolation == "nearest": interpolation = Interpolation.Nearest @@ -58,17 +36,16 @@ def process_image( warp_args = ( img_main, - deformation_main, + coefs_main, premat_main, postmat_main, - deformation_aug, + coefs_aug, premat_aug, hparams.invert_on_the_fly, interpolation, hparams.target_shape, hparams.spacing, hparams.augmentation_probability, - hparams.deformations_are_coefs, ) if hparams.registration == "nonlinear": @@ -91,17 +68,16 @@ def process_image( if hparams.cross_intensity > 0: warp_args = ( img_aux, - deformation_aux, + coefs_aux, premat_aux, postmat_aux, - deformation_aug, + coefs_aug, premat_aug, hparams.invert_on_the_fly, interpolation, hparams.target_shape, hparams.spacing, hparams.augmentation_probability, - hparams.deformations_are_coefs, ) if hparams.registration == "nonlinear": img_aux_registered = warp_augmentation("nonlinear", 0, *warp_args) @@ -139,22 +115,22 @@ def affine_augmentation( target_shape: Tuple[int, int, int] = (182, 218, 182), use_isotropic_scaling: bool = False, augmentation_probability: float = 1, -) -> torch.Tensor: +) -> Optional[Union[K.AugmentationSequential, K.RandomAffine3D]]: augmentation = None - if (translation is not None) and translation > 0: + if translation > 0: translation = ( translation / target_shape[0], translation / target_shape[1], translation / target_shape[2], - ) # type: ignore + ) else: translation = None - if (scale is not None) and scale > 0: - scale = (1 - scale, 1 + scale) # type: ignore + if scale > 0: + scale = (1 - scale, 1 + scale) if not use_isotropic_scaling: - scale = (scale, scale, scale) # type: ignore + scale = (scale, scale, scale) else: scale = None @@ -175,16 +151,12 @@ def affine_augmentation( augmentation = K.AugmentationSequential(flip_aug, augmentation) else: augmentation = flip_aug - return augmentation(img) if augmentation is not None else img # type: ignore + return augmentation(img) if augmentation is not None else img def intensity_augmentation( - img_main: torch.Tensor, - img_aux: torch.Tensor, - p_max: float, - augmentation_probability: float, - mask_thresh: float = 1e-4, -) -> torch.Tensor: + img_main, img_aux, p_max, augmentation_probability, mask_thresh=1e-4 +): main_mask = img_main > mask_thresh aux_mask = img_aux > mask_thresh @@ -199,141 +171,102 @@ def warp_augmentation( mode: str, mode_factor: float, img: torch.Tensor, - deformation1: torch.Tensor, + coefs1: torch.Tensor, premat1: torch.Tensor, postmat1: torch.Tensor, - deformation2: torch.Tensor, + coefs2: torch.Tensor, premat2: torch.Tensor, - invert_on_the_fly: bool, - interpolation: Interpolation, - target_shape: Tuple[int, int, int], - spacing: Tuple[float, float, float], - augmentation_probability: float, - deformations_are_coefs: bool, -) -> torch.Tensor: + invert_on_the_fly, + interpolation, + target_shape, + spacing, + augmentation_probability, +) -> Tuple[torch.Tensor, ...]: device = img.device - batch_size = img.size(0) - + + if len(img.shape) == 3: + batch_size = 1 + elif len(img.shape) == 4: + batch_size = img.size(0) + else: + raise ValueError("Invalid shape for image_data. Expected 3 or 4 dimensions.") + if mode == "linear": # note: nonlinear transform is used, but with 0 as scale, wich results in only a linear transform scale_vec = torch.zeros(batch_size, device=device) - if deformations_are_coefs: - return register_nonlinear_from_coefs( - img, - deformation1, - premat1, - postmat1, - scale_vec, - interpolation=interpolation, - target_shape=target_shape, - spacing=spacing, - ) - else: - return register_nonlinear_from_field( - img, - deformation1, - premat1, - postmat1, - scale_vec, - interpolation=interpolation, - target_shape=target_shape, - ) - + return register_nonlinear( + img, + coefs1, + premat1, + postmat1, + scale_vec, + interpolation=interpolation, + target_shape=target_shape, + spacing=spacing, + ) elif mode == "nonlinear": scale_vec = torch.ones(batch_size, device=device) - if deformations_are_coefs: - return register_nonlinear_from_coefs( + return register_nonlinear( + img, + coefs1, + premat1, + postmat1, + scale_vec, + interpolation=interpolation, + target_shape=target_shape, + spacing=spacing, + ) + elif mode == "cross_warp": + # S ~ U(0, factor) + p_mask = torch.rand(batch_size, device=device) < augmentation_probability + scale_vec = torch.rand(batch_size, device=device) * mode_factor * p_mask + if invert_on_the_fly: + return xregister_nonlinear_invert( img, - deformation1, + coefs1, premat1, postmat1, + coefs2, scale_vec, interpolation=interpolation, target_shape=target_shape, spacing=spacing, ) else: - return register_nonlinear_from_field( + return xregister_nonlinear( img, - deformation1, + coefs1, premat1, postmat1, + coefs2, + premat2, scale_vec, interpolation=interpolation, target_shape=target_shape, + spacing=spacing, ) - elif mode == "cross_warp": + elif mode == "random_warp": + # Assuming 10mm, max def = 10/2 + # X ~ U(-5, 5) + max_displacement = np.mean(spacing) / 2 + # print("using rand max disp", max_displacement) + rnd_coefs = ( + torch.rand_like(coefs1, dtype=torch.float32, device=device) * 2 - 1 + ) * max_displacement + # S ~ U(0, factor) p_mask = torch.rand(batch_size, device=device) < augmentation_probability scale_vec = torch.rand(batch_size, device=device) * mode_factor * p_mask - if deformations_are_coefs: - if invert_on_the_fly: - return xregister_nonlinear_invert_from_coefs( - img, - deformation1, - premat1, - postmat1, - deformation2, - scale_vec, - interpolation=interpolation, - target_shape=target_shape, - spacing=spacing, - ) - else: - return xregister_nonlinear_from_coefs( - img, - deformation1, - premat1, - postmat1, - deformation2, - premat2, - scale_vec, - interpolation=interpolation, - target_shape=target_shape, - spacing=spacing, - ) - else: - if invert_on_the_fly: - return xregister_nonlinear_invert_from_field( - img, - deformation1, - deformation2, - premat1, - postmat1, - scale_vec, - interpolation=interpolation, - target_shape=target_shape, - ) - else: - # Left as an exercise for the reader - raise NotImplementedError - - elif mode == "random_warp": - if deformations_are_coefs: - # Assuming 10mm, max def = 10/2 - # X ~ U(-5, 5) - max_displacement = np.mean(spacing) / 2 - # print("using rand max disp", max_displacement) - rnd_coefs = ( - torch.rand_like(deformation1, dtype=torch.float32, device=device) * 2 - - 1 - ) * max_displacement - - # S ~ U(0, factor) - p_mask = torch.rand(batch_size, device=device) < augmentation_probability - scale_vec = torch.rand(batch_size, device=device) * mode_factor * p_mask - return register_nonlinear_from_coefs( - img, - rnd_coefs, - premat1, - postmat1, - scale_vec, - interpolation=interpolation, - target_shape=target_shape, - spacing=spacing, - ) - else: - raise NotImplementedError + return register_nonlinear( + img, + rnd_coefs, + premat1, + postmat1, + scale_vec, + interpolation=interpolation, + target_shape=target_shape, + spacing=spacing, + ) else: - raise ValueError("invalid warp augmentation mode") + raise ValueError("invalid mode") diff --git a/brain_deform/datasets.py b/brain_deform/datasets.py index 6745920..0bea3cd 100644 --- a/brain_deform/datasets.py +++ b/brain_deform/datasets.py @@ -7,11 +7,7 @@ import torch.nn.functional as F from fsl.data.image import Image from fsl.transform.fnirt import readFnirt -from fsl.transform import affine -from fsl.transform.nonlinear import CoefficientField, DeformationField from torch.utils.data import Dataset -import fsl.data.constants as constants -from brain_deform.ants import read_ants_affine from brain_deform import registration @@ -34,15 +30,15 @@ def get_mni(template: str) -> Image: def get_field_to_ref_mat( - src_image: Union[str, Image], deformation_path: str, ref: Image + src_image: Union[str, Image], coefs_path: str, ref: Image ) -> npt.NDArray[Any]: if isinstance(src_image, str): src_image = Image(src_image) - deform = readFnirt(deformation_path, src=src_image, ref=ref) - if not registration.is_diag(deform.fieldToRefMat): + coefs = readFnirt(coefs_path, src=src_image, ref=ref) + if not registration.is_diag(coefs.fieldToRefMat): raise ValueError("fieldToRefMat: only scaling matrices are supported") - return deform.fieldToRefMat # type: ignore + return coefs.fieldToRefMat # type: ignore def multi_modal_collate_fn( @@ -59,23 +55,21 @@ def multi_modal_collate_fn( num_modalities = (len(datum) - 1) // 4 - deformation_shape = data[0][1].shape + coefs_shape = data[0][1].shape premat_shape = data[0][2].shape postmat_shape = data[0][3].shape - label_shape = data[0][-1].shape + label_shape = data[0][-1][-1].shape batch = torch.zeros((len(data), num_modalities, max_x, max_y, max_z)) - prewarp_deformation_tensor = torch.zeros( - (len(data), num_modalities, *deformation_shape) - ) + prewarp_coefs_tensor = torch.zeros((len(data), num_modalities, *coefs_shape)) prewarp_premat_tensor = torch.zeros((len(data), num_modalities, *premat_shape)) prewarp_postmat_tensor = torch.zeros((len(data), num_modalities, *postmat_shape)) labels = torch.zeros((len(data), *label_shape)) - + for i, datum in enumerate(data): for j in range((len(datum) - 1) // 4): img = datum[j * 4] - deform = datum[j * 4 + 1] + coef = datum[j * 4 + 1] premat = datum[j * 4 + 2] postmat = datum[j * 4 + 3] @@ -84,16 +78,16 @@ def multi_modal_collate_fn( offset_z = max_z - img.shape[2] batch[i, j] = F.pad(img, (0, offset_z, 0, offset_y, 0, offset_x)) - prewarp_deformation_tensor[i, j] = deform + prewarp_coefs_tensor[i, j] = coef prewarp_premat_tensor[i, j] = premat prewarp_postmat_tensor[i, j] = postmat - labels[i] = datum[-1] + labels[i] = datum[-1][-1] # Reshape to (batch * dim, shape) batch = batch.reshape((len(data) * num_modalities, max_x, max_y, max_z)) - prewarp_deformation_tensor = prewarp_deformation_tensor.reshape( - (len(data) * num_modalities, *deformation_shape) + prewarp_coefs_tensor = prewarp_coefs_tensor.reshape( + (len(data) * num_modalities, *coefs_shape) ) prewarp_premat_tensor = prewarp_premat_tensor.reshape( (len(data) * num_modalities, *premat_shape) @@ -104,7 +98,7 @@ def multi_modal_collate_fn( return ( batch, - prewarp_deformation_tensor, + prewarp_coefs_tensor, prewarp_premat_tensor, prewarp_postmat_tensor, labels, @@ -114,29 +108,38 @@ def multi_modal_collate_fn( def multi_modal_inv_collate_fn( data: List[Tuple[torch.Tensor, ...]] ) -> Tuple[torch.Tensor, torch.Tensor]: - + num_modalities = len(data[0]) // 2 + # Subtract one because of the label + element_length = len(data[0]) - 1 + num_modalities = element_length // 2 + - deformation_shape = data[0][0].shape + coefs_shape = data[0][0].shape premat_shape = data[0][1].shape + label_shape = data[0][-1][-1].shape - deformation_tensor = torch.zeros((len(data), num_modalities, *deformation_shape)) + coefs_tensor = torch.zeros((len(data), num_modalities, *coefs_shape)) premat_tensor = torch.zeros((len(data), num_modalities, *premat_shape)) + labels = torch.zeros((len(data), *label_shape)) for i, datum in enumerate(data): - for j in range(len(datum) // 2): - deform = datum[j * 2] + for j in range(num_modalities): + coef = datum[j * 2] premat = datum[j * 2 + 1] - deformation_tensor[i, j] = deform + coefs_tensor[i, j] = coef premat_tensor[i, j] = premat + + labels[i] = datum[-1][-1] - deformation_tensor = deformation_tensor.reshape((len(data) * num_modalities, *deformation_shape)) + coefs_tensor = coefs_tensor.reshape((len(data) * num_modalities, *coefs_shape)) premat_tensor = premat_tensor.reshape((len(data) * num_modalities, *premat_shape)) return ( - deformation_tensor, + coefs_tensor, premat_tensor, + labels ) @@ -144,95 +147,44 @@ class MultiModalWarpDataset(Dataset[Tuple[torch.Tensor, ...]]): def __init__( self, t1_paths: Sequence[str], - t1_deformation_paths: Sequence[str], - labels: npt.NDArray[Any], + t1_coefs_paths: Sequence[str], + labels: npt.NDArray[Tuple[Any, Any]], + extra_modalities: Sequence[Mapping[str, Sequence[str]]], - t1_to_mni_matrix_paths: Optional[Sequence[str]] = None, - deformations_are_coefs: bool = True, - ants_mode: bool = False ): self.t1_paths = t1_paths - self.t1_deformation_paths = t1_deformation_paths - self.t1_to_mni_matrix_paths = t1_to_mni_matrix_paths + self.t1_coefs_paths = t1_coefs_paths self.extra_modalities = extra_modalities self.labels = labels - self.deformations_are_coefs = deformations_are_coefs - self.ants_mode = ants_mode self.mni_ref_image = get_mni("brain") - - # Determine initial fieldToRefMat - # We assume the same spacing across all samples from here on - if self.deformations_are_coefs: - self.fieldToRefMat = get_field_to_ref_mat( - self.t1_paths[0], self.t1_deformation_paths[0], self.mni_ref_image - ) + self.fieldToRefMat = get_field_to_ref_mat( + self.t1_paths[0], self.t1_coefs_paths[0], self.mni_ref_image + ) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]: - t1_path, t1_deformation_path, label = ( + t1_path, t1_coefs_path, label = ( self.t1_paths[idx], - self.t1_deformation_paths[idx], - self.labels[idx], + self.t1_coefs_paths[idx], + self.labels[idx], ) # Get T1 image t1_image = Image(t1_path) - intent = None if self.deformations_are_coefs else constants.FSL_FNIRT_DISPLACEMENT_FIELD - t1_deformation = readFnirt( - t1_deformation_path, src=t1_image, ref=self.mni_ref_image, intent=intent - ) - - # Verify that correct deformation type is used - assert ( - isinstance(t1_deformation, DeformationField) - and not self.deformations_are_coefs - or isinstance(t1_deformation, CoefficientField) - and self.deformations_are_coefs - ), "Mismatch, please specify the correct type of deformation (deformations_are_coefs)" - - if isinstance(t1_deformation, DeformationField): - #print("field!") - # TODO: somehow check this for ants? - if not self.ants_mode and t1_deformation.absolute: - raise ValueError("please only use relative deformation fields without included affines") - assert ( - self.t1_to_mni_matrix_paths is not None - ), "You must supply T1 to MNI matrices when using deformation fields" - - # Get T1 matrix - t1_to_mni_matrix_path = self.t1_to_mni_matrix_paths[idx] - if self.ants_mode: - t1_premat = affine.invert(read_ants_affine(t1_to_mni_matrix_path, src=t1_image, ref=self.mni_ref_image)) - else: - t1_premat = affine.invert(np.loadtxt(t1_to_mni_matrix_path)) - - t1_postmat = registration.get_postmat(t1_deformation) - else: - #print("coefs!") - t1_premat, t1_postmat = registration.get_premat_and_postmat(t1_deformation) - - if not np.allclose(t1_deformation.fieldToRefMat, self.fieldToRefMat): - raise ValueError("found incompatible field2ref matrix") - - t1_image_tensor = torch.from_numpy(t1_image.data) - t1_deformation_tensor = torch.from_numpy(t1_deformation.data) - t1_premat_tensor = torch.from_numpy(t1_premat) - t1_postmat_tensor = torch.from_numpy(t1_postmat) + t1_coefs = readFnirt(t1_coefs_path, src=t1_image, ref=self.mni_ref_image) + # premat includes: going to the right coordinate space + going to linear mni space + # postmat: applied after the non-liner transform; used for go back to the original coordinate system + t1_premat, t1_postmat = registration.get_premat_and_postmat(t1_coefs) - if self.ants_mode: - # We need to squeeze flip the y-axis for ants - t1_deformation_tensor = t1_deformation_tensor.squeeze(3) - t1_deformation_tensor[..., 1] = -t1_deformation_tensor[..., 1] + t1_image_tensor = torch.FloatTensor(t1_image.data) + t1_coefs_tensor = torch.FloatTensor(t1_coefs.data) + t1_premat_tensor = torch.FloatTensor(t1_premat) + t1_postmat_tensor = torch.FloatTensor(t1_postmat) # Add to modality list tuple_list = [] tuple_list.append( - ( - t1_image_tensor, - t1_deformation_tensor, - t1_premat_tensor, - t1_postmat_tensor, - ) + (t1_image_tensor, t1_coefs_tensor, t1_premat_tensor, t1_postmat_tensor) ) # Iterate extra modalities @@ -243,7 +195,7 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]: # Load image modality_image_path = modality_image_list[idx] modality_image = Image(modality_image_path) - modality_image_tensor = torch.from_numpy(modality_image.data) + modality_image_tensor = torch.FloatTensor(modality_image.data) # Load premat if available # Premats are assumed to map from the modality's space to T1 @@ -253,14 +205,10 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]: # Premats map from FSL to FSL space modality_premat_path = modality_premat_list[idx] - if self.ants_mode: - #modality_premat = torch.from_numpy(affine.invert(read_ants_affine(modality_premat_path))).float() - raise NotImplementedError - else: - with open(modality_premat_path, "r") as f: - modality_premat = torch.inverse( - torch.from_numpy(np.loadtxt(f)) - ).float() + with open(modality_premat_path, "r") as f: + modality_premat = torch.inverse( + torch.from_numpy(np.loadtxt(f)) + ).float() # Convert from voxel space to FSL and back # Note: out final target is MNI, thus vox2fsl uses MNI @@ -274,140 +222,92 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]: ).float() # Get raw FSL -> FSL matrix - t1_raw_mat = torch.from_numpy(t1_deformation.refToSrcMat).float() + t1_raw_mat = torch.from_numpy(t1_coefs.refToSrcMat).float() modality_premat_tensor = torch.matmul(t1_raw_mat, vox2fsl) modality_premat_tensor = torch.matmul( modality_premat, modality_premat_tensor ) - tuple_list.append( - ( - modality_image_tensor, - t1_deformation_tensor, - modality_premat_tensor, - fsl2vox, - ) - ) + tuple_list.append((modality_image_tensor, t1_coefs_tensor, modality_premat_tensor, fsl2vox)) # type: ignore else: tuple_list.append( ( modality_image_tensor, - t1_deformation_tensor, + t1_coefs_tensor, t1_premat_tensor, t1_postmat_tensor, ) ) + if not np.allclose(t1_coefs.fieldToRefMat, self.fieldToRefMat): + raise ValueError("found incompatible field2ref matrix") + # print('warp dataset:',idx, src_image.data.shape, prewarp_coefs.data.shape,prewarp_premat.shape,prewarp_postmat.shape) return tuple(sum(tuple_list, ()) + (label,)) - def get_ctrl_spacing(self) -> Optional[npt.NDArray[Any]]: - if self.deformations_are_coefs: - return np.diag(self.fieldToRefMat)[:3] - else: - return None + def get_ctrl_spacing(self) -> npt.NDArray[Any]: + return np.diag(self.fieldToRefMat)[:3] def __len__(self) -> int: return len(self.t1_paths) class MultiModalInvWarpDataset(Dataset[Tuple[torch.Tensor, ...]]): - """Supports two modes of operation: - Either supply coefficient fields of inverse warp fields in relative coordinates (MNI -> MNI). - This does not require T1 paths. - Alternatively, supply T1 -> MNI coefficient fields that are then inverted. - These require the T1 paths to load the coefficient field correctly. - - Third option: - Load absolute coordinate warp fields along with T1 -> MNI matrices - """ - def __init__( self, - t1_deformation_paths: List[str], + t1_coefs_paths: List[str], extra_modalities: List[Dict[str, List[str]]], + labels: npt.NDArray[Tuple[Any, Any]], t1_paths: Optional[List[str]] = None, - t1_to_mni_matrix_paths: Optional[Sequence[str]] = None, - deformations_are_coefs: bool = True, - ants_mode: bool = False ): - self.t1_deformation_paths = t1_deformation_paths + self.t1_coefs_paths = t1_coefs_paths self.extra_modalities = extra_modalities self.t1_paths = t1_paths - self.t1_to_mni_matrix_paths = t1_to_mni_matrix_paths - self.deformations_are_coefs = deformations_are_coefs - self.ants_mode = ants_mode self.mni_ref_image = get_mni("brain") + self.labels = labels - if self.deformations_are_coefs: - if self.t1_paths is not None: - self.fieldToRefMat = get_field_to_ref_mat( - self.t1_paths[0], self.t1_deformation_paths[0], self.mni_ref_image - ) - else: - self.fieldToRefMat = get_field_to_ref_mat( - self.mni_ref_image, self.t1_deformation_paths[0], self.mni_ref_image - ) + if self.t1_paths is not None: + self.fieldToRefMat = get_field_to_ref_mat( + self.t1_paths[0], self.t1_coefs_paths[0], self.mni_ref_image + ) else: - if t1_paths is None: - raise ValueError("Please supply T1 paths for full warp fields") + self.fieldToRefMat = get_field_to_ref_mat( + self.mni_ref_image, self.t1_coefs_paths[0], self.mni_ref_image + ) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]: # Construct T1 gold standard - intent = None if self.deformations_are_coefs else constants.FSL_FNIRT_DISPLACEMENT_FIELD if self.t1_paths is not None: - src = Image(self.t1_paths[idx]) - else: - src = self.mni_ref_image - - t1_prewarp_deformation = readFnirt( - self.t1_deformation_paths[idx], - src=src, - ref=self.mni_ref_image, - intent=intent - ) - - if self.deformations_are_coefs: - t1_prewarp_premat, _ = registration.get_premat_and_postmat(t1_prewarp_deformation) - - if not np.allclose(t1_prewarp_deformation.fieldToRefMat, self.fieldToRefMat): - raise ValueError("found incompatible field2ref matrix") + t1_prewarp_coefs = readFnirt( + self.t1_coefs_paths[idx], + src=Image(self.t1_paths[idx]), + ref=self.mni_ref_image, + ) else: - assert ( - self.t1_to_mni_matrix_paths is not None - ), "You must supply T1 to MNI matrices when using deformation fields" - - t1_to_mni_matrix_path = self.t1_to_mni_matrix_paths[idx] - - if self.ants_mode: - t1_prewarp_premat = affine.invert(read_ants_affine(t1_to_mni_matrix_path, src=src, ref=self.mni_ref_image)) - else: - t1_prewarp_premat = affine.invert(np.loadtxt(t1_to_mni_matrix_path)) - - t1_prewarp_deformation_tensor = torch.from_numpy(t1_prewarp_deformation.data) - t1_prewarp_premat_tensor = torch.from_numpy(t1_prewarp_premat) - - if self.ants_mode: - # We need to squeeze flip the y-axis for ants - t1_prewarp_deformation_tensor = t1_prewarp_deformation_tensor.squeeze(3) - t1_prewarp_deformation_tensor[..., 1] = -t1_prewarp_deformation_tensor[..., 1] + t1_prewarp_coefs = readFnirt( + self.t1_coefs_paths[idx], src=self.mni_ref_image, ref=self.mni_ref_image + ) + t1_prewarp_premat, _ = registration.get_premat_and_postmat(t1_prewarp_coefs) + t1_prewarp_coefs_tensor = torch.FloatTensor(t1_prewarp_coefs.data) + t1_prewarp_premat_tensor = torch.FloatTensor(t1_prewarp_premat) + label = self.labels[idx] # Add T1 to list tuple_list = [] - tuple_list.append((t1_prewarp_deformation_tensor, t1_prewarp_premat_tensor)) + tuple_list.append((t1_prewarp_coefs_tensor, t1_prewarp_premat_tensor)) # Repeat for remaining modalities for modality in self.extra_modalities: - tuple_list.append((t1_prewarp_deformation_tensor, t1_prewarp_premat_tensor)) + tuple_list.append((t1_prewarp_coefs_tensor, t1_prewarp_premat_tensor)) - return tuple(sum(tuple_list, ())) + if not np.allclose(t1_prewarp_coefs.fieldToRefMat, self.fieldToRefMat): + raise ValueError("found incompatible field2ref matrix") - def get_ctrl_spacing(self) -> Optional[npt.NDArray[Any]]: - if self.deformations_are_coefs: - return np.diag(self.fieldToRefMat)[:3] - else: - return None + return tuple(sum(tuple_list, ())) + (label,) + + def get_ctrl_spacing(self) -> npt.NDArray[Any]: + return np.diag(self.fieldToRefMat)[:3] def __len__(self) -> int: - return len(self.t1_deformation_paths) + return len(self.t1_coefs_paths) \ No newline at end of file diff --git a/brain_deform/lightning.py b/brain_deform/lightning.py index 6455987..86bde19 100644 --- a/brain_deform/lightning.py +++ b/brain_deform/lightning.py @@ -1,9 +1,12 @@ import json from typing import Any, List, Literal, Mapping, Optional, Sequence, Tuple, Union - import pandas as pd -import torch import pytorch_lightning as pl +import torch + +import numpy as np +import random + from pytorch_lightning.trainer.supporters import CombinedLoader from torch.utils.data import DataLoader, Dataset @@ -16,6 +19,112 @@ multi_modal_inv_collate_fn, ) +def combined_collate_fn(args): + main, aux, aug = zip(*args) + #main, aux, aug_main, aug_aux = zip(*args) + + main = multi_modal_collate_fn(main) + aux = multi_modal_collate_fn(aux) + aug = multi_modal_inv_collate_fn(aug) + #aug_aux = multi_modal_inv_collate_fn(aug_aux) + + return main, aux, aug + #return main, aux, aug_main, aug_aux + +class CombinedDataset(Dataset): + def __init__(self, ds_main, ds_aux, ds_aug, condition=None): + self.ds_main = ds_main + self.ds_aux = ds_aux + self.ds_aug = ds_aug + self.condition = condition + + def __len__(self): + return len(self.ds_main) + + # This function performs conditional sampling, and if no candidates are, n-nearest neighbours of the same sex are sampled + # Assuming labels and __getitem__ have the same order. + def sample_conditionally(self, idx, dataset): + + all_file_paths = dataset.t1_paths # Unique paths to files + current_file_path = all_file_paths[idx] # Current path + + all_indices = np.arange(len(all_file_paths)) # Get all possible indices + other_indices = [index for index in all_indices if index != idx] # Get indices excluding the current one + + all_labels = dataset.labels[:, -1] # Get all labels in the dataset + current_label = all_labels[idx] # Current label + + all_sex_labels = dataset.labels[:, -2] # Get all sex labels in the dataset + current_sex_label = all_sex_labels[idx] # Get the current sex label + + other_sex_labels = all_sex_labels[other_indices] # Get sex labels excluding the current one + + matching_sex_labels = all_labels[other_indices][other_sex_labels == current_sex_label] # Get labels of same sex excluding the current one + matching_sex_indices = all_indices[other_indices][other_sex_labels == current_sex_label] # Get indices of same sex excluding the current one + + # Apply the condition to the labels of the same sex. + condition_mask = self.condition(current_label, matching_sex_labels) + matching_condition_indices = np.where(condition_mask)[0] # Get indices that satisfy the condition for augmentation + + num_matching_indices = len(matching_condition_indices) + + other_index_to_all_index_map = {} # Dictionary that maps from other_indices to indices + j = 0 + for i in range(len(all_indices)): + if j >= len(other_indices) or all_indices[i] != other_indices[j]: + continue + other_index_to_all_index_map[j] = i + j += 1 + + matching_index_to_all_index_map = {} # Dictionary that maps from matching_sex_indices to all_indices + j = 0 + for i in range(len(all_indices)): + if j >= len(matching_sex_indices) or all_indices[i] != matching_sex_indices[j]: + continue + matching_index_to_all_index_map[j] = i + j += 1 + + + if num_matching_indices == 0: # If no candidates meet the criterion + num_nearest_neighbours = 3 # Number of nearest neighbours to consider + + label_differences = np.abs(np.array(matching_sex_labels) - current_label) # Measure the dissimilarity between the labels + nearest_neighbours_indices = np.argsort(label_differences)[:num_nearest_neighbours] # Get indices of labels closest to 'current_label' + random_nearest_neighbour_index = np.random.choice(nearest_neighbours_indices) # Randomly select one of the nearest neighbours indices + + #nearest_neighbour_labels = np.array(matching_sex_labels)[nearest_neighbours_indices] + + return matching_index_to_all_index_map[random_nearest_neighbour_index] + + else: + random_matching_condition_index = np.random.choice(matching_condition_indices) + + return matching_index_to_all_index_map[random_matching_condition_index] + + + + def __getitem__(self, idx): + + aux_idx = random.randint(0, len(self.ds_aux)-1) ### Should be random in both cases? + sample_main = self.ds_main[idx] + sample_aux = self.ds_main[aux_idx] + + if self.condition is not None: + + main_aug_idx = self.sample_conditionally(idx, self.ds_aug) + aux_aug_idx = self.sample_conditionally(aux_idx, self.ds_aux) + + sample_main_aug = self.ds_aug[main_aug_idx] + sample_aux_aug = self.ds_main[aux_aug_idx] + #return sample_main, sample_aux_aug, sample_main_aug #sample_main_aug, sample_aux_aug + + else: + sample_main_aug = self.ds_aug[idx] + sample_aux_aug = self.ds_aux[aux_idx] + #return sample_main, sample_aux, sample_aug + + #return sample_main, sample_aux, sample_main_aug, sample_aux_aug + return sample_main, sample_aux_aug, sample_main_aug class BrainDataModule(pl.LightningDataModule): def __init__( @@ -23,13 +132,11 @@ def __init__( data_table_path: str, split_path: str, index_column: str = "index", - t1_column: str = "t1", - deformation_to_mni_column: str = "deformation_to_mni", - deformation_from_mni_column: Optional[str] = None, - deformation_matrix_column: Optional[str] = None, - deformations_are_coefs: bool = True, + t1_column: str = "raw", + coefs_to_mni_column: str = "coefs_to_mni", + coefs_from_mni_column: Optional[str] = None, extra_modalities: List[Mapping[str, str]] = [], - target_column: str = "target", + target_columns: List[str] = ["sex", "target"], invert_on_the_fly: bool = True, interpolation: Literal["linear", "nearest"] = "linear", batch_size: int = 8, @@ -46,7 +153,8 @@ def __init__( random_warp: Union[float, int] = 0, cross_intensity: Union[float, int] = 0, augmentation_probability: Union[float, int] = 1, - ants_mode: bool = False + conditional: bool = False, + conditional_value: int = 10 ): """pytorch-lightning style DataModule to easily augment brain images on the gpu. @@ -59,13 +167,13 @@ def __init__( For each batch, the module returns `(batch_main, batch_aux)`, with each batch consisting of `(registered_image, augmented_registered_image, label)`. If `auxiliary_image=False`, `batch_aux` is simply `(None, None, None)` Args: - data_table_path (str): csv file that tells the module where the images are. table minimally requires the following columns: index,t1,deformation_to_mni,target + data_table_path (str): csv file that tells the module where the images are. table minimally requires the following columns: index,t1,coefs_to_mni,target split_path (str): yaml file with split into train, val, test, augmentation, auxiliary sets. augmentation->warp targets, auxiliary->unlabeled images for semisupervised learning index_column (str, optional): in case your data_table columns are named differently. Defaults to "index". t1_column (str, optional): in case your data_table columns are named differently. Defaults to "t1". - deformation_to_mni_column (str, optional): in case your data_table columns are named differently. Defaults to "deformation_to_mni". - deformation_from_mni_column (Optional[str], optional): in case you want to use precomputed MNI-to-T1-linear warps, specify the data_table column name here. you probably don't need this. Defaults to None. - extra_modalities (List[Mapping[str, str]], optional): specify column names for additional modalities here. e.g., [{'image_column': 'flair', {image_column: 'swi', premat_column: 'swi_to_T1'}]. Defaults to []. + coefs_to_mni_column (str, optional): in case your data_table columns are named differently. Defaults to "coefs_to_mni". + coefs_from_mni_column (Optional[str], optional): in case you want to use precomputed MNI-to-T1-linear warps, specify the data_table column name here. you probably don't need this. Defaults to None. + extra_modalities (List[Mapping[str, str]], optional): specify column names for additional modalities here. e.g., [{'image_column': 'flair'}, {image_column: 'swi', premat_column: 'swi_to_T1'}]. Defaults to []. target_column (str, optional): in case your data_table columns are named differently. Defaults to "target". invert_on_the_fly (bool, optional): invert warp fields on the fly, or use precomputed files. you probably wan't to keep the default. Defaults to True. interpolation (Literal[ "linear", "nearest" ], optional): interpolation method to use for remapping coordinates. linear or nearest neighbour. spline not yet implemented. Defaults to "linear". @@ -79,14 +187,13 @@ def __init__( cross_warp (Union[float, int], optional): max factor for subject-to-subject warping images to targets in augmentation set. Defaults to 0. random_warp (Union[float, int], optional): max factor for random warping. Defaults to 0. cross_intensity (Union[float, int], optional): max factor for subject-to-subject transforming intensity to targets in augmentation set. Defaults to 0. - ants_mode: run in ANTs compatiblity mode """ super().__init__() assert not (cross_warp and random_warp) assert not (cross_warp and random_warp) - if (cross_warp or random_warp) and registration == "nonlinear": - raise ValueError("Cross warps and random warps can only be applied to linearly registered images") + if cross_warp or random_warp: + assert registration != "nonlinear" self.save_hyperparameters() @@ -109,22 +216,18 @@ def setup(self, stage: Optional[str] = None) -> None: # self.target_shape = datasets.get_mni("brain").shape self.hparams.target_shape = datasets.get_mni("brain").shape - # TODO: deformations_are_coefs could be unique for both the main dataset and the invwarp dataset - if self.hparams.deformations_are_coefs: - # Get spacing for both main and augmentation - # Assert both have the same resolution - df_ = self.df.loc[self.idx["train"]] - self.main_ctrl_spacing = self.get_main_ds(df_).get_ctrl_spacing() # type: ignore - self.augmentation_ctrl_spacing = self.get_augmentation_ds( - df_ - ).get_ctrl_spacing() # type: ignore - if (self.main_ctrl_spacing != self.augmentation_ctrl_spacing).any(): - raise ValueError("main and aug datasets must have the same resolution") - - # Spacing is set to main by default - self.hparams.spacing = self.main_ctrl_spacing - else: - self.hparams.spacing = None + # Get spacing for both main and augmentation + # Assert both have the same resolution + df_ = self.df.loc[self.idx["train"]] + self.main_ctrl_spacing = self.get_main_ds(df_).get_ctrl_spacing() + self.augmentation_ctrl_spacing = self.get_augmentation_ds( + df_ + ).get_ctrl_spacing() + if (self.main_ctrl_spacing != self.augmentation_ctrl_spacing).any(): + raise ValueError("main and aug datasets must have the same resolution") + + # Spacing is set to main by default + self.hparams.spacing = self.main_ctrl_spacing def train_dataloader(self) -> CombinedLoader: # type: ignore return self.create_dataloader(self.idx["train"], shuffle=True) @@ -148,14 +251,10 @@ def get_main_ds(self, df: pd.DataFrame) -> Dataset[Tuple[torch.Tensor, ...]]: return MultiModalWarpDataset( t1_paths=df[self.hparams.t1_column].values, - t1_deformation_paths=df[self.hparams.deformation_to_mni_column].values, - labels=df[self.hparams.target_column].values, + t1_coefs_paths=df[self.hparams.coefs_to_mni_column].values, + labels=df[self.hparams.target_columns].values, extra_modalities=extra_modalities, - t1_to_mni_matrix_paths=df[self.hparams.deformation_matrix_column].values - if self.hparams.deformation_matrix_column is not None - else None, - deformations_are_coefs=self.hparams.deformations_are_coefs, - ants_mode=self.hparams.ants_mode + ) def get_main_dataloader( @@ -187,24 +286,16 @@ def get_augmentation_ds( if self.hparams.invert_on_the_fly: # T1 -> coef -> MNI return MultiModalInvWarpDataset( - t1_deformation_paths=df[self.hparams.deformation_to_mni_column].values, + t1_coefs_paths=df[self.hparams.coefs_to_mni_column].values, extra_modalities=extra_modalities, t1_paths=df[self.hparams.t1_column].values, - t1_to_mni_matrix_paths=df[self.hparams.deformation_matrix_column].values - if self.hparams.deformation_matrix_column is not None - else None, - deformations_are_coefs=self.hparams.deformations_are_coefs, - ants_mode=self.hparams.ants_mode + labels=df[self.hparams.target_columns].values, ) else: return MultiModalInvWarpDataset( - t1_deformation_paths=df[self.hparams.deformation_from_mni_column].values, + t1_coefs_paths=df[self.hparams.coefs_from_mni_column].values, extra_modalities=extra_modalities, - t1_to_mni_matrix_paths=df[self.hparams.deformation_matrix_column].values - if self.hparams.deformation_matrix_column is not None - else None, - deformations_are_coefs=self.hparams.deformations_are_coefs, - ants_mode=self.hparams.ants_mode + labels=df[self.hparams.target_columns].values, ) def get_augmentation_dataloader( @@ -224,67 +315,122 @@ def get_augmentation_dataloader( ) def create_dataloader(self, idx: Sequence[int], shuffle: bool) -> CombinedLoader: - dl_main = self.get_main_dataloader(self.df.loc[idx], shuffle) + if self.hparams.conditional is True: + ds_main = self.get_main_ds(self.df.loc[idx]) + + df_ = ( + self.df.loc[self.idx["auxiliary"]] + if self.idx["auxiliary"] is not False + else self.df + ) + ds_aux = self.get_main_ds(df_) - df_ = ( - self.df.loc[self.idx["auxiliary"]] - if self.idx["auxiliary"] is not False - else self.df - ) - dl_aux = self.get_main_dataloader(df_, shuffle) + df_ = ( + self.df.loc[self.idx["augmentation"]] + if self.idx["augmentation"] is not False + else self.df + + ) + ds_aug = self.get_augmentation_ds(df_) + + ### PLACE THE CONDITION FUNCTION HERE ### + threshold_condition = lambda label,labels: np.abs(labels - label) <= self.hparams.conditional_value + + + ds_combined = CombinedDataset(ds_main, ds_aux, ds_aug, threshold_condition) + + return DataLoader( + ds_combined, + collate_fn=combined_collate_fn, + shuffle=shuffle, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + persistent_workers=True, + pin_memory=self.hparams.pin_memory, + drop_last=True, + ) + else: + ds_main = self.get_main_ds(self.df.loc[idx]) + + df_ = ( + self.df.loc[self.idx["auxiliary"]] + if self.idx["auxiliary"] is not False + else self.df + ) + ds_aux = self.get_main_ds(df_) + + df_ = ( + self.df.loc[self.idx["augmentation"]] + if self.idx["augmentation"] is not False + else self.df + ) + ds_aug = self.get_augmentation_ds(df_) + + # In pytorch 2.0 CombinedLoader only supports "sequential" mode + ds_combined = CombinedDataset(ds_main, ds_aux, ds_aug) + + return DataLoader( + ds_combined, + collate_fn=combined_collate_fn, + shuffle=shuffle, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + persistent_workers=True, + pin_memory=self.hparams.pin_memory, + drop_last=True, + ) - df_ = ( - self.df.loc[self.idx["augmentation"]] - if self.idx["augmentation"] is not False - else self.df - ) - dl_augment = self.get_augmentation_dataloader(df_, shuffle) - return CombinedLoader([dl_main, dl_aux, dl_augment]) def on_after_batch_transfer( self, batch: Tuple[torch.Tensor, ...], dataloader_idx: int - ) -> Tuple[ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor], - ]: + ) -> Tuple[torch.Tensor, ...]: + #data_main, data_aux, data_main_aug , data_aux_aug = batch data_main, data_aux, data_aug = batch - img_main, deformation_main, premat_main, postmat_main, label_main = data_main - img_aux, deformation_aux, premat_aux, postmat_aux, label_aux = data_aux - deformation_aug, premat_aug = data_aug + + img_main, coefs_main, premat_main, postmat_main, label_main = data_main + img_aux, coefs_aux, premat_aux, postmat_aux, label_aux = data_aux + coefs_aug, premat_aug, label_aug = data_aug + #coefs_main_aug, premat_main_aug, label_main_aug = data_main_aug + #coefs_aux_aug, premat_aux_aug, label_aux_aug = data_aux_aug img_main_registered, img_main_registered_augmented = process_image( img_main, - deformation_main, + coefs_main, premat_main, postmat_main, img_aux, - deformation_aux, + coefs_aux, premat_aux, postmat_aux, - deformation_aug, + coefs_aug, + #coefs_main_aug, # premat_aug, + #premat_main_aug, # self.hparams, ) if self.hparams.auxiliary_image: img_aux_registered, img_aux_registered_augmented = process_image( img_aux, - deformation_aux, + coefs_aux, premat_aux, postmat_aux, img_main, - deformation_main, + coefs_main, premat_main, postmat_main, - deformation_aug, + #coefs_aux_aug, # + coefs_aug, premat_aug, + #premat_aux_aug, # self.hparams, ) else: img_aux_registered, img_aux_registered_augmented = None, None - return (img_main_registered, img_main_registered_augmented, label_main), ( + return (img_main_registered, img_main_registered_augmented, label_main, label_aug), ( img_aux_registered, img_aux_registered_augmented, label_aux, + label_aug )