From 128312517ab49bdaf5084d0ca37d86846830ae3b Mon Sep 17 00:00:00 2001 From: Davis Vigneault Date: Thu, 4 Dec 2025 09:22:05 -0800 Subject: [PATCH 1/2] Prevent implicit conversion of metatensor to numpy array Signed-off-by: Davis Vigneault --- monai/data/image_writer.py | 2 +- tests/data/test_itk_writer.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index b9e8b9e68e..b881406f02 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -324,7 +324,7 @@ def convert_to_channel_last( data = data[..., 0, :] # if desired, remove trailing singleton dimensions while squeeze_end_dims and data.shape[-1] == 1: - data = np.squeeze(data, -1) + data = data.squeeze(-1) if contiguous: data = ascontiguousarray(data) return data diff --git a/tests/data/test_itk_writer.py b/tests/data/test_itk_writer.py index 6625339dd0..beb17f01af 100644 --- a/tests/data/test_itk_writer.py +++ b/tests/data/test_itk_writer.py @@ -18,7 +18,7 @@ import numpy as np import torch -from monai.data import ITKWriter +from monai.data import ITKWriter, MetaTensor from monai.utils import optional_import itk, has_itk = optional_import("itk") @@ -64,6 +64,13 @@ def test_no_channel(self): np.testing.assert_allclose(output.shape, (4, 4, 3)) np.testing.assert_allclose(output[1, 1], (5, 21, 37)) + def test_metatensor_preserved(self): + data = MetaTensor(np.arange(48).reshape(3, 4, 4, 1), meta={"test_key": "test_value"}) + writer = ITKWriter() + writer.set_data_array(data, channel_dim=-1, squeeze_end_dims=True) + self.assertIsInstance(writer.data_obj, MetaTensor) + self.assertEqual(writer.data_obj.meta.get("test_key"), "test_value") + if __name__ == "__main__": unittest.main() From ced515ba25898fc00449d98612682f014340db10 Mon Sep 17 00:00:00 2001 From: Davis Vigneault Date: Thu, 4 Dec 2025 11:11:09 -0800 Subject: [PATCH 2/2] Address CodeRabbitAI Feedback Signed-off-by: Davis Vigneault --- tests/data/test_itk_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_itk_writer.py b/tests/data/test_itk_writer.py index beb17f01af..796fe74eef 100644 --- a/tests/data/test_itk_writer.py +++ b/tests/data/test_itk_writer.py @@ -45,7 +45,7 @@ def test_rgb(self): with tempfile.TemporaryDirectory() as tempdir: fname = os.path.join(tempdir, "testing.png") writer = ITKWriter(output_dtype=np.uint8) - writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=0) + writer.set_data_array(torch.arange(48).reshape(3, 4, 4), channel_dim=0) writer.set_metadata({"spatial_shape": (5, 5)}) writer.write(fname)