From 3a22af77a67eff399bc911165062e94a48d797c0 Mon Sep 17 00:00:00 2001 From: Dimitris Karkalousos Date: Wed, 11 Jun 2025 23:40:29 +0000 Subject: [PATCH] Implements logic for the muscle_adipose_tissue model & fixes DICOM reading --- bin/C2C | 7 ++- comp2comp/io/io.py | 2 +- comp2comp/muscle_adipose_tissue/data.py | 2 +- .../muscle_adipose_tissue.py | 63 +++++++++++++++++++ 4 files changed, 70 insertions(+), 4 deletions(-) diff --git a/bin/C2C b/bin/C2C index 3b8ac892..2b203484 100755 --- a/bin/C2C +++ b/bin/C2C @@ -51,15 +51,16 @@ def MuscleAdiposeTissuePipelineBuilder(args): muscle_adipose_tissue.MuscleAdiposeTissueComputeMetrics(), muscle_adipose_tissue_visualization.MuscleAdiposeTissueVisualizer(), muscle_adipose_tissue.MuscleAdiposeTissueH5Saver(), + muscle_adipose_tissue.MuscleAdiposeTissueNiftiSaver(), muscle_adipose_tissue.MuscleAdiposeTissueMetricsSaver(), ] ) return pipeline -def MuscleAdiposeTissueFullPipelineBuilder(args): +def MuscleAdiposeTissueFullPipelineBuilder(path, args): pipeline = InferencePipeline( - [io.DicomFinder(args.input_path), MuscleAdiposeTissuePipelineBuilder(args)] + [io.DicomFinder(path), MuscleAdiposeTissuePipelineBuilder(args)] ) return pipeline @@ -269,6 +270,8 @@ def main(): args = argument_parser().parse_args() if args.pipeline == "spine_muscle_adipose_tissue": process_3d(args, SpineMuscleAdiposeTissuePipelineBuilder) + elif args.pipeline == "muscle_adipose_tissue": + process_3d(args, MuscleAdiposeTissueFullPipelineBuilder) elif args.pipeline == "spine": process_3d(args, SpinePipelineBuilder) elif args.pipeline == "contrast_phase": diff --git a/comp2comp/io/io.py b/comp2comp/io/io.py index 9870e41d..484dce14 100644 --- a/comp2comp/io/io.py +++ b/comp2comp/io/io.py @@ -96,7 +96,7 @@ def __call__(self, inference_pipeline): if self.input_path.is_dir(): # store a dcm object for retrieving dicom tags dcm_files = [d for d in os.listdir(self.input_path) if d.endswith('.dcm')] - inference_pipeline.dcm = pydicom.read_file(os.path.join(self.input_path, dcm_files[0])) + inference_pipeline.dcm = pydicom.dcmread(os.path.join(self.input_path, dcm_files[0])) ds = dicom_series_to_nifti( self.input_path, diff --git a/comp2comp/muscle_adipose_tissue/data.py b/comp2comp/muscle_adipose_tissue/data.py index 54843bf8..39cc7830 100644 --- a/comp2comp/muscle_adipose_tissue/data.py +++ b/comp2comp/muscle_adipose_tissue/data.py @@ -84,7 +84,7 @@ def __len__(self): def __getitem__(self, idx): files = self._files[idx * self._batch_size : (idx + 1) * self._batch_size] - dcms = [pydicom.read_file(f, force=True) for f in files] + dcms = [pydicom.dcmread(f, force=True) for f in files] xs = [(x.pixel_array + int(x.RescaleIntercept)).astype("float32") for x in dcms] diff --git a/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue.py b/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue.py index 6ec41cb1..b149af9a 100644 --- a/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue.py +++ b/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue.py @@ -1,4 +1,5 @@ import os +import re import zipfile from pathlib import Path from time import perf_counter @@ -9,6 +10,7 @@ import nibabel as nib import numpy as np import pandas as pd +import SimpleITK as sitk import wget from keras import backend as K from tqdm import tqdm @@ -347,6 +349,11 @@ def compute_metrics(self, x, mask, spacing): hu_vals = np.nan_to_num(hu_vals) csa_vals = np.nan_to_num(csa_vals) + if mask.shape[-1] != len(categories): + # TODO: Handle this properly. This is a hard fix removing the BG class, + # which is added by the abCT_v0.0.1 model in the end. + mask = mask[..., :-1] + assert mask.shape[-1] == len( categories ), "{} categories found in mask, " "but only {} categories specified".format( @@ -397,6 +404,62 @@ def save_results(self, results): f.create_dataset(name=cat, data=np.array(mask, dtype=np.uint8)) +def natural_sort_key(s): + """ + Create a key for sorting strings in a 'natural' order. e.g., 'slice10' comes after 'slice2'. + """ + return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)] + + +class MuscleAdiposeTissueNiftiSaver(InferenceClass): + """ + Saves the multi-class muscle and adipose tissue segmentations as a single multi-labeled NIfTI file. + """ + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline, images, results): + """Orchestrates the entire saving and assembly process.""" + self.model_type = inference_pipeline.muscle_adipose_tissue_model_type + self.output_dir = inference_pipeline.output_dir + self.nifti_output_dir = os.path.join(self.output_dir, "segmentations") + self.dicom_file_names = inference_pipeline.dicom_file_names + os.makedirs(self.nifti_output_dir, exist_ok=True) + self.spacings = getattr(inference_pipeline, 'spacings', None) + self.save_results(results) + return {"results": results} + + def save_results(self, results): + """Saves NIfTI file.""" + categories = self.model_type.categories + + slices = {} + for i, result in enumerate(results): + file_name = self.dicom_file_names[i] + first_cat_name = list(categories.keys())[0] + if first_cat_name not in result: + continue + mask_shape = result[first_cat_name]["mask"].shape + multi_label_slice = np.zeros(mask_shape, dtype=np.uint8) + for class_name, label in categories.items(): + if class_name in result: + class_mask = result[class_name]["mask"] + multi_label_slice[class_mask > 0] = label + 1 + slices[file_name] = multi_label_slice + slices = [slices[fname] for fname in sorted(slices.keys(), key=natural_sort_key)] + + final_image = sitk.GetImageFromArray(np.stack(slices, axis=0)[::-1, :, :]) + if self.spacings and len(self.spacings) > 0: + # Assumes spacing is (x, y, z) + final_spacing = tuple(float(s) for s in self.spacings[0]) + final_image.SetSpacing(final_spacing) + else: + final_image.SetSpacing((1.0, 1.0, 1.0)) + + sitk.WriteImage(final_image, os.path.join(self.nifti_output_dir, "muscle_adipose_tissue_seg.nii.gz")) + + class MuscleAdiposeTissueMetricsSaver(InferenceClass): """Save metrics to a CSV file."""