diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3dc7897feb..a13535b5b9 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1049,19 +1049,34 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor): label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion, label 2 is the peritumoral edema, which is counted only under WT subregion, - label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions. + the specified `et_label` (default 4) is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions. + + Args: + et_label: the label used for the GD-enhancing tumor (ET). + - Use 4 for BraTS 2018-2022. + - Use 3 for BraTS 2023. + Defaults to 4. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, et_label: int = 4) -> None: + if et_label in (1, 2): + raise ValueError(f"et_label cannot be 1 or 2, as these are reserved. Got {et_label}.") + self.et_label = et_label + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: img = img.squeeze(0) - result = [(img == 1) | (img == 4), (img == 1) | (img == 4) | (img == 2), img == 4] - # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT - # label 4 is ET + result = [ + (img == 1) | (img == self.et_label), + (img == 1) | (img == self.et_label) | (img == 2), + img == self.et_label, + ] + # merge labels 1 (tumor non-enh) and self.et_label (tumor enh) and 2 (large edema) to WT + # self.et_label is ET (4 or 3) return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 95c59e07bc..71b61d924d 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1297,19 +1297,27 @@ def __call__(self, data: Mapping[Hashable, Any]): class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`. - Convert labels to multi channels based on brats18 classes: + Convert labels to multi channels based on brats classes: label 1 is the necrotic and non-enhancing tumor core label 2 is the peritumoral edema - label 4 is the GD-enhancing tumor + the specified `et_label` (default 4) is the GD-enhancing tumor The possible classes are TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor). + + Args: + keys: keys of the corresponding items to be transformed. + et_label: the label used for the GD-enhancing tumor (ET). + - Use 4 for BraTS 2018-2022. + - Use 3 for BraTS 2023. + Defaults to 4. + allow_missing_keys: don't raise exception if key is missing. """ backend = ConvertToMultiChannelBasedOnBratsClasses.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): + def __init__(self, keys: KeysCollection, et_label: int = 4, allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) - self.converter = ConvertToMultiChannelBasedOnBratsClasses() + self.converter = ConvertToMultiChannelBasedOnBratsClasses(et_label=et_label) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) diff --git a/tests/transforms/test_convert_to_multi_channel.py b/tests/transforms/test_convert_to_multi_channel.py index dfa324b6b9..ed9271842e 100644 --- a/tests/transforms/test_convert_to_multi_channel.py +++ b/tests/transforms/test_convert_to_multi_channel.py @@ -20,6 +20,9 @@ from tests.test_utils import TEST_NDARRAYS, assert_allclose TESTS = [] +TESTS_ET_LABEL_3 = [] + +# Tests for default et_label = 4 for p in TEST_NDARRAYS: TESTS.extend( [ @@ -46,6 +49,23 @@ ] ) +# Tests for et_label = 3 +for p in TEST_NDARRAYS: + TESTS_ET_LABEL_3.extend( + [ + [ + p([[0, 1, 2], [1, 2, 3], [0, 1, 3]]), + p( + [ + [[0, 1, 0], [1, 0, 1], [0, 1, 1]], + [[0, 1, 1], [1, 1, 1], [0, 1, 1]], + [[0, 0, 0], [0, 0, 1], [0, 0, 1]], + ] + ), + ], + ] + ) + class TestConvertToMultiChannel(unittest.TestCase): @parameterized.expand(TESTS) @@ -54,6 +74,18 @@ def test_type_shape(self, data, expected_result): assert_allclose(result, expected_result) self.assertTrue(result.dtype in (bool, torch.bool)) + @parameterized.expand(TESTS_ET_LABEL_3) + def test_type_shape_et_label_3(self, data, expected_result): + result = ConvertToMultiChannelBasedOnBratsClasses(et_label=3)(data) + assert_allclose(result, expected_result) + self.assertTrue(result.dtype in (bool, torch.bool)) + + def test_invalid_et_label(self): + with self.assertRaises(ValueError): + ConvertToMultiChannelBasedOnBratsClasses(et_label=1) + with self.assertRaises(ValueError): + ConvertToMultiChannelBasedOnBratsClasses(et_label=2) + if __name__ == "__main__": unittest.main() diff --git a/tests/transforms/test_convert_to_multi_channeld.py b/tests/transforms/test_convert_to_multi_channeld.py index e482770497..1cea3a6919 100644 --- a/tests/transforms/test_convert_to_multi_channeld.py +++ b/tests/transforms/test_convert_to_multi_channeld.py @@ -24,6 +24,12 @@ np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), ] +TEST_CASE_ET_LABEL_3 = [ + {"keys": "label", "et_label": 3}, + {"label": np.array([[0, 1, 2], [1, 2, 3], [0, 1, 3]])}, + np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), +] + class TestConvertToMultiChanneld(unittest.TestCase): @@ -32,6 +38,11 @@ def test_type_shape(self, keys, data, expected_result): result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data) np.testing.assert_equal(result["label"], expected_result) + @parameterized.expand([TEST_CASE_ET_LABEL_3]) + def test_et_label_3(self, keys, data, expected_result): + result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data) + np.testing.assert_equal(result["label"], expected_result) + if __name__ == "__main__": unittest.main()