diff --git a/.github/workflows/test-action.yml b/.github/workflows/test-action.yml new file mode 100644 index 0000000..3952eaa --- /dev/null +++ b/.github/workflows/test-action.yml @@ -0,0 +1,28 @@ +name: test + +on: + pull_request: + push: + branches: + - master + +jobs: + test-linux: + runs-on: ubuntu-latest + + strategy: + max-parallel: 5 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + - name: Install dependencies + run: | + pip install -e ".[testing]" + + - name: test + run: | + python -m unittest discover diff --git a/setup.cfg b/setup.cfg index c14eff5..778008c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,6 +21,10 @@ python_requires = >=3.4 install_requires = torch itk==5.3rc4 + +[options.extras_require] +testing = + monai [options.packages.find] where = src diff --git a/src/itk_torch_transform_bridge.py b/src/itk_torch_transform_bridge.py index fe1ad6a..c4bdd05 100644 --- a/src/itk_torch_transform_bridge.py +++ b/src/itk_torch_transform_bridge.py @@ -1,5 +1,6 @@ import itk import torch +import torch.nn.functional as F def monai_warp_to_itk_transform(image_fixed: "itk.Image", image_moving:"itk.Image", network_shape: "[int]", network: "torch.nn.Module", **kwargs)->"itk.Transform": tensor_fixed, tensor_moving, convert_back = itk_transform_bridge(image_fixed, image_moving, network_shape, phi_type="displacement_field", order="vector_first", **kwargs) @@ -13,18 +14,66 @@ def grid_sample_to_itk_transform(image_fixed: "itk.Image", image_moving:"itk.Ima return convert_back(phi) -def itk_transform_bridge(image_fixed: "itk.Image", image_moving:"itk.Image", network_shape: "[int]", **kwargs)->"(torch.Tensor, torch.Tensor, Callable[[torch.Tensor], itk.Transform])": - # Convert images to tensors +def itk_transform_bridge(image_fixed: "itk.Image", image_moving:"itk.Image", network_shape: "[int]", phi_type="displacement_field", range=(-1, 1))->"(torch.Tensor, torch.Tensor, Callable[[torch.Tensor], itk.Transform])": + to_network_space = resampling_transform(image_moving, network_shape) + from_network_space = resampling_transform(image_fixed, network_shape).GetInverse() + + moving_npy = np.array(image_moving) + fixed_npy = np.array(image_fixed) + + # turn images into torch Tensors: add feature and batch dimensions (each of length 1) + moving_trch = torch.Tensor(moving_npy)[None, None] + fixed_trch = torch.Tensor(fixed_npy)[None, None] - # ... + + # Here we resize the input images to the shape expected by the neural network. This affects the + # pixel stride as well as the magnitude of the displacement vectors of the resulting displacement field, which + # convert_back will have to compensate for. + + #TODO: it is crucial to blur before this step if we are downsampling! + moving_resized = F.interpolate(moving_trch, size=network_shape, mode="trilinear", align_corners=False) + fixed_resized = F.interpolate(fixed_trch, size=network_shape, mode="trilinear", align_corners=False) + # Create convert_back function def convert_back(phi: "torch.Tensor") -> "itk.Transform": + phi = phi.cpu().detach() + + if phi_type == "coordinate_field" and range == (-1, 1): + # itk.DeformationFieldTransform expects a displacement field, so we subtract off the identity map. + disp = (phi - ) + + dimension = len(network_shape_list) + + + # We convert the displacement field into an itk Vector Image. + scale = torch.Tensor(network_shape_list) + + for _ in network_shape_list: + scale = scale[:, None] + disp *= scale + + # disp is a shape [3, H, W, D] tensor with vector components in the order [vi, vj, vk] + disp_itk_format = disp.double().numpy()[list(reversed(range(dimension)))].transpose(list(range(1, dimension + 1)) + [0]) + # disp_itk_format is a shape [H, W, D, 3] array with vector components in the order [vk, vj, vi] + # as expected by itk. + + itk_disp_field = itk.image_from_array(disp_itk_format, is_vector=True) + + deformable_transform = itk.DisplacementFieldTransform[(itk.D, dimension)].New() + + deformable_transform.SetDisplacementField(itk_disp_field) + + final_transform = itk.CompositeTransform[itk.D, dimension].New() + + final_transform.PrependTransform(from_network_space) + final_transform.PrependTransform(deformable_transform) + final_transform.PrependTransform(to_network_space) - return itk.CompositeTransform(some_stuff) + return final_transform - return tensor_fixed, tensor_moving, convert_back + return fixed_resized, moving_resized, convert_back def resampling_transform(image, shape) -> itk.Transform: diff --git a/test/test_grid_sample.py b/test/test_grid_sample.py new file mode 100644 index 0000000..530117d --- /dev/null +++ b/test/test_grid_sample.py @@ -0,0 +1,4 @@ +import itk_torch_transform +import torch +import itk +import torch.nn.functional as F diff --git a/test/test_monai.py b/test/test_monai.py new file mode 100644 index 0000000..e389d78 --- /dev/null +++ b/test/test_monai.py @@ -0,0 +1,6 @@ +import torch +import monai +import unittest + +class TestMonaiWarp(unittest.TestCase): +