diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9f1429d477..680e6b806b 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -952,14 +952,15 @@ def create_shear( Args: spatial_dims: spatial rank - coefs: shearing factors, a tuple of 2 floats for 2D, a tuple of 6 floats for 3D), - take a 3D affine as example:: + coefs: shearing factors, a tuple of 2 floats for 2D, a tuple of 6 floats for 3D). + Individual single-axis shear matrices are composed (multiplied) in + coefficient order so that the result is a proper shear with determinant 1. + For 2D with coefs ``(Sx, Sy)`` the composed matrix is:: [ - [1.0, coefs[0], coefs[1], 0.0], - [coefs[2], 1.0, coefs[3], 0.0], - [coefs[4], coefs[5], 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], + [1.0, Sx, 0.0], + [Sy, 1.0 + Sx*Sy, 0.0], + [0.0, 0.0, 1.0], ] device: device to compute and store the output (when the backend is "torch"). @@ -982,17 +983,30 @@ def create_shear( def _create_shear(spatial_dims: int, coefs: Sequence[float] | float, eye_func=np.eye) -> NdarrayOrTensor: if spatial_dims == 2: coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0) - out = eye_func(3) - out[0, 1], out[1, 0] = coefs[0], coefs[1] - return out # type: ignore - if spatial_dims == 3: + rank = 3 + shear_indices = [(0, 1, coefs[0]), (1, 0, coefs[1])] + elif spatial_dims == 3: coefs = ensure_tuple_size(coefs, dim=6, pad_val=0.0) - out = eye_func(4) - out[0, 1], out[0, 2] = coefs[0], coefs[1] - out[1, 0], out[1, 2] = coefs[2], coefs[3] - out[2, 0], out[2, 1] = coefs[4], coefs[5] - return out # type: ignore - raise NotImplementedError("Currently only spatial_dims in [2, 3] are supported.") + rank = 4 + shear_indices = [ + (0, 1, coefs[0]), + (0, 2, coefs[1]), + (1, 0, coefs[2]), + (1, 2, coefs[3]), + (2, 0, coefs[4]), + (2, 1, coefs[5]), + ] + else: + raise NotImplementedError("Currently only spatial_dims in [2, 3] are supported.") + # Compose individual single-axis shear matrices so that the result is a + # proper (area/volume-preserving) shear with determinant 1. Each elementary + # shear is pre-multiplied, so the first coefficient is applied first. + out = eye_func(rank) + for i, j, c in shear_indices: + s = eye_func(rank) + s[i, j] = c + out = s @ out + return out # type: ignore def create_scale( diff --git a/tests/transforms/test_create_grid_and_affine.py b/tests/transforms/test_create_grid_and_affine.py index f4793cabe0..d636b5a516 100644 --- a/tests/transforms/test_create_grid_and_affine.py +++ b/tests/transforms/test_create_grid_and_affine.py @@ -224,13 +224,31 @@ def test_create_rotate(self): def test_create_shear(self): test_assert(create_shear, (2, 1.0), np.array([[1.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) - test_assert(create_shear, (2, (2.0, 3.0)), np.array([[1.0, 2.0, 0.0], [3.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) + test_assert(create_shear, (2, (2.0, 3.0)), np.array([[1.0, 2.0, 0.0], [3.0, 7.0, 0.0], [0.0, 0.0, 1.0]])) test_assert( create_shear, (3, 1.0), np.array([[1.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), ) + def test_create_shear_determinant(self): + """Composed shear must be area/volume-preserving (determinant == 1).""" + for coefs in [(0.3, 0.5), (1.0, 2.0), (-0.5, 0.7)]: + m = create_shear(2, coefs, backend="numpy") + assert_allclose(np.linalg.det(m), 1.0, atol=1e-10) + for coefs in [(0.1, 0.2, 0.3, 0.4, 0.5, 0.6)]: + m = create_shear(3, coefs, backend="numpy") + assert_allclose(np.linalg.det(m), 1.0, atol=1e-10) + + def test_create_shear_sequential_equivalence(self): + """Composing single-axis shears must equal a single multi-axis shear.""" + sx, sy = 0.3, 0.5 + shear_x = create_shear(2, (sx, 0.0), backend="numpy") + shear_y = create_shear(2, (0.0, sy), backend="numpy") + combined = create_shear(2, (sx, sy), backend="numpy") + # shear_y applied after shear_x + assert_allclose(shear_y @ shear_x, combined, atol=1e-10) + def test_create_scale(self): test_assert(create_scale, (2, 2), np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) test_assert(create_scale, (2, [2, 2, 2]), np.array([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 1.0]])) diff --git a/tests/transforms/test_rand_affine.py b/tests/transforms/test_rand_affine.py index 7b07d5f09d..6985912b5c 100644 --- a/tests/transforms/test_rand_affine.py +++ b/tests/transforms/test_rand_affine.py @@ -102,7 +102,7 @@ {"img": p(torch.arange(64).reshape((1, 8, 8))), "spatial_size": (3, 3)}, p( torch.tensor( - [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + [[[18.7575, 15.5552, 12.3528], [27.4202, 24.2178, 21.0154], [36.0828, 32.8805, 29.6781]]] ) ), ] @@ -122,7 +122,7 @@ {"img": p(torch.arange(64).reshape((1, 8, 8)))}, p( torch.tensor( - [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + [[[18.7575, 15.5552, 12.3528], [27.4202, 24.2178, 21.0154], [36.0828, 32.8805, 29.6781]]] ) ), ] diff --git a/tests/transforms/test_rand_affined.py b/tests/transforms/test_rand_affined.py index 1c55a936d8..a3c5398ee2 100644 --- a/tests/transforms/test_rand_affined.py +++ b/tests/transforms/test_rand_affined.py @@ -93,7 +93,7 @@ "img": MetaTensor(torch.arange(64).reshape((1, 8, 8))), "seg": MetaTensor(torch.arange(64).reshape((1, 8, 8))), }, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), + torch.tensor([[[18.7575, 15.5552, 12.3528], [27.4202, 24.2178, 21.0154], [36.0828, 32.8805, 29.6781]]]), ] ) TESTS.append( @@ -118,14 +118,14 @@ torch.tensor( [ [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], + [18.757534, 15.55517, 12.352802], + [27.420181, 24.217815, 21.01545], + [36.08283, 32.880463, 29.678097], ] ] ) ), - "seg": MetaTensor(torch.tensor([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + "seg": MetaTensor(torch.tensor([[[19.0, 12.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), }, ] ) @@ -168,14 +168,14 @@ np.array( [ [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], + [18.757534, 15.55517, 12.352802], + [27.420181, 24.217815, 21.01545], + [36.08283, 32.880463, 29.678097], ] ] ) ), - "seg": MetaTensor(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + "seg": MetaTensor(np.array([[[19.0, 12.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), }, ] ) @@ -202,14 +202,14 @@ torch.tensor( [ [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], + [18.757534, 15.55517, 12.352802], + [27.420181, 24.217815, 21.01545], + [36.08283, 32.880463, 29.678097], ] ] ) ), - "seg": MetaTensor(torch.tensor([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + "seg": MetaTensor(torch.tensor([[[19.0, 12.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), }, ] )