Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,14 +952,15 @@ def create_shear(

Args:
spatial_dims: spatial rank
coefs: shearing factors, a tuple of 2 floats for 2D, a tuple of 6 floats for 3D),
take a 3D affine as example::
coefs: shearing factors, a tuple of 2 floats for 2D, a tuple of 6 floats for 3D).
Individual single-axis shear matrices are composed (multiplied) in
coefficient order so that the result is a proper shear with determinant 1.
For 2D with coefs ``(Sx, Sy)`` the composed matrix is::

[
[1.0, coefs[0], coefs[1], 0.0],
[coefs[2], 1.0, coefs[3], 0.0],
[coefs[4], coefs[5], 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
[1.0, Sx, 0.0],
[Sy, 1.0 + Sx*Sy, 0.0],
[0.0, 0.0, 1.0],
]

device: device to compute and store the output (when the backend is "torch").
Expand All @@ -982,17 +983,30 @@ def create_shear(
def _create_shear(spatial_dims: int, coefs: Sequence[float] | float, eye_func=np.eye) -> NdarrayOrTensor:
if spatial_dims == 2:
coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0)
out = eye_func(3)
out[0, 1], out[1, 0] = coefs[0], coefs[1]
return out # type: ignore
if spatial_dims == 3:
rank = 3
shear_indices = [(0, 1, coefs[0]), (1, 0, coefs[1])]
elif spatial_dims == 3:
coefs = ensure_tuple_size(coefs, dim=6, pad_val=0.0)
out = eye_func(4)
out[0, 1], out[0, 2] = coefs[0], coefs[1]
out[1, 0], out[1, 2] = coefs[2], coefs[3]
out[2, 0], out[2, 1] = coefs[4], coefs[5]
return out # type: ignore
raise NotImplementedError("Currently only spatial_dims in [2, 3] are supported.")
rank = 4
shear_indices = [
(0, 1, coefs[0]),
(0, 2, coefs[1]),
(1, 0, coefs[2]),
(1, 2, coefs[3]),
(2, 0, coefs[4]),
(2, 1, coefs[5]),
]
else:
raise NotImplementedError("Currently only spatial_dims in [2, 3] are supported.")
# Compose individual single-axis shear matrices so that the result is a
# proper (area/volume-preserving) shear with determinant 1. Each elementary
# shear is pre-multiplied, so the first coefficient is applied first.
out = eye_func(rank)
for i, j, c in shear_indices:
s = eye_func(rank)
s[i, j] = c
out = s @ out
return out # type: ignore


def create_scale(
Expand Down
20 changes: 19 additions & 1 deletion tests/transforms/test_create_grid_and_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,31 @@ def test_create_rotate(self):

def test_create_shear(self):
test_assert(create_shear, (2, 1.0), np.array([[1.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))
test_assert(create_shear, (2, (2.0, 3.0)), np.array([[1.0, 2.0, 0.0], [3.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))
test_assert(create_shear, (2, (2.0, 3.0)), np.array([[1.0, 2.0, 0.0], [3.0, 7.0, 0.0], [0.0, 0.0, 1.0]]))
test_assert(
create_shear,
(3, 1.0),
np.array([[1.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),
)

def test_create_shear_determinant(self):
"""Composed shear must be area/volume-preserving (determinant == 1)."""
for coefs in [(0.3, 0.5), (1.0, 2.0), (-0.5, 0.7)]:
m = create_shear(2, coefs, backend="numpy")
assert_allclose(np.linalg.det(m), 1.0, atol=1e-10)
for coefs in [(0.1, 0.2, 0.3, 0.4, 0.5, 0.6)]:
m = create_shear(3, coefs, backend="numpy")
assert_allclose(np.linalg.det(m), 1.0, atol=1e-10)

def test_create_shear_sequential_equivalence(self):
"""Composing single-axis shears must equal a single multi-axis shear."""
sx, sy = 0.3, 0.5
shear_x = create_shear(2, (sx, 0.0), backend="numpy")
shear_y = create_shear(2, (0.0, sy), backend="numpy")
combined = create_shear(2, (sx, sy), backend="numpy")
# shear_y applied after shear_x
assert_allclose(shear_y @ shear_x, combined, atol=1e-10)

def test_create_scale(self):
test_assert(create_scale, (2, 2), np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))
test_assert(create_scale, (2, [2, 2, 2]), np.array([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 1.0]]))
Expand Down
4 changes: 2 additions & 2 deletions tests/transforms/test_rand_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
{"img": p(torch.arange(64).reshape((1, 8, 8))), "spatial_size": (3, 3)},
p(
torch.tensor(
[[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]
[[[18.7575, 15.5552, 12.3528], [27.4202, 24.2178, 21.0154], [36.0828, 32.8805, 29.6781]]]
)
),
]
Expand All @@ -122,7 +122,7 @@
{"img": p(torch.arange(64).reshape((1, 8, 8)))},
p(
torch.tensor(
[[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]
[[[18.7575, 15.5552, 12.3528], [27.4202, 24.2178, 21.0154], [36.0828, 32.8805, 29.6781]]]
)
),
]
Expand Down
26 changes: 13 additions & 13 deletions tests/transforms/test_rand_affined.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
"img": MetaTensor(torch.arange(64).reshape((1, 8, 8))),
"seg": MetaTensor(torch.arange(64).reshape((1, 8, 8))),
},
torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]),
torch.tensor([[[18.7575, 15.5552, 12.3528], [27.4202, 24.2178, 21.0154], [36.0828, 32.8805, 29.6781]]]),
]
)
TESTS.append(
Expand All @@ -118,14 +118,14 @@
torch.tensor(
[
[
[18.736153, 15.581954, 12.4277525],
[27.398798, 24.244598, 21.090399],
[36.061443, 32.90724, 29.753046],
[18.757534, 15.55517, 12.352802],
[27.420181, 24.217815, 21.01545],
[36.08283, 32.880463, 29.678097],
]
]
)
),
"seg": MetaTensor(torch.tensor([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])),
"seg": MetaTensor(torch.tensor([[[19.0, 12.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])),
},
]
)
Expand Down Expand Up @@ -168,14 +168,14 @@
np.array(
[
[
[18.736153, 15.581954, 12.4277525],
[27.398798, 24.244598, 21.090399],
[36.061443, 32.90724, 29.753046],
[18.757534, 15.55517, 12.352802],
[27.420181, 24.217815, 21.01545],
[36.08283, 32.880463, 29.678097],
]
]
)
),
"seg": MetaTensor(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])),
"seg": MetaTensor(np.array([[[19.0, 12.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])),
},
]
)
Expand All @@ -202,14 +202,14 @@
torch.tensor(
[
[
[18.736153, 15.581954, 12.4277525],
[27.398798, 24.244598, 21.090399],
[36.061443, 32.90724, 29.753046],
[18.757534, 15.55517, 12.352802],
[27.420181, 24.217815, 21.01545],
[36.08283, 32.880463, 29.678097],
]
]
)
),
"seg": MetaTensor(torch.tensor([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])),
"seg": MetaTensor(torch.tensor([[[19.0, 12.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])),
},
]
)
Expand Down
Loading