From 4f0a90b984b6803a1df81879767e038506bcfea5 Mon Sep 17 00:00:00 2001 From: Abdessamad Date: Tue, 17 Mar 2026 18:23:55 +0000 Subject: [PATCH 1/9] Add support for configurable GD-enhancing tumor label in ConvertToMultiChannelBasedOnBratsClasses Signed-off-by: Abdessamad --- monai/transforms/utility/array.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3dc7897feb..72c96a9525 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1050,18 +1050,27 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion, label 2 is the peritumoral edema, which is counted only under WT subregion, label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions. + + Args: + et_label: the label used for the GD-enhancing tumor (ET). + - Use 4 for BraTS 2018-2022. + - Use 3 for BraTS 2023. + Defaults to 4. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, et_label: int = 4) -> None: + self.et_label = et_label + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: img = img.squeeze(0) - result = [(img == 1) | (img == 4), (img == 1) | (img == 4) | (img == 2), img == 4] - # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT - # label 4 is ET + result = [(img == 1) | (img == self.et_label), (img == 1) | (img == self.et_label) | (img == 2), img == self.et_label] + # merge labels 1 (tumor non-enh) and self.et_label (tumor enh) and 2 (large edema) to WT + # self.et_label is ET (4 or 3) return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0) From 0a7e9c2f82c459037d41aaa3e87e8b66be7e8452 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 18:35:29 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Abdessamad --- monai/transforms/utility/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 72c96a9525..775eed2eff 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1052,7 +1052,7 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions. Args: - et_label: the label used for the GD-enhancing tumor (ET). + et_label: the label used for the GD-enhancing tumor (ET). - Use 4 for BraTS 2018-2022. - Use 3 for BraTS 2023. Defaults to 4. From 40f812969691ffc19ac50e63dc9839f91c5cdb26 Mon Sep 17 00:00:00 2001 From: Abdessamad Date: Tue, 17 Mar 2026 21:04:36 +0000 Subject: [PATCH 3/9] Add validation for et_label in ConvertToMultiChannelBasedOnBratsClasses and extend tests Signed-off-by: Abdessamad --- monai/transforms/utility/array.py | 2 ++ .../test_convert_to_multi_channel.py | 32 ++++++++++++++++++- .../test_convert_to_multi_channeld.py | 13 +++++++- 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 775eed2eff..d8dd149af8 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1061,6 +1061,8 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, et_label: int = 4) -> None: + if et_label in (1, 2): + raise ValueError(f"et_label cannot be 1 or 2, these are reserved.Got {et_label}.") self.et_label = et_label def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: diff --git a/tests/transforms/test_convert_to_multi_channel.py b/tests/transforms/test_convert_to_multi_channel.py index dfa324b6b9..03e87eb5ec 100644 --- a/tests/transforms/test_convert_to_multi_channel.py +++ b/tests/transforms/test_convert_to_multi_channel.py @@ -20,6 +20,9 @@ from tests.test_utils import TEST_NDARRAYS, assert_allclose TESTS = [] +TESTS_ET_LABEL_3 = [] + +# Tests for default et_label = 4 for p in TEST_NDARRAYS: TESTS.extend( [ @@ -46,6 +49,22 @@ ] ) +# Tests for et_label = 3 +for p in TEST_NDARRAYS: + TESTS_ET_LABEL_3.extend( + [ + [ + p([[0, 1, 2], [1, 2, 3], [0, 1, 3]]), + p( + [ + [[0, 1, 0], [1, 0, 1], [0, 1, 1]], + [[0, 1, 1], [1, 1, 1], [0, 1, 1]], + [[0, 0, 0], [0, 0, 1], [0, 0, 1]], + ] + ), + ], + ] + ) class TestConvertToMultiChannel(unittest.TestCase): @parameterized.expand(TESTS) @@ -54,6 +73,17 @@ def test_type_shape(self, data, expected_result): assert_allclose(result, expected_result) self.assertTrue(result.dtype in (bool, torch.bool)) + @parameterized.expand(TESTS_ET_LABEL_3) + def test_type_shape_et_label_3(self, data, expected_result): + result = ConvertToMultiChannelBasedOnBratsClasses(et_label=3)(data) + assert_allclose(result, expected_result) + self.assertTrue(result.dtype in (bool, torch.bool)) + + def test_invalid_et_label(self): + with self.assertRaises(ValueError): + ConvertToMultiChannelBasedOnBratsClasses(et_label=1) + with self.assertRaises(ValueError): + ConvertToMultiChannelBasedOnBratsClasses(et_label=2) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/transforms/test_convert_to_multi_channeld.py b/tests/transforms/test_convert_to_multi_channeld.py index e482770497..784d5832f8 100644 --- a/tests/transforms/test_convert_to_multi_channeld.py +++ b/tests/transforms/test_convert_to_multi_channeld.py @@ -24,6 +24,12 @@ np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), ] +TEST_CASE_ET_LABEL_3 = [ + {"keys": "label", "et_label": 3}, + {"label": np.array([[0, 1, 2], [1, 2, 3], [0, 1, 3]])}, + np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), +] + class TestConvertToMultiChanneld(unittest.TestCase): @@ -32,6 +38,11 @@ def test_type_shape(self, keys, data, expected_result): result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data) np.testing.assert_equal(result["label"], expected_result) + @parameterized.expand([TEST_CASE_ET_LABEL_3]) + def test_et_label_3(self, keys, data, expected_result): + result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data) + np.testing.assert_equal(result["label"], expected_result) + if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file From c586d46bd8d686651ea62a77ae211d9613e31ffb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 21:05:09 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Abdessamad --- tests/transforms/test_convert_to_multi_channel.py | 2 +- tests/transforms/test_convert_to_multi_channeld.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transforms/test_convert_to_multi_channel.py b/tests/transforms/test_convert_to_multi_channel.py index 03e87eb5ec..ce6e588f66 100644 --- a/tests/transforms/test_convert_to_multi_channel.py +++ b/tests/transforms/test_convert_to_multi_channel.py @@ -86,4 +86,4 @@ def test_invalid_et_label(self): ConvertToMultiChannelBasedOnBratsClasses(et_label=2) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/transforms/test_convert_to_multi_channeld.py b/tests/transforms/test_convert_to_multi_channeld.py index 784d5832f8..1cea3a6919 100644 --- a/tests/transforms/test_convert_to_multi_channeld.py +++ b/tests/transforms/test_convert_to_multi_channeld.py @@ -45,4 +45,4 @@ def test_et_label_3(self, keys, data, expected_result): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 9726c9263b6f16429d08a2e75f3cf005979d6c3e Mon Sep 17 00:00:00 2001 From: Abdessamad Date: Tue, 17 Mar 2026 21:29:28 +0000 Subject: [PATCH 5/9] Update dictionary wrapper to accept et_label Signed-off-by: Abdessamad --- monai/transforms/utility/dictionary.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 95c59e07bc..44292333ea 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1297,19 +1297,27 @@ def __call__(self, data: Mapping[Hashable, Any]): class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`. - Convert labels to multi channels based on brats18 classes: + Convert labels to multi channels based on brats classes: label 1 is the necrotic and non-enhancing tumor core label 2 is the peritumoral edema label 4 is the GD-enhancing tumor The possible classes are TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor). + + Args: + keys: keys of the corresponding items to be transformed. + et_label: the label used for the GD-enhancing tumor (ET). + - Use 4 for BraTS 2018-2022. + - Use 3 for BraTS 2023. + Defaults to 4. + allow_missing_keys: don't raise exception if key is missing. """ backend = ConvertToMultiChannelBasedOnBratsClasses.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): + def __init__(self, keys: KeysCollection, et_label: int = 4, allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) - self.converter = ConvertToMultiChannelBasedOnBratsClasses() + self.converter = ConvertToMultiChannelBasedOnBratsClasses(et_label=et_label) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) From 85bb3123c3a77c246c1bcadbe0fc04f8c50cbee8 Mon Sep 17 00:00:00 2001 From: Abdessamad Date: Tue, 17 Mar 2026 21:43:49 +0000 Subject: [PATCH 6/9] Clarify documentation for GD-enhancing tumor label in ConvertToMultiChannelBasedOnBratsClasses Signed-off-by: Abdessamad --- monai/transforms/utility/array.py | 2 +- monai/transforms/utility/dictionary.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index d8dd149af8..2361d56554 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1049,7 +1049,7 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor): label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion, label 2 is the peritumoral edema, which is counted only under WT subregion, - label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions. + the specified `et_label` (default 4) is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions. Args: et_label: the label used for the GD-enhancing tumor (ET). diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 44292333ea..71b61d924d 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1300,7 +1300,7 @@ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): Convert labels to multi channels based on brats classes: label 1 is the necrotic and non-enhancing tumor core label 2 is the peritumoral edema - label 4 is the GD-enhancing tumor + the specified `et_label` (default 4) is the GD-enhancing tumor The possible classes are TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor). From 48c00c075bbb9bb0840af9aad0b05f716ab00d4f Mon Sep 17 00:00:00 2001 From: Abdessamad Date: Tue, 17 Mar 2026 21:57:57 +0000 Subject: [PATCH 7/9] Fix typo in ValueError message for reserved et_label values in ConvertToMultiChannelBasedOnBratsClasses Signed-off-by: Abdessamad --- monai/transforms/utility/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2361d56554..9e4144e2bf 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1062,7 +1062,7 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): def __init__(self, et_label: int = 4) -> None: if et_label in (1, 2): - raise ValueError(f"et_label cannot be 1 or 2, these are reserved.Got {et_label}.") + raise ValueError(f"et_label cannot be 1 or 2, as these are reserved. Got {et_label}.") self.et_label = et_label def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: From ce817c168ff20bdda3f588ccd9759f87319c7c66 Mon Sep 17 00:00:00 2001 From: Abdessamad Date: Wed, 18 Mar 2026 13:37:02 +0000 Subject: [PATCH 8/9] Fix formatting in array.py Signed-off-by: Abdessamad --- monai/transforms/utility/array.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9e4144e2bf..a13535b5b9 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1070,7 +1070,11 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: if img.ndim == 4 and img.shape[0] == 1: img = img.squeeze(0) - result = [(img == 1) | (img == self.et_label), (img == 1) | (img == self.et_label) | (img == 2), img == self.et_label] + result = [ + (img == 1) | (img == self.et_label), + (img == 1) | (img == self.et_label) | (img == 2), + img == self.et_label, + ] # merge labels 1 (tumor non-enh) and self.et_label (tumor enh) and 2 (large edema) to WT # self.et_label is ET (4 or 3) return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0) From ef36a588be19dd9fd6cdc997846f528d124629c4 Mon Sep 17 00:00:00 2001 From: Abdessamad Date: Wed, 18 Mar 2026 15:27:57 +0000 Subject: [PATCH 9/9] Fix formatting in test_convert_to_multi_channel.py Signed-off-by: Abdessamad --- tests/transforms/test_convert_to_multi_channel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/transforms/test_convert_to_multi_channel.py b/tests/transforms/test_convert_to_multi_channel.py index ce6e588f66..ed9271842e 100644 --- a/tests/transforms/test_convert_to_multi_channel.py +++ b/tests/transforms/test_convert_to_multi_channel.py @@ -66,6 +66,7 @@ ] ) + class TestConvertToMultiChannel(unittest.TestCase): @parameterized.expand(TESTS) def test_type_shape(self, data, expected_result): @@ -85,5 +86,6 @@ def test_invalid_et_label(self): with self.assertRaises(ValueError): ConvertToMultiChannelBasedOnBratsClasses(et_label=2) + if __name__ == "__main__": unittest.main()