diff --git a/src/bridge_utils.py b/src/bridge_utils.py new file mode 100644 index 0000000..55f67a1 --- /dev/null +++ b/src/bridge_utils.py @@ -0,0 +1,51 @@ +import itk +import torch +import numpy as np +from monai.data import ITKReader +from monai.data.meta_tensor import MetaTensor +from monai.transforms import EnsureChannelFirst +from monai.utils import convert_to_dst_type + + +def metatensor_to_array(metatensor): + metatensor = metatensor.squeeze() + metatensor = metatensor.permute(*torch.arange(metatensor.ndim - 1, -1, -1)) + + return metatensor.get_array() + + +def image_to_metatensor(image): + """ + Converts an ITK image to a MetaTensor object. + + Args: + image: The ITK image to be converted. + + Returns: + A MetaTensor object containing the array data and metadata. + """ + reader = ITKReader(affine_lps_to_ras=False) + image_array, meta_data = reader.get_data(image) + image_array = convert_to_dst_type(image_array, dst=image_array, dtype=itk.D)[0] + metatensor = MetaTensor.ensure_torch_and_prune_meta(image_array, meta_data) + metatensor = EnsureChannelFirst()(metatensor) + + return metatensor + + + +def remove_border(image): + """ + MONAI seems to have different behavior in the borders of the image than ITK. + This helper function sets the border of the ITK image as 0 (padding but keeping + the same image size) in order to allow numerical comparison between the + result from resampling with ITK/Elastix and resampling with MONAI. + To use: image[:] = remove_border(image) + Args: + image: The ITK image to be padded. + + Returns: + The padded array of data. + """ + return np.pad(image[1:-1, 1:-1, 1:-1] if image.ndim==3 else image[1:-1, 1:-1], + pad_width=1) diff --git a/src/itk_torch_ddf_bridge.py b/src/itk_torch_ddf_bridge.py new file mode 100644 index 0000000..298ba35 --- /dev/null +++ b/src/itk_torch_ddf_bridge.py @@ -0,0 +1,82 @@ +import itk +import torch +import monai +import numpy as np +import matplotlib.pyplot as plt + +def monai_to_itk_ddf(image, ddf): + """ + converting the dense displacement field from the MONAI space to the ITK + Args: + image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) + ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) + Returns: + displacement_field: itk image of the corresponding displacement field + + """ + # 3, D, H, W -> D, H, W, 3 + ndim = image.ndim + ddf = ddf.transpose(tuple(list(range(1, ndim+1)) + [0])) + # x, y, z -> z, x, y + ddf = ddf[..., ::-1] + + # Correct for spacing + spacing = np.asarray(image.GetSpacing(), dtype=np.float64) + ddf *= np.array(spacing, ndmin=ndim+1) + + # Correct for direction + direction = np.asarray(image.GetDirection(), dtype=np.float64) + ddf = np.einsum('ij,...j->...i', direction, ddf, dtype=np.float64).astype(np.float32) + + # initialise displacement field - + vector_component_type = itk.F + vector_pixel_type = itk.Vector[vector_component_type, ndim] + displacement_field_type = itk.Image[vector_pixel_type, ndim] + displacement_field = itk.GetImageFromArray(ddf, ttype=displacement_field_type) + + # Set image metadata + displacement_field.SetSpacing(image.GetSpacing()) + displacement_field.SetOrigin(image.GetOrigin()) + displacement_field.SetDirection(image.GetDirection()) + + return displacement_field + + +def itk_warp(image, ddf): + """ + warping with python itk + Args: + image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) + ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) + Returns: + warped_image: numpy array of shape (H, W) or (D, H, W) + """ + # MONAI->ITK ddf + displacement_field = monai_to_itk_ddf(image, ddf) + + # Resample using the ddf + interpolator = itk.LinearInterpolateImageFunction.New(image) + warped_image = itk.warp_image_filter(image, + interpolator=interpolator, + displacement_field=displacement_field, + output_parameters_from_image=image) + + return np.asarray(warped_image) + + +def monai_wrap(image_tensor, ddf_tensor): + """ + warping with MONAI + Args: + image_tensor: torch tensor of shape 2D: (1, 1, H, W) and 3D: (1, 1, D, H, W) + ddf_tensor: torch tensor of shape 2D: (1, 2, H, W) and 3D: (1, 3, D, H, W) + Returns: + warped_image: numpy array of shape (H, W) or (D, H, W) + """ + warp = monai.networks.blocks.Warp(mode='bilinear', padding_mode="zeros") + warped_image = warp(image_tensor.to(torch.float64), ddf_tensor.to(torch.float64)) + + return warped_image.to(torch.float32).squeeze().numpy() + + + diff --git a/src/run_tests.py b/src/run_tests.py new file mode 100644 index 0000000..ad09885 --- /dev/null +++ b/src/run_tests.py @@ -0,0 +1,17 @@ +from test_cases import * +import test_utils + +test_utils.download_test_data() + +filepath_2D = str(test_utils.TEST_DATA_DIR / 'CT_2D_head_fixed.mha') +filepath_3D = str(test_utils.TEST_DATA_DIR / 'copd1_highres_INSP_STD_COPD_img.nii.gz') + +# 2D cases +test_random_array(ndim=2) +test_real_data(filepath=filepath_2D) + +# 3D cases +test_random_array(ndim=3) +test_real_data(filepath=filepath_3D) + + diff --git a/src/test_cases.py b/src/test_cases.py new file mode 100644 index 0000000..41b23d2 --- /dev/null +++ b/src/test_cases.py @@ -0,0 +1,66 @@ +from itk_torch_ddf_bridge import * +from bridge_utils import remove_border + +def test_random_array(ndim): + print("\nTest: Random array with random spacing, direction and origin, ndim={}".format(ndim)) + + # Create image/array with random size and pixel intensities + s = torch.randint(low=2, high=20, size=(ndim,)) + img = 100 * torch.rand((1, 1, *s.tolist()), dtype=torch.float32) + + # Pad at the edges because ITK and MONAI have different behavior there + # during resampling + img = torch.nn.functional.pad(img, pad=ndim*(1, 1)) + ddf = 5 * torch.rand((1, ndim, *img.shape[-ndim:]), dtype=torch.float32) - 2.5 + + # Warp with MONAI + img_resampled = monai_wrap(img, ddf) + + # Create ITK image + itk_img = itk.GetImageFromArray(img.squeeze().numpy()) + + # Set random spacing + spacing = 3 * np.random.rand(ndim) + itk_img.SetSpacing(spacing) + + # Set random direction + direction = 5 * np.random.rand(ndim, ndim) - 5 + direction = itk.matrix_from_array(direction) + itk_img.SetDirection(direction) + + # Set random origin + origin = 100 * np.random.rand(ndim) - 100 + itk_img.SetOrigin(origin) + + # Warp with ITK + itk_img_resampled = itk_warp(itk_img, ddf.squeeze().numpy()) + + # Compare + print("All close: ", np.allclose(img_resampled, itk_img_resampled, rtol=1e-3, atol=1e-3)) + diff = img_resampled - itk_img_resampled + print(diff.min(), diff.max()) + + +def test_real_data(filepath): + print("\nTEST: Real data with random deformation field") + # Read image + image = itk.imread(filepath, itk.F) + image[:] = remove_border(image) + ndim = image.ndim + + # Random ddf + ddf = 10 * torch.rand((1, ndim, *image.shape), dtype=torch.float32) - 10 + + # Warp with MONAI + image_tensor = torch.tensor(itk.GetArrayFromImage(image), dtype=torch.float32).unsqueeze(0).unsqueeze(0) + img_resampled = monai_wrap(image_tensor, ddf) + + # Warp with ITK + itk_img_resampled = itk_warp(image, ddf.squeeze().numpy()) + + # Compare + print("All close: ", np.allclose(img_resampled, itk_img_resampled)) + diff = img_resampled - itk_img_resampled + print(diff.min(), diff.max()) + + diff --git a/src/test_utils.py b/src/test_utils.py new file mode 100644 index 0000000..c69df16 --- /dev/null +++ b/src/test_utils.py @@ -0,0 +1,18 @@ +import pathlib +import subprocess +# import sys + +TEST_DATA_DIR = pathlib.Path(__file__).parent.parent / "test_files" + +def download_test_data(): + subprocess.run( + [ + "girder-client", + "--api-url", + "https://data.kitware.com/api/v1", + "localsync", + "62a0efe5bddec9d0c4175c1f", + str(TEST_DATA_DIR), + ], + #stdout=sys.stdout, + ) diff --git a/test_files/CT_2D_head_fixed.mha b/test_files/CT_2D_head_fixed.mha new file mode 100644 index 0000000..4f90b93 Binary files /dev/null and b/test_files/CT_2D_head_fixed.mha differ diff --git a/test_files/CT_2D_head_moving.mha b/test_files/CT_2D_head_moving.mha new file mode 100644 index 0000000..b04fc78 Binary files /dev/null and b/test_files/CT_2D_head_moving.mha differ