Skip to content
Open
23 changes: 19 additions & 4 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,19 +1049,34 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform):
which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor):
label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion,
label 2 is the peritumoral edema, which is counted only under WT subregion,
label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.
the specified `et_label` (default 4) is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.

Args:
et_label: the label used for the GD-enhancing tumor (ET).
- Use 4 for BraTS 2018-2022.
- Use 3 for BraTS 2023.
Defaults to 4.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, et_label: int = 4) -> None:
if et_label in (1, 2):
raise ValueError(f"et_label cannot be 1 or 2, as these are reserved. Got {et_label}.")
self.et_label = et_label

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
# if img has channel dim, squeeze it
if img.ndim == 4 and img.shape[0] == 1:
img = img.squeeze(0)

result = [(img == 1) | (img == 4), (img == 1) | (img == 4) | (img == 2), img == 4]
# merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT
# label 4 is ET
result = [
(img == 1) | (img == self.et_label),
(img == 1) | (img == self.et_label) | (img == 2),
img == self.et_label,
]
# merge labels 1 (tumor non-enh) and self.et_label (tumor enh) and 2 (large edema) to WT
# self.et_label is ET (4 or 3)
return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)


Expand Down
16 changes: 12 additions & 4 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,19 +1297,27 @@ def __call__(self, data: Mapping[Hashable, Any]):
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`.
Convert labels to multi channels based on brats18 classes:
Convert labels to multi channels based on brats classes:
label 1 is the necrotic and non-enhancing tumor core
label 2 is the peritumoral edema
label 4 is the GD-enhancing tumor
the specified `et_label` (default 4) is the GD-enhancing tumor
The possible classes are TC (Tumor core), WT (Whole tumor)
and ET (Enhancing tumor).

Args:
keys: keys of the corresponding items to be transformed.
et_label: the label used for the GD-enhancing tumor (ET).
- Use 4 for BraTS 2018-2022.
- Use 3 for BraTS 2023.
Defaults to 4.
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ConvertToMultiChannelBasedOnBratsClasses.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
def __init__(self, keys: KeysCollection, et_label: int = 4, allow_missing_keys: bool = False):
super().__init__(keys, allow_missing_keys)
self.converter = ConvertToMultiChannelBasedOnBratsClasses()
self.converter = ConvertToMultiChannelBasedOnBratsClasses(et_label=et_label)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
Expand Down
32 changes: 32 additions & 0 deletions tests/transforms/test_convert_to_multi_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from tests.test_utils import TEST_NDARRAYS, assert_allclose

TESTS = []
TESTS_ET_LABEL_3 = []

# Tests for default et_label = 4
for p in TEST_NDARRAYS:
TESTS.extend(
[
Expand All @@ -46,6 +49,23 @@
]
)

# Tests for et_label = 3
for p in TEST_NDARRAYS:
TESTS_ET_LABEL_3.extend(
[
[
p([[0, 1, 2], [1, 2, 3], [0, 1, 3]]),
p(
[
[[0, 1, 0], [1, 0, 1], [0, 1, 1]],
[[0, 1, 1], [1, 1, 1], [0, 1, 1]],
[[0, 0, 0], [0, 0, 1], [0, 0, 1]],
]
),
],
]
)


class TestConvertToMultiChannel(unittest.TestCase):
@parameterized.expand(TESTS)
Expand All @@ -54,6 +74,18 @@ def test_type_shape(self, data, expected_result):
assert_allclose(result, expected_result)
self.assertTrue(result.dtype in (bool, torch.bool))

@parameterized.expand(TESTS_ET_LABEL_3)
def test_type_shape_et_label_3(self, data, expected_result):
result = ConvertToMultiChannelBasedOnBratsClasses(et_label=3)(data)
assert_allclose(result, expected_result)
self.assertTrue(result.dtype in (bool, torch.bool))

def test_invalid_et_label(self):
with self.assertRaises(ValueError):
ConvertToMultiChannelBasedOnBratsClasses(et_label=1)
with self.assertRaises(ValueError):
ConvertToMultiChannelBasedOnBratsClasses(et_label=2)


if __name__ == "__main__":
unittest.main()
11 changes: 11 additions & 0 deletions tests/transforms/test_convert_to_multi_channeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]),
]

TEST_CASE_ET_LABEL_3 = [
{"keys": "label", "et_label": 3},
{"label": np.array([[0, 1, 2], [1, 2, 3], [0, 1, 3]])},
np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]),
]


class TestConvertToMultiChanneld(unittest.TestCase):

Expand All @@ -32,6 +38,11 @@ def test_type_shape(self, keys, data, expected_result):
result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)
np.testing.assert_equal(result["label"], expected_result)

@parameterized.expand([TEST_CASE_ET_LABEL_3])
def test_et_label_3(self, keys, data, expected_result):
result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)
np.testing.assert_equal(result["label"], expected_result)


if __name__ == "__main__":
unittest.main()