Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def convert_to_channel_last(
data = data[..., 0, :]
# if desired, remove trailing singleton dimensions
while squeeze_end_dims and data.shape[-1] == 1:
data = np.squeeze(data, -1)
data = data.squeeze(-1)
if contiguous:
data = ascontiguousarray(data)
return data
Expand Down
11 changes: 9 additions & 2 deletions tests/data/test_itk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import torch

from monai.data import ITKWriter
from monai.data import ITKWriter, MetaTensor
from monai.utils import optional_import

itk, has_itk = optional_import("itk")
Expand All @@ -45,7 +45,7 @@ def test_rgb(self):
with tempfile.TemporaryDirectory() as tempdir:
fname = os.path.join(tempdir, "testing.png")
writer = ITKWriter(output_dtype=np.uint8)
writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=0)
writer.set_data_array(torch.arange(48).reshape(3, 4, 4), channel_dim=0)
writer.set_metadata({"spatial_shape": (5, 5)})
writer.write(fname)

Expand All @@ -64,6 +64,13 @@ def test_no_channel(self):
np.testing.assert_allclose(output.shape, (4, 4, 3))
np.testing.assert_allclose(output[1, 1], (5, 21, 37))

def test_metatensor_preserved(self):
data = MetaTensor(np.arange(48).reshape(3, 4, 4, 1), meta={"test_key": "test_value"})
writer = ITKWriter()
writer.set_data_array(data, channel_dim=-1, squeeze_end_dims=True)
self.assertIsInstance(writer.data_obj, MetaTensor)
self.assertEqual(writer.data_obj.meta.get("test_key"), "test_value")
Comment on lines +67 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add docstring and consider additional test coverage.

Missing docstring violates coding guidelines. Additionally, consider testing more scenarios:

  • Different channel_dim values (e.g., 0, 1)
  • squeeze_end_dims=False case
  • Verifying the squeezed shape

Apply this diff to add a docstring:

     def test_metatensor_preserved(self):
+        """Test that MetaTensor and its metadata are preserved when squeeze_end_dims=True."""
         data = MetaTensor(np.arange(48).reshape(3, 4, 4, 1), meta={"test_key": "test_value"})

Consider expanding the test:

def test_metatensor_preserved(self):
    """Test that MetaTensor and its metadata are preserved when squeeze_end_dims=True."""
    data = MetaTensor(np.arange(48).reshape(3, 4, 4, 1), meta={"test_key": "test_value"})
    writer = ITKWriter()
    writer.set_data_array(data, channel_dim=-1, squeeze_end_dims=True)
    self.assertIsInstance(writer.data_obj, MetaTensor)
    self.assertEqual(writer.data_obj.meta.get("test_key"), "test_value")
    # Verify shape after squeeze
    self.assertEqual(writer.data_obj.shape, (3, 4, 4))
🤖 Prompt for AI Agents
In tests/data/test_itk_writer.py around lines 67 to 72, add a docstring to the
test_metatensor_preserved method and expand assertions to cover more scenarios:
keep the existing MetaTensor and metadata checks, assert the shape after squeeze
for the squeeze_end_dims=True case, add at least one additional case testing a
different channel_dim (e.g., 0 or 1) and assert expected shape/metadata, and add
a case with squeeze_end_dims=False to verify the shape remains unchanged;
implement these as additional assertions or small helper calls within the same
test (or split into small focused tests) so the test verifies preservation of
MetaTensor, metadata, and correct squeezing behavior for multiple parameter
combinations.



if __name__ == "__main__":
unittest.main()
Loading