diff --git a/bin/C2C b/bin/C2C index 3b32ddde..be26e24a 100755 --- a/bin/C2C +++ b/bin/C2C @@ -13,21 +13,61 @@ from comp2comp.aortic_calcium import ( from comp2comp.contrast_phase.contrast_phase import ContrastPhaseDetection from comp2comp.hip import hip from comp2comp.inference_pipeline import InferencePipeline -from comp2comp.io import io +from comp2comp.io import fda_io, io from comp2comp.liver_spleen_pancreas import ( liver_spleen_pancreas, liver_spleen_pancreas_visualization, ) from comp2comp.muscle_adipose_tissue import ( + fda_muscle_adipose_tissue, muscle_adipose_tissue, muscle_adipose_tissue_visualization, ) -from comp2comp.spine import spine +from comp2comp.spine import fda_spine, spine from comp2comp.utils import orientation from comp2comp.utils.process import process_3d os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# FDA approved BunkeHill BMD algorithm +def FdaBmdPipelineBuilder(path, args): + pipeline = InferencePipeline( + [ + FdaSpinePipelineBuilder(path, args), + fda_spine.SpineFindDicoms(), + FdaMuscleAdiposeTissuePipelineBuilder(args), + ] + ) + return pipeline + +def FdaSpinePipelineBuilder(path, args): + pipeline = InferencePipeline( + [ + fda_io.DicomToNifti(path), + fda_spine.SpineSegmentation(args.spine_model, save=True), + orientation.ToCanonical(), + fda_spine.SpineComputeROIs(args.spine_model), + fda_spine.SpineMetricsSaver(), + ] + ) + return pipeline + +def FdaMuscleAdiposeTissuePipelineBuilder(args): + pipeline = InferencePipeline( + [ + fda_muscle_adipose_tissue.MuscleAdiposeTissueSegmentation( + 16, args.muscle_fat_model + ), + fda_muscle_adipose_tissue.MuscleAdiposeTissuePostProcessing(), + fda_muscle_adipose_tissue.MuscleAdiposeTissueComputeMetrics(), + # fda_muscle_adipose_tissue_visualization.MuscleAdiposeTissueVisualizer(), + fda_muscle_adipose_tissue.MuscleAdiposeTissueH5Saver(), + fda_muscle_adipose_tissue.MuscleAdiposeTissueMetricsSaver(), + ] + ) + return pipeline + + ### AAA Pipeline def AAAPipelineBuilder(path, args): pipeline = InferencePipeline( @@ -197,6 +237,17 @@ def argument_parser(): "--spine_model", default="stanford_spine_v0.0.1", type=str ) + # FDA approved BMD algorithm (Spine + muscle + fat) + fda_spine_muscle_adipose_tissue_parser = subparsers.add_parser( + "fda_bmd", parents=[base_parser] + ) + fda_spine_muscle_adipose_tissue_parser.add_argument( + "--muscle_fat_model", default="stanford_v0.0.2", type=str + ) + fda_spine_muscle_adipose_tissue_parser.add_argument( + "--spine_model", default="ts_spine", type=str + ) + # Liver spleen pancreas liver_spleen_pancreas = subparsers.add_parser( "liver_spleen_pancreas", parents=[base_parser] @@ -266,7 +317,9 @@ def argument_parser(): def main(): args = argument_parser().parse_args() - if args.pipeline == "spine_muscle_adipose_tissue": + if args.pipeline == "fda_bmd": + process_3d(args, FdaBmdPipelineBuilder) + elif args.pipeline == "spine_muscle_adipose_tissue": process_3d(args, SpineMuscleAdiposeTissuePipelineBuilder) elif args.pipeline == "spine": process_3d(args, SpinePipelineBuilder) @@ -283,7 +336,8 @@ def main(): elif args.pipeline == "all": process_3d(args, AllPipelineBuilder) else: - raise AssertionError("{} command not supported".format(args.action)) + # raise AssertionError("{} command not supported".format(args.action)) + raise AssertionError("{} command not supported".format(args.pipeline)) if __name__ == "__main__": diff --git a/comp2comp/aaa/aaa.py b/comp2comp/aaa/aaa.py index 41e74142..2390d819 100644 --- a/comp2comp/aaa/aaa.py +++ b/comp2comp/aaa/aaa.py @@ -1,7 +1,7 @@ import math import operator import os -import traceback +import traceback import zipfile from pathlib import Path from time import time @@ -397,10 +397,10 @@ def __call__(self, inference_pipeline): ) try: clip.write_videofile(output_dir_summary + "aaa.mp4") - except Exception as e: - print('Error encountered in video generation:\n') + except Exception: + print("Error encountered in video generation:\n") traceback.print_exc() - + return {} diff --git a/comp2comp/aortic_calcium/aortic_calcium.py b/comp2comp/aortic_calcium/aortic_calcium.py index a555a58b..8777f60d 100644 --- a/comp2comp/aortic_calcium/aortic_calcium.py +++ b/comp2comp/aortic_calcium/aortic_calcium.py @@ -12,8 +12,8 @@ import matplotlib.pyplot as plt import numpy as np +import pydicom from scipy import ndimage -import pydicom # from totalsegmentator.libs import ( # download_pretrained_weights, @@ -34,9 +34,9 @@ def __init__(self): def __call__(self, inference_pipeline): # check if kernels are allowed if agatston is used - if inference_pipeline.args.threshold == 'agatston': + if inference_pipeline.args.threshold == "agatston": self.reconKernelChecker(inference_pipeline.dcm) - + # inference_pipeline.dicom_series_path = self.input_path self.output_dir = inference_pipeline.output_dir self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") @@ -115,27 +115,63 @@ def reconKernelChecker(self, dcm): ge_kernels = ["standard", "md stnd"] philips_kernels = ["a", "b", "c", "sa", "sb"] canon_kernels = ["fc08", "fc18"] - siemens_kernels = ["b20s", "b20f", "b30f", "b31s", "b31f", "br34f", "b35f", "bf37f", "br38f", "b41f", - "qr40", "qr40d", "br36f", "br40", "b40f", "br40d", "i30f", "i31f", "i26f", "i31s", - "i40f", "b30s", "br36d", "bf39f", "b41s", "br40f"] + siemens_kernels = [ + "b20s", + "b20f", + "b30f", + "b31s", + "b31f", + "br34f", + "b35f", + "bf37f", + "br38f", + "b41f", + "qr40", + "qr40d", + "br36f", + "br40", + "b40f", + "br40d", + "i30f", + "i31f", + "i26f", + "i31s", + "i40f", + "b30s", + "br36d", + "bf39f", + "b41s", + "br40f", + ] toshiba_kernels = ["fc01", "fc02", "fc07", "fc08", "fc13", "fc18"] - all_kernels = ge_kernels+philips_kernels+canon_kernels+siemens_kernels+toshiba_kernels - - conv_kernel_raw = dcm['ConvolutionKernel'].value - + all_kernels = ( + ge_kernels + + philips_kernels + + canon_kernels + + siemens_kernels + + toshiba_kernels + ) + + conv_kernel_raw = dcm["ConvolutionKernel"].value + if isinstance(conv_kernel_raw, pydicom.multival.MultiValue): conv_kernel = conv_kernel_raw[0].lower() - recon_kernel_extra = str(conv_kernel_raw) + str(conv_kernel_raw) else: conv_kernel = conv_kernel_raw.lower() - recon_kernel_extra = 'n/a' - + if conv_kernel in all_kernels: return True - else: - raise ValueError('Reconstruction kernel not allowed, found: ' + conv_kernel +'\n' - + 'Allowed kernels are: ' + str(all_kernels)) + else: + raise ValueError( + "Reconstruction kernel not allowed, found: " + + conv_kernel + + "\n" + + "Allowed kernels are: " + + str(all_kernels) + ) + class AorticCalciumSegmentation(InferenceClass): """Segmentaiton of aortic calcium""" @@ -168,10 +204,10 @@ def __init__(self): 47: "vertebrae_C4", 48: "vertebrae_C3", 49: "vertebrae_C2", - 50: "vertebrae_C1"} - - self.vertebrae_name = {v: k for k, v in self.vertebrae_num.items()} + 50: "vertebrae_C1", + } + self.vertebrae_name = {v: k for k, v in self.vertebrae_num.items()} def __call__(self, inference_pipeline): @@ -194,15 +230,21 @@ def __call__(self, inference_pipeline): os.makedirs(os.path.join(self.output_dir, "metrics/")) inference_pipeline.ct = inference_pipeline.medical_volume.get_fdata() - inference_pipeline.aorta_mask = (inference_pipeline.segmentation.get_fdata().round().astype(np.int8) == 52) - inference_pipeline.spine_mask = inference_pipeline.spine_segmentation.get_fdata().round().astype(np.uint8) - + inference_pipeline.aorta_mask = ( + inference_pipeline.segmentation.get_fdata().round().astype(np.int8) == 52 + ) + inference_pipeline.spine_mask = ( + inference_pipeline.spine_segmentation.get_fdata().round().astype(np.uint8) + ) + # convert to the index of TotalSegmentator - if inference_pipeline.spine_model_name == 'stanford_spine_v0.0.1': + if inference_pipeline.spine_model_name == "stanford_spine_v0.0.1": tmp_mask = inference_pipeline.spine_mask > 0 - inference_pipeline.spine_mask[tmp_mask] = inference_pipeline.spine_mask[tmp_mask] + 11 + inference_pipeline.spine_mask[tmp_mask] = ( + inference_pipeline.spine_mask[tmp_mask] + 11 + ) del tmp_mask - + spine_mask_bin = inference_pipeline.spine_mask > 0 # Determine the target number of pixels @@ -220,7 +262,7 @@ def __call__(self, inference_pipeline): inference_pipeline.aorta_mask, exclude_mask=spine_mask_bin, remove_size=3, - return_dilated_mask=True, + return_dilated_mask=True, return_eroded_aorta=True, threshold=inference_pipeline.args.threshold, dilation_iteration=target_aorta_dil, @@ -240,7 +282,7 @@ def __call__(self, inference_pipeline): "calcium_segmentations.nii.gz", ), ) - + inference_pipeline.saveArrToNifti( calcification_results["dilated_mask"], os.path.join( @@ -248,7 +290,7 @@ def __call__(self, inference_pipeline): "dilated_aorta_mask.nii.gz", ), ) - + inference_pipeline.saveArrToNifti( calcification_results["aorta_eroded"], os.path.join( @@ -263,7 +305,7 @@ def __call__(self, inference_pipeline): inference_pipeline.output_dir_segmentation_masks, "spine_mask.nii.gz" ), ) - + inference_pipeline.saveArrToNifti( inference_pipeline.aorta_mask, os.path.join( @@ -647,7 +689,7 @@ def getSmallestArraySlice(self, input_mask, margin=0): ) return (slice(x_start, x_end), slice(y_start, y_end), slice(z_start, z_end)) - + class AorticCalciumMetrics(InferenceClass): """Calculate metrics for the aortic calcifications""" @@ -659,7 +701,7 @@ def __call__(self, inference_pipeline): calc_mask = inference_pipeline.calc_mask spine_mask = inference_pipeline.spine_mask aorta_mask = inference_pipeline.aorta_mask - + t12_level = np.where((spine_mask == 32).sum(axis=(0, 1)))[0] l1_level = np.where((spine_mask == 31).sum(axis=(0, 1)))[0] @@ -686,7 +728,7 @@ def __call__(self, inference_pipeline): ), ) inference_pipeline.t12_plane = planes - + inference_pipeline.pix_dims = inference_pipeline.medical_volume.header[ "pixdim" ][1:4] @@ -741,10 +783,11 @@ def __call__(self, inference_pipeline): metrics["volume_total"] = calc_vol metrics["num_calc"] = num_lesions - + # percent of the aorta calcificed - metrics['perc_calcified'] = (calc_mask_region.sum() / aorta_mask_region.sum()) * 100 - + metrics["perc_calcified"] = ( + calc_mask_region.sum() / aorta_mask_region.sum() + ) * 100 if inference_pipeline.args.threshold == "agatston": if num_lesions == 0: diff --git a/comp2comp/aortic_calcium/aortic_calcium_visualization.py b/comp2comp/aortic_calcium/aortic_calcium_visualization.py index 351a8891..7a796fa1 100644 --- a/comp2comp/aortic_calcium/aortic_calcium_visualization.py +++ b/comp2comp/aortic_calcium/aortic_calcium_visualization.py @@ -2,8 +2,13 @@ import numpy as np +from comp2comp.aortic_calcium.visualization_utils import ( + createCalciumMosaic, + createMipPlot, + mergeMipAndMosaic, +) from comp2comp.inference_class_base import InferenceClass -from comp2comp.aortic_calcium.visualization_utils import createMipPlot, createCalciumMosaic, mergeMipAndMosaic + class AorticCalciumVisualizer(InferenceClass): def __init__(self): @@ -16,10 +21,10 @@ def __call__(self, inference_pipeline): if not os.path.exists(self.output_dir_images_organs): os.makedirs(self.output_dir_images_organs) - + # Create MIP part of the overview plot createMipPlot( - inference_pipeline.ct, + inference_pipeline.ct, inference_pipeline.calc_mask, inference_pipeline.aorta_mask, inference_pipeline.t12_plane == 1, @@ -28,26 +33,24 @@ def __call__(self, inference_pipeline): inference_pipeline.metrics, self.output_dir_images_organs, ) - - ab_num = inference_pipeline.metrics['Abdominal']['num_calc'] - th_num = inference_pipeline.metrics['Thoracic']['num_calc'] + + ab_num = inference_pipeline.metrics["Abdominal"]["num_calc"] + th_num = inference_pipeline.metrics["Thoracic"]["num_calc"] # Create mosaic part of the overview plot - if not (ab_num == 0 and th_num == 0): + if not (ab_num == 0 and th_num == 0): createCalciumMosaic( - inference_pipeline.ct, + inference_pipeline.ct, inference_pipeline.calc_mask, - inference_pipeline.dilated_aorta_mask, # the dilated mask is used here + inference_pipeline.dilated_aorta_mask, # the dilated mask is used here inference_pipeline.spine_mask, inference_pipeline.pix_dims, self.output_dir_images_organs, inference_pipeline.args.mosaic_type, ) - - # Merge the two images created above for the final report - mergeMipAndMosaic( - self.output_dir_images_organs - ) - + + # Merge the two images created above for the final report + mergeMipAndMosaic(self.output_dir_images_organs) + return {} @@ -145,7 +148,9 @@ def __call__(self, inference_pipeline): "{},{:.3f}\n".format("Min volume (cm³):", np.min(metrics["volume"])) ) f.write( - "{},{:.3f}\n".format("% Calcified aorta:", metrics["perc_calcified"]) + "{},{:.3f}\n".format( + "% Calcified aorta:", metrics["perc_calcified"] + ) ) if inference_pipeline.args.threshold == "agatston": @@ -228,7 +233,7 @@ def __call__(self, inference_pipeline): metrics["perc_calcified"], ) ) - + if inference_pipeline.args.threshold == "agatston": print( "{:<{}}{:.1f}".format( diff --git a/comp2comp/aortic_calcium/visualization_utils.py b/comp2comp/aortic_calcium/visualization_utils.py index b8b1f143..b19781ba 100644 --- a/comp2comp/aortic_calcium/visualization_utils.py +++ b/comp2comp/aortic_calcium/visualization_utils.py @@ -2,19 +2,18 @@ import os import shutil -import matplotlib.patches as patches import matplotlib.pyplot as plt +import numpy as np from matplotlib.colors import ListedColormap from matplotlib.transforms import Bbox -import nibabel as nib -import numpy as np from numpy.typing import NDArray from PIL import Image # color map used for segmnetations -color_array = plt.get_cmap('Set1')(range(10)) -color_array = np.concatenate( (np.array([[0,0,0,0]]),color_array[:-1,:]), axis=0) -map_object_seg = ListedColormap(name='segmentation_cmap',colors=color_array) +color_array = plt.get_cmap("Set1")(range(10)) +color_array = np.concatenate((np.array([[0, 0, 0, 0]]), color_array[:-1, :]), axis=0) +map_object_seg = ListedColormap(name="segmentation_cmap", colors=color_array) + def createMipPlot( ct: NDArray, @@ -24,27 +23,26 @@ def createMipPlot( HU_val: int, pix_size: NDArray, metrics: dict, - save_root: str - ) -> None: - ''' + save_root: str, +) -> None: + """ Create a MIP projection in the frontal and side plane with - the calcication overlayed. The text box is generated seperately + the calcication overlayed. The text box is generated seperately and then resampled to the MIP - ''' + """ - - ''' + """ Generate MIP image - ''' + """ # Create transparent hot cmap - hot = plt.get_cmap('hot', 256) + hot = plt.get_cmap("hot", 256) hot_colors = hot(np.linspace(0, 1, 256)) - hot_colors[0, -1] = 0 + hot_colors[0, -1] = 0 hot_transparent = ListedColormap(hot_colors) - fig, axx = plt.subplots(figsize=(12,12), dpi=300) - fig.patch.set_facecolor('black') - axx.set_facecolor('black') + fig, axx = plt.subplots(figsize=(12, 12), dpi=300) + fig.patch.set_facecolor("black") + axx.set_facecolor("black") # Create the frontal projection thres = 300 @@ -55,10 +53,12 @@ def createMipPlot( calc_mask_proj = np.flip(np.transpose(calc_mask.sum(axis=1)), axis=0) # normalize both views for the heat map if not calc_mask_proj.max() == 0: - calc_mask_proj = calc_mask_proj/calc_mask_proj.max() + calc_mask_proj = calc_mask_proj / calc_mask_proj.max() - aorta_mask_proj = np.flip(np.transpose(aorta_mask.max(axis=1)), axis=0)*2 - plane_mask_proj = np.where(np.flip(np.transpose( (plane_mask == 1).max(axis=(0,1))), axis=0))[0][0] + aorta_mask_proj = np.flip(np.transpose(aorta_mask.max(axis=1)), axis=0) * 2 + plane_mask_proj = np.where( + np.flip(np.transpose((plane_mask == 1).max(axis=(0, 1))), axis=0) + )[0][0] # Create the side projection ct_proj_side = np.flip(np.transpose(ct.max(axis=0)), axis=0) @@ -67,9 +67,9 @@ def createMipPlot( calc_mask_proj_side = np.flip(np.transpose(calc_mask.sum(axis=0)), axis=0) # normalize both views for the heat map if not calc_mask_proj_side.max() == 0: - calc_mask_proj_side = calc_mask_proj_side/calc_mask_proj_side.max() - - aorta_mask_proj_side = np.flip(np.transpose(aorta_mask.max(axis=0)), axis=0)*2 + calc_mask_proj_side = calc_mask_proj_side / calc_mask_proj_side.max() + + aorta_mask_proj_side = np.flip(np.transpose(aorta_mask.max(axis=0)), axis=0) * 2 # Concatenate together ct_proj_all = np.hstack([ct_proj, ct_proj_side]) @@ -77,79 +77,171 @@ def createMipPlot( aorta_mask_proj_all = np.hstack([aorta_mask_proj, aorta_mask_proj_side]) # Plot the results - axx.imshow(ct_proj_all, cmap='gray', vmin=thres, vmax=1600, aspect=pix_size[2]/pix_size[0], alpha=1) + axx.imshow( + ct_proj_all, + cmap="gray", + vmin=thres, + vmax=1600, + aspect=pix_size[2] / pix_size[0], + alpha=1, + ) # Aorta mask and calcification - aorta_im = axx.imshow(aorta_mask_proj_all, cmap=map_object_seg, zorder=1, - vmin=0, vmax=10, aspect=pix_size[2]/pix_size[0], alpha=0.6, interpolation='nearest') - calc_im = axx.imshow(calc_mask_proj_all, cmap=hot_transparent, aspect=pix_size[2]/pix_size[0], alpha=1, - interpolation='nearest', zorder=2) + aorta_im = axx.imshow( + aorta_mask_proj_all, + cmap=map_object_seg, + zorder=1, + vmin=0, + vmax=10, + aspect=pix_size[2] / pix_size[0], + alpha=0.6, + interpolation="nearest", + ) + calc_im = axx.imshow( + calc_mask_proj_all, + cmap=hot_transparent, + aspect=pix_size[2] / pix_size[0], + alpha=1, + interpolation="nearest", + zorder=2, + ) # Ab and Th separating plane - axx.plot([0, ct_proj_all.shape[1]], [plane_mask_proj, plane_mask_proj], - color=map_object_seg(3), lw=0.8, alpha=0.8, zorder=0) - axx.text(30, plane_mask_proj - 8, 'Thoracic', color=map_object_seg(3), - va='center', ha='left', alpha=0.8, fontsize=10) - axx.text(30, plane_mask_proj + 8, 'Abdominal', color=map_object_seg(3), - va='center', ha='left', alpha=0.8, fontsize=10) + axx.plot( + [0, ct_proj_all.shape[1]], + [plane_mask_proj, plane_mask_proj], + color=map_object_seg(3), + lw=0.8, + alpha=0.8, + zorder=0, + ) + axx.text( + 30, + plane_mask_proj - 8, + "Thoracic", + color=map_object_seg(3), + va="center", + ha="left", + alpha=0.8, + fontsize=10, + ) + axx.text( + 30, + plane_mask_proj + 8, + "Abdominal", + color=map_object_seg(3), + va="center", + ha="left", + alpha=0.8, + fontsize=10, + ) axx.set_xticks([]) - axx.set_ylabel('Slice number', color='white', fontsize=10) - axx.tick_params(axis='y', colors='white', labelsize=10) + axx.set_ylabel("Slice number", color="white", fontsize=10) + axx.tick_params(axis="y", colors="white", labelsize=10) - # extend black background + # extend black background axx.set_xlim(0, ct_proj_all.shape[1]) # wrap plot in Image tight_bbox = fig.get_tightbbox(fig.canvas.get_renderer()) # Create a new bounding box with padding only on the left # The bbox coordinates are [left, bottom, right, top] - custom_bbox = Bbox([[tight_bbox.x0 - 0.05, tight_bbox.y0-0.07], # Add 0.5 inches to the left only - [tight_bbox.x1, tight_bbox.y1]]) + custom_bbox = Bbox( + [ + [ + tight_bbox.x0 - 0.05, + tight_bbox.y0 - 0.07, + ], # Add 0.5 inches to the left only + [tight_bbox.x1, tight_bbox.y1], + ] + ) buf_mip = io.BytesIO() - fig.savefig(buf_mip, bbox_inches=custom_bbox, pad_inches=0, dpi=300, format='png') + fig.savefig(buf_mip, bbox_inches=custom_bbox, pad_inches=0, dpi=300, format="png") plt.close(fig) buf_mip.seek(0) image_mip = Image.open(buf_mip) - ''' + """ Generate the text box - ''' + """ spacing = 23 indent = 1 - text_box_x_offset = 20 report_text = [] - report_text.append(r'$\bf{Calcification\ Report}$') - + report_text.append(r"$\bf{Calcification\ Report}$") + for i, (region, region_metrics) in enumerate(metrics.items()): - report_text.append(region) - report_text.append('{:<{}}{:<{}}{}'.format('',indent, 'Total number:', spacing, region_metrics['num_calc'])) - report_text.append('{:<{}}{:<{}}{:.3f}'.format('',indent,'Total volume (cm³):', spacing, region_metrics['volume_total'])) - report_text.append('{:<{}}{:<{}}{:.3f}{}{:.3f}'.format('',indent,'Mean volume (cm³):', spacing, - np.mean(region_metrics["volume"]),r'$\pm$',np.std(region_metrics["volume"]))) - report_text.append('{:<{}}{:<{}}{:.1f}{}{:.1f}'.format('',indent,'Median HU:', spacing, - np.mean(region_metrics["median_hu"]),r'$\pm$',np.std(region_metrics["median_hu"]))) - report_text.append('{:<{}}{:<{}}{:.3f}'.format('',indent,'% Volume calcified:', spacing, - np.mean(region_metrics["perc_calcified"]))) - - if 'agatston_score' in region_metrics: - report_text.append('{:<{}}{:<{}}{:.0f}'.format('',indent,'Agatston:', spacing, region_metrics['agatston_score'])) - - report_text.append('\n') - - report_text.append('{:<{}}{:<{}}{}'.format('',indent, 'Threshold (HU):', spacing, HU_val)) - - fig_t, axx_t = plt.subplots(figsize=(5.85,5.85), dpi=300) - fig_t.patch.set_facecolor('black') - axx_t.set_facecolor('black') - - axx_t.imshow(np.ones((100,65)), cmap='gray') + report_text.append(region) + report_text.append( + "{:<{}}{:<{}}{}".format( + "", indent, "Total number:", spacing, region_metrics["num_calc"] + ) + ) + report_text.append( + "{:<{}}{:<{}}{:.3f}".format( + "", + indent, + "Total volume (cm³):", + spacing, + region_metrics["volume_total"], + ) + ) + report_text.append( + "{:<{}}{:<{}}{:.3f}{}{:.3f}".format( + "", + indent, + "Mean volume (cm³):", + spacing, + np.mean(region_metrics["volume"]), + r"$\pm$", + np.std(region_metrics["volume"]), + ) + ) + report_text.append( + "{:<{}}{:<{}}{:.1f}{}{:.1f}".format( + "", + indent, + "Median HU:", + spacing, + np.mean(region_metrics["median_hu"]), + r"$\pm$", + np.std(region_metrics["median_hu"]), + ) + ) + report_text.append( + "{:<{}}{:<{}}{:.3f}".format( + "", + indent, + "% Volume calcified:", + spacing, + np.mean(region_metrics["perc_calcified"]), + ) + ) + + if "agatston_score" in region_metrics: + report_text.append( + "{:<{}}{:<{}}{:.0f}".format( + "", indent, "Agatston:", spacing, region_metrics["agatston_score"] + ) + ) + + report_text.append("\n") + + report_text.append( + "{:<{}}{:<{}}{}".format("", indent, "Threshold (HU):", spacing, HU_val) + ) + + fig_t, axx_t = plt.subplots(figsize=(5.85, 5.85), dpi=300) + fig_t.patch.set_facecolor("black") + axx_t.set_facecolor("black") + + axx_t.imshow(np.ones((100, 65)), cmap="gray") bbox_props = dict(boxstyle="round", facecolor="gray", alpha=0.5) text_obj = axx_t.text( x=0.3, y=1, - s='\n'.join(report_text), + s="\n".join(report_text), color="white", # fontsize=12.12 * font_scale, fontsize=15.7, @@ -159,17 +251,26 @@ def createMipPlot( ha="left", ) - axx_t.set_aspect(0.69) - axx_t.axis('off') + axx_t.set_aspect(0.69) + axx_t.axis("off") fig.canvas.draw_idle() plt.tight_layout() - + # wrap text box in Image tight_bbox_text = fig_t.get_tightbbox(fig_t.canvas.get_renderer()) - custom_bbox_text = Bbox([[tight_bbox_text.x0 - 0.05, tight_bbox_text.y0-0.14], # Add 0.5 inches to the left only - [tight_bbox_text.x1+0.15, tight_bbox_text.y1+0.05]]) + custom_bbox_text = Bbox( + [ + [ + tight_bbox_text.x0 - 0.05, + tight_bbox_text.y0 - 0.14, + ], # Add 0.5 inches to the left only + [tight_bbox_text.x1 + 0.15, tight_bbox_text.y1 + 0.05], + ] + ) buf_text = io.BytesIO() - fig_t.savefig(buf_text, bbox_inches=custom_bbox_text, pad_inches=0, dpi=300, format='png') + fig_t.savefig( + buf_text, bbox_inches=custom_bbox_text, pad_inches=0, dpi=300, format="png" + ) plt.close(fig_t) buf_text.seek(0) image_text = Image.open(buf_text) @@ -180,30 +281,34 @@ def createMipPlot( aspect_ratio = image_text.height / image_text.width adjusted_width = int(image_mip.height / aspect_ratio) - image_text_resample = image_text.resize([adjusted_width, image_mip.height], Image.LANCZOS) - + image_text_resample = image_text.resize( + [adjusted_width, image_mip.height], Image.LANCZOS + ) + # Merge into one image - result = Image.new("RGB", (image_mip.width + image_text_resample.width, image_mip.height)) + result = Image.new( + "RGB", (image_mip.width + image_text_resample.width, image_mip.height) + ) result.paste(im=image_mip, box=(0, 0)) result.paste(im=image_text_resample, box=(image_mip.width, 0)) # create path and save - path = os.path.join(save_root, 'sub_figures') + path = os.path.join(save_root, "sub_figures") os.makedirs(path, exist_ok=True) - result.save(os.path.join(path,'projection.png'), dpi=(300, 300)) - + result.save(os.path.join(path, "projection.png"), dpi=(300, 300)) + def crop_and_pad_image(image, pad_percent=0.025, pad_color=(0, 0, 0, 255)): # Ensure image has alpha channel - image = image.convert('RGBA') - + image = image.convert("RGBA") + # Create a binary mask: 255 for non-black, 0 for black - mask = image.convert('RGB').point(lambda p: 255 if p != 0 else 0).convert('L') + mask = image.convert("RGB").point(lambda p: 255 if p != 0 else 0).convert("L") bbox = mask.getbbox() if not bbox: return image # No content found; return original - + # Crop the image cropped = image.crop(bbox) @@ -212,7 +317,7 @@ def crop_and_pad_image(image, pad_percent=0.025, pad_color=(0, 0, 0, 255)): pad = int(width * pad_percent) # Add padding - padded = Image.new('RGBA', (width + pad * 2, height + pad * 2), pad_color) + padded = Image.new("RGBA", (width + pad * 2, height + pad * 2), pad_color) padded.paste(cropped, (pad, pad)) return padded @@ -225,13 +330,13 @@ def createCalciumMosaic( spine_mask: NDArray, pix_size: NDArray, save_root: str, - mosaic_type: str = 'all' - ) -> None: - ''' - Wrapper function that calls different functions for creating the mosaic + mosaic_type: str = "all", +) -> None: + """ + Wrapper function that calls different functions for creating the mosaic depending on the "mosaic_type". - ''' - if mosaic_type == 'all': + """ + if mosaic_type == "all": createCalciumMosaicAll( ct, calc_mask, @@ -239,7 +344,7 @@ def createCalciumMosaic( pix_size, save_root, ) - elif mosaic_type == 'vertebrae': + elif mosaic_type == "vertebrae": createCalciumMosaicVertebrae( ct, calc_mask, @@ -249,7 +354,8 @@ def createCalciumMosaic( save_root, ) else: - raise ValueError('mosaic_type not recognized, got: ' + str(mosaic_type)) + raise ValueError("mosaic_type not recognized, got: " + str(mosaic_type)) + def createCalciumMosaicAll( ct: NDArray, @@ -257,9 +363,9 @@ def createCalciumMosaicAll( aorta_mask_dil: NDArray, pix_size: NDArray, save_root: str, - ) -> None: - - calc_idx = np.where( calc_mask.sum(axis=(0,1)) )[0] +) -> None: + + calc_idx = np.where(calc_mask.sum(axis=(0, 1)))[0] per_row = 15 @@ -268,7 +374,7 @@ def createCalciumMosaicAll( # target size of 60 mm crop size rounded to nearest multiple of 2 crop_size = round(60 / pix_size[0]) - crop_size = 2*round(crop_size/2) + crop_size = 2 * round(crop_size / 2) ct_crops = [] mask_crops = [] @@ -279,32 +385,38 @@ def createCalciumMosaicAll( aorta_dil_tmp = [] for i, idx in enumerate(calc_idx[::-1]): - ct_slice = np.flip(np.transpose(ct[:,:,idx]), axis=(0,1)) - mask_slice = np.flip(np.transpose(calc_mask[:,:,idx]), axis=(0,1)) - aorta_slice = np.flip(np.transpose(aorta_mask_dil[:,:,idx]), axis=(0,1)) - + ct_slice = np.flip(np.transpose(ct[:, :, idx]), axis=(0, 1)) + mask_slice = np.flip(np.transpose(calc_mask[:, :, idx]), axis=(0, 1)) + aorta_slice = np.flip(np.transpose(aorta_mask_dil[:, :, idx]), axis=(0, 1)) + x_center = np.where(aorta_slice.sum(axis=1))[0] - x_center = x_center[len(x_center)//2] - + x_center = x_center[len(x_center) // 2] + y_center = np.where(aorta_slice.sum(axis=0))[0] - y_center = y_center[len(y_center)//2] - - ct_tmp.append(ct_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ]) - - mask_tmp.append(mask_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ]) - - aorta_dil_tmp.append(aorta_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ]) - - if (i+1) % per_row == 0: + y_center = y_center[len(y_center) // 2] + + ct_tmp.append( + ct_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + + mask_tmp.append( + mask_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + + aorta_dil_tmp.append( + aorta_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + + if (i + 1) % per_row == 0: # print('got here') ct_crops.append(np.hstack(ct_tmp)) mask_crops.append(np.hstack(mask_tmp)) @@ -315,13 +427,33 @@ def createCalciumMosaicAll( aorta_dil_tmp = [] if len(ct_tmp) > 0: - pad_len = per_row*crop_size - len(ct_tmp)*crop_size - - ct_crops.append(np.pad(np.hstack(ct_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=-400)) - mask_crops.append(np.pad(np.hstack(mask_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=0)) - aorta_dil_crops.append(np.pad(np.hstack(aorta_dil_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=0)) + pad_len = per_row * crop_size - len(ct_tmp) * crop_size + + ct_crops.append( + np.pad( + np.hstack(ct_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=-400, + ) + ) + mask_crops.append( + np.pad( + np.hstack(mask_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=0, + ) + ) + aorta_dil_crops.append( + np.pad( + np.hstack(aorta_dil_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=0, + ) + ) - ct_crops = np.vstack(ct_crops) mask_crops = np.vstack(mask_crops) aorta_dil_crops = np.vstack(aorta_dil_crops) @@ -330,29 +462,47 @@ def createCalciumMosaicAll( combined_crops[aorta_dil_crops == 1] = 2 combined_crops[mask_crops == 1] = 1 - fig, axx = plt.subplots(1, figsize=(20,20)) + fig, axx = plt.subplots(1, figsize=(20, 20)) + + axx.imshow(ct_crops, cmap="gray", vmin=-150, vmax=250) + axx.imshow( + combined_crops, + vmin=0, + vmax=10, + cmap=map_object_seg, + interpolation="nearest", + alpha=0.4, + ) - axx.imshow(ct_crops, cmap='gray', vmin=-150, vmax=250) - axx.imshow(combined_crops, vmin=0, vmax=10, cmap=map_object_seg, interpolation='nearest', alpha=0.4) - for i, idx_ in enumerate(calc_idx[::-1]): # account for how the MIP is displayed idx = calc_mask.shape[2] - idx_ - x_ = crop_size*(i % per_row) + 4 # crop_size//2 - y_ = crop_size*(i // per_row) + 4 - axx.text(x_, y_, s=str(idx), color = 'white', fontsize=10, va='top', ha='left', - bbox=dict(facecolor='black', edgecolor='black', boxstyle='round,pad=0.1')) + x_ = crop_size * (i % per_row) + 4 # crop_size//2 + y_ = crop_size * (i // per_row) + 4 + axx.text( + x_, + y_, + s=str(idx), + color="white", + fontsize=10, + va="top", + ha="left", + bbox=dict(facecolor="black", edgecolor="black", boxstyle="round,pad=0.1"), + ) axx.set_xticks([]) axx.set_yticks([]) - - path = os.path.join(save_root, 'sub_figures') + + path = os.path.join(save_root, "sub_figures") os.makedirs(path, exist_ok=True) # Save with the custom bounding box - fig.savefig(os.path.join(path,'mosaic.png'),bbox_inches="tight", pad_inches=0, dpi=300) + fig.savefig( + os.path.join(path, "mosaic.png"), bbox_inches="tight", pad_inches=0, dpi=300 + ) plt.close(fig) + def createCalciumMosaicVertebrae( ct: NDArray, calc_mask: NDArray, @@ -360,8 +510,8 @@ def createCalciumMosaicVertebrae( spine_mask: NDArray, pix_size: NDArray, save_root: str, - ) -> None: - +) -> None: + vertebrae_num = { 26: "vertebrae_S1", 27: "vertebrae_L5", @@ -387,11 +537,12 @@ def createCalciumMosaicVertebrae( 47: "vertebrae_C4", 48: "vertebrae_C3", 49: "vertebrae_C2", - 50: "vertebrae_C1"} + 50: "vertebrae_C1", + } vertebrae_name = {v: k for k, v in vertebrae_num.items()} - calc_idx = np.where( calc_mask.sum(axis=(0,1)) )[0] + calc_idx = np.where(calc_mask.sum(axis=(0, 1)))[0] per_row = 5 found_vertebras = np.unique(spine_mask) @@ -400,7 +551,7 @@ def createCalciumMosaicVertebrae( # target size of 120 mm crop size rounded to nearest multiple of 2 crop_size = round(120 / pix_size[0]) - crop_size = 2*round(crop_size/2) + crop_size = 2 * round(crop_size / 2) ct_crops = [] mask_crops = [] @@ -416,42 +567,47 @@ def createCalciumMosaicVertebrae( for vert_num in found_vertebras[::-1]: # only vertebraes if vert_num < 26: - continue - + continue + # Find middle of the vertebra tmp_spine_mask = spine_mask == vert_num - tmp_spine_idx = np.where(tmp_spine_mask.sum(axis=(0,1)))[0] - idx = tmp_spine_idx[len(tmp_spine_idx)//2] + tmp_spine_idx = np.where(tmp_spine_mask.sum(axis=(0, 1)))[0] + idx = tmp_spine_idx[len(tmp_spine_idx) // 2] + + ct_slice = np.flip(np.transpose(ct[:, :, idx]), axis=(0, 1)) + mask_slice = np.flip(np.transpose(calc_mask[:, :, idx]), axis=(0, 1)) + aorta_slice = np.flip(np.transpose(aorta_mask_dil[:, :, idx]), axis=(0, 1)) - ct_slice = np.flip(np.transpose(ct[:,:,idx]), axis=(0,1)) - mask_slice = np.flip(np.transpose(calc_mask[:,:,idx]), axis=(0,1)) - aorta_slice = np.flip(np.transpose(aorta_mask_dil[:,:,idx]), axis=(0,1)) - x_center = np.where(aorta_slice.sum(axis=1))[0] # skip if no aorta is present if len(x_center) == 0: continue - x_center = x_center[len(x_center)//2] - + x_center = x_center[len(x_center) // 2] + y_center = np.where(aorta_slice.sum(axis=0))[0] - y_center = y_center[len(y_center)//2] - - ct_tmp.append(ct_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ] - ) - mask_tmp.append(mask_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ]) - - aorta_dil_tmp.append(aorta_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ]) - - if (i+1) % per_row == 0: + y_center = y_center[len(y_center) // 2] + + ct_tmp.append( + ct_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + mask_tmp.append( + mask_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + + aorta_dil_tmp.append( + aorta_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + + if (i + 1) % per_row == 0: # print('got here') ct_crops.append(np.hstack(ct_tmp)) mask_crops.append(np.hstack(mask_tmp)) @@ -460,18 +616,38 @@ def createCalciumMosaicVertebrae( ct_tmp = [] mask_tmp = [] aorta_dil_tmp = [] - - vertebra_names.append(vertebrae_num[vert_num].split('_')[1]) + + vertebra_names.append(vertebrae_num[vert_num].split("_")[1]) i += 1 if len(ct_tmp) > 0: - pad_len = per_row*crop_size - len(ct_tmp)*crop_size - - ct_crops.append(np.pad(np.hstack(ct_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=-400)) - mask_crops.append(np.pad(np.hstack(mask_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=0)) - aorta_dil_crops.append(np.pad(np.hstack(aorta_dil_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=0)) + pad_len = per_row * crop_size - len(ct_tmp) * crop_size + + ct_crops.append( + np.pad( + np.hstack(ct_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=-400, + ) + ) + mask_crops.append( + np.pad( + np.hstack(mask_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=0, + ) + ) + aorta_dil_crops.append( + np.pad( + np.hstack(aorta_dil_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=0, + ) + ) - ct_crops = np.vstack(ct_crops) mask_crops = np.vstack(mask_crops) aorta_dil_crops = np.vstack(aorta_dil_crops) @@ -480,51 +656,74 @@ def createCalciumMosaicVertebrae( combined_crops[aorta_dil_crops == 1] = 2 combined_crops[mask_crops == 1] = 1 - fig, axx = plt.subplots(1, figsize=(20,20)) + fig, axx = plt.subplots(1, figsize=(20, 20)) + + axx.imshow(ct_crops, cmap="gray", vmin=-150, vmax=250) + axx.imshow( + combined_crops, + vmin=0, + vmax=10, + cmap=map_object_seg, + interpolation="nearest", + alpha=0.4, + ) - axx.imshow(ct_crops, cmap='gray', vmin=-150, vmax=250) - axx.imshow(combined_crops, vmin=0, vmax=10, cmap=map_object_seg, interpolation='nearest', alpha=0.4) - # for i, idx_ in enumerate(calc_idx[::-1]): for i, name in enumerate(vertebra_names): - x_ = crop_size*(i % per_row) + 4 # crop_size//2 - y_ = crop_size*(i // per_row) + 4 - axx.text(x_, y_, s=name, color = 'white', fontsize=18, va='top', ha='left', - bbox=dict(facecolor='black', edgecolor='black', boxstyle='round,pad=0.1')) + x_ = crop_size * (i % per_row) + 4 # crop_size//2 + y_ = crop_size * (i // per_row) + 4 + axx.text( + x_, + y_, + s=name, + color="white", + fontsize=18, + va="top", + ha="left", + bbox=dict(facecolor="black", edgecolor="black", boxstyle="round,pad=0.1"), + ) axx.set_xticks([]) axx.set_yticks([]) - - path = os.path.join(save_root, 'sub_figures') + + path = os.path.join(save_root, "sub_figures") os.makedirs(path, exist_ok=True) # Save with the custom bounding box - fig.savefig(os.path.join(path,'mosaic.png'),bbox_inches="tight", pad_inches=0, dpi=300) + fig.savefig( + os.path.join(path, "mosaic.png"), bbox_inches="tight", pad_inches=0, dpi=300 + ) plt.close(fig) - - + + def mergeMipAndMosaic(save_root: str) -> None: - ''' + """ This function loads the MIP and mosaic images and merges them into a single final image. - ''' - - if os.path.isfile(os.path.join(save_root, 'sub_figures/mosaic.png')): - img_proj = Image.open(os.path.join(save_root, 'sub_figures/projection.png')) - img_mosaic = Image.open(os.path.join(save_root, 'sub_figures/mosaic.png')) - + """ + + if os.path.isfile(os.path.join(save_root, "sub_figures/mosaic.png")): + img_proj = Image.open(os.path.join(save_root, "sub_figures/projection.png")) + img_mosaic = Image.open(os.path.join(save_root, "sub_figures/mosaic.png")) + # Match the width to the projection image aspect_ratio = img_mosaic.height / img_mosaic.width new_height = int(img_proj.width * aspect_ratio) - img_mosaic_resample = img_mosaic.resize([img_proj.width, new_height], Image.LANCZOS) + img_mosaic_resample = img_mosaic.resize( + [img_proj.width, new_height], Image.LANCZOS + ) - result = Image.new("RGB", (img_proj.width, img_proj.height + img_mosaic_resample.height)) + result = Image.new( + "RGB", (img_proj.width, img_proj.height + img_mosaic_resample.height) + ) result.paste(im=img_proj, box=(0, 0)) result.paste(im=img_mosaic_resample, box=(0, img_proj.height)) - result.save(os.path.join(save_root,'overview.png'), dpi=(300, 300)) + result.save(os.path.join(save_root, "overview.png"), dpi=(300, 300)) else: - shutil.copy2(os.path.join(save_root, 'sub_figures/projection.png'), os.path.join(save_root,'overview.png')) - \ No newline at end of file + shutil.copy2( + os.path.join(save_root, "sub_figures/projection.png"), + os.path.join(save_root, "overview.png"), + ) diff --git a/comp2comp/contrast_phase/contrast_inf.py b/comp2comp/contrast_phase/contrast_inf.py index 860d1477..a768b8e5 100644 --- a/comp2comp/contrast_phase/contrast_inf.py +++ b/comp2comp/contrast_phase/contrast_inf.py @@ -8,8 +8,10 @@ import scipy import SimpleITK as sitk from scipy import ndimage as ndi + # import xgboost + def loadNiiToArray(path): NiImg = nib.load(path) array = np.array(NiImg.dataobj) @@ -176,7 +178,7 @@ def getFeatures(TSArray, scanArray): kidneyLMask = getClassBinaryMask(TSArray, 3) kidneyRMask = getClassBinaryMask(TSArray, 2) adRMask = getClassBinaryMask(TSArray, 11) - + # aortaMask = getClassBinaryMask(TSArray, 52) # IVCMask = getClassBinaryMask(TSArray, 63) # portalMask = getClassBinaryMask(TSArray, 64) @@ -421,16 +423,11 @@ def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False): model = loadModel() # TS_array, image_array = loadNiftis(TS_output_nifti_path, image_nifti_path) featureArray, kidneyLMask, adRMask = getFeatures(TS_array, image_array) - + y_pred_proba = model.predict_proba([featureArray])[0] y_pred = np.argmax(y_pred_proba) - - phase_dict = { - 0: "non-contrast", - 1: "arterial", - 2: "venous", - 3: "delayed" - } + + phase_dict = {0: "non-contrast", 1: "arterial", 2: "venous", 3: "delayed"} pred_phase = phase_dict[y_pred] @@ -447,16 +444,16 @@ def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False): if not os.path.exists(output_path_metrics): os.makedirs(output_path_metrics) outputTxt = os.path.join(output_path_metrics, "phase_prediction.txt") - + with open(outputTxt, "w") as text_file: - text_file.write('phase,'+pred_phase + '\n') + text_file.write("phase," + pred_phase + "\n") for i in range(len(y_pred_proba)): - text_file.write('{},{:.3f}\n'.format(phase_dict[i], y_pred_proba[i])) + text_file.write("{},{:.3f}\n".format(phase_dict[i], y_pred_proba[i])) - print('Predicted phase: ' + pred_phase) - print('\nProbabilities:') + print("Predicted phase: " + pred_phase) + print("\nProbabilities:") for i in range(len(y_pred_proba)): - print('{:<20}{:.3f}'.format(phase_dict[i], y_pred_proba[i])) + print("{:<20}{:.3f}".format(phase_dict[i], y_pred_proba[i])) output_path_images = os.path.join(outputPath, "images") if not os.path.exists(output_path_images): diff --git a/comp2comp/contrast_phase/contrast_phase.py b/comp2comp/contrast_phase/contrast_phase.py index 524b1b83..54df1455 100644 --- a/comp2comp/contrast_phase/contrast_phase.py +++ b/comp2comp/contrast_phase/contrast_phase.py @@ -1,20 +1,21 @@ import os +import subprocess +import zipfile from pathlib import Path from time import time from typing import Union -import subprocess -import zipfile -from totalsegmentator.libs import ( - # download_pretrained_weights, +from totalsegmentator.libs import ( # download_pretrained_weights, nostdout, setup_nnunet, ) -# from totalsegmentatorv2.python_api import totalsegmentator from comp2comp.contrast_phase.contrast_inf import predict_phase from comp2comp.inference_class_base import InferenceClass +# from totalsegmentatorv2.python_api import totalsegmentator + + class ContrastPhaseDetection(InferenceClass): """Contrast Phase Detection.""" @@ -74,7 +75,7 @@ def run_segmentation( # download with weight for id 251 self.download_pretrained_weights_updated(task_id[0]) - + from totalsegmentator.nnunet import nnUNet_predict_image with nostdout(): @@ -99,7 +100,7 @@ def run_segmentation( verbose=False, test=0, ) - + # seg = totalsegmentator( # input = os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), # output = os.path.join(self.output_dir_segmentations, "segmentation.nii"), @@ -135,40 +136,36 @@ def run_segmentation( # return seg, img return seg, img - + def download_pretrained_weights_updated(self, task_id): - ''' + """ Download the weights with curl to resolve problems with downloading from Zenodo - ''' + """ home_path = Path(os.environ["SCRATCH"]) config_dir = home_path / ".totalsegmentator/nnunet/results/nnUNet" (config_dir / "3d_fullres").mkdir(exist_ok=True, parents=True) (config_dir / "2d").mkdir(exist_ok=True, parents=True) - + url = "https://zenodo.org/records/6802342/files/Task251_TotalSegmentator_part1_organs_1139subj.zip?download=1" config_dir = config_dir / "3d_fullres" weights_path = config_dir / "Task251_TotalSegmentator_part1_organs_1139subj" tempfile = config_dir / "tmp_download_file.zip" - + if not weights_path.exists(): - print('Downloading weights..') - subprocess.run( - ["curl", "-L", url, "-o", tempfile], - check=True - ) - - print('Unzipping..') - with zipfile.ZipFile(config_dir / "tmp_download_file.zip", 'r') as zip_f: + print("Downloading weights..") + subprocess.run(["curl", "-L", url, "-o", tempfile], check=True) + + print("Unzipping..") + with zipfile.ZipFile(config_dir / "tmp_download_file.zip", "r") as zip_f: zip_f.extractall(config_dir) # print(f" downloaded in {time.time()-st:.2f}s") if tempfile.exists(): os.remove(tempfile) - print('Done.') + print("Done.") else: - print('Weights are already downloaded') - - + print("Weights are already downloaded") + def convertNibToNumpy(self, TSNib, ImageNib): """Convert nifti to numpy array. diff --git a/comp2comp/inference_pipeline.py b/comp2comp/inference_pipeline.py index 0aca5daf..60cba8d2 100644 --- a/comp2comp/inference_pipeline.py +++ b/comp2comp/inference_pipeline.py @@ -86,7 +86,8 @@ def __call__(self, inference_pipeline=None, **kwargs): def saveArrToNifti(self, arr, path): """ - Saves an array to nifti using the CT as reference + Saves an array to nifti using the CT as reference. Assumes + that self.medical_volume is not changed Args: arr (ndarray): input array. diff --git a/comp2comp/io/fda_io.py b/comp2comp/io/fda_io.py new file mode 100644 index 00000000..da53260d --- /dev/null +++ b/comp2comp/io/fda_io.py @@ -0,0 +1,184 @@ +""" +@author: louisblankemeier +""" + +import os +import shutil +from pathlib import Path +from typing import Dict, Union + +# import dicom2nifti +import dosma as dm +import pydicom +import SimpleITK as sitk + +from comp2comp.inference_class_base import InferenceClass + + +class DicomLoader(InferenceClass): + """Load a single dicom series.""" + + def __init__(self, input_path: Union[str, Path]): + super().__init__() + self.dicom_dir = Path(input_path) + self.dr = dm.DicomReader() + + def __call__(self, inference_pipeline) -> Dict: + medical_volume = self.dr.load( + self.dicom_dir, group_by=None, sort_by="InstanceNumber" + )[0] + return {"medical_volume": medical_volume} + + +class NiftiSaver(InferenceClass): + """Save dosma medical volume object to NIfTI file.""" + + def __init__(self): + super().__init__() + # self.output_dir = Path(output_path) + self.nw = dm.NiftiWriter() + + def __call__( + self, inference_pipeline, medical_volume: dm.MedicalVolume + ) -> Dict[str, Path]: + nifti_file = inference_pipeline.output_dir + self.nw.write(medical_volume, nifti_file) + return {"nifti_file": nifti_file} + + +class DicomFinder(InferenceClass): + """Find dicom files in a directory.""" + + def __init__(self, input_path: Union[str, Path]) -> Dict[str, Path]: + super().__init__() + self.input_path = Path(input_path) + + def __call__(self, inference_pipeline) -> Dict[str, Path]: + """Find dicom files in a directory. + + Args: + inference_pipeline (InferencePipeline): Inference pipeline. + + Returns: + Dict[str, Path]: Dictionary containing dicom files. + """ + dicom_files = [] + for file in self.input_path.glob("**/*.dcm"): + dicom_files.append(file) + inference_pipeline.dicom_file_paths = dicom_files + return {} + + +class DicomToNifti(InferenceClass): + """Convert dicom files to NIfTI files.""" + + def __init__(self, input_path: Union[str, Path], pipeline_name=None, save=True): + super().__init__() + self.input_path = Path(input_path) + self.save = save + self.pipeline_name = pipeline_name + + def __call__(self, inference_pipeline): + if os.path.exists( + os.path.join( + inference_pipeline.output_dir, "segmentations", "converted_dcm.nii.gz" + ) + ): + return {} + if hasattr(inference_pipeline, "medical_volume"): + return {} + output_dir = inference_pipeline.output_dir + segmentations_output_dir = os.path.join(output_dir, "segmentations") + os.makedirs(segmentations_output_dir, exist_ok=True) + + # if self.input_path is a folder + if self.input_path.is_dir(): + ds = dicom_series_to_nifti( + self.input_path, + output_file=os.path.join( + segmentations_output_dir, "converted_dcm.nii.gz" + ), + reorient_nifti=False, + pipeline_name=self.pipeline_name, + ) + inference_pipeline.dicom_series_path = str(self.input_path) + inference_pipeline.dicom_ds = ds + # try: + # age = np.array(ds.loc[ds['Keyword'] == 'PatientAge']['Value'])[0] + # inference_pipeline.age = int(age.replace('Y', '')) + try: + age = ds.PatientAge + inference_pipeline.age = int(age.replace("Y", "")) + except: + inference_pipeline.age = 65 + elif str(self.input_path).endswith(".nii"): + shutil.copy( + self.input_path, + os.path.join(segmentations_output_dir, "converted_dcm.nii"), + ) + elif str(self.input_path).endswith(".nii.gz"): + shutil.copy( + self.input_path, + os.path.join(segmentations_output_dir, "converted_dcm.nii.gz"), + ) + + return {} + + +def series_selector(dicom_path, pipeline_name=None): + ds = pydicom.filereader.dcmread(dicom_path) + list(ds.ImageType) + # if pipeline_name != "aaa": + # if not any("primary" in s.lower() for s in image_type_list): + # raise ValueError("Not primary image type") + # if not any("original" in s.lower() for s in image_type_list): + # raise ValueError("Not original image type") + # if ds.ImageOrientationPatient != [1, 0, 0, 0, 1, 0]: + # raise ValueError("Image orientation is not axial") + # else: + # print(f"Skipping primary, original, and orientation image type check for the {pipeline_name} pipeline.") + # if any("gsi" in s.lower() for s in image_type_list): + # raise ValueError("GSI image type") + return ds + + +# def dicom_series_to_nifti(input_path, output_file, reorient_nifti, pipeline_name=None): +# reader = sitk.ImageSeriesReader() +# dicom_names = reader.GetGDCMSeriesFileNames(str(input_path)) +# ds = series_selector(dicom_names[0], pipeline_name=pipeline_name) +# reader.SetFileNames(dicom_names) +# image = reader.Execute() +# sitk.WriteImage(image, output_file) +# return ds + +# study_names = ['1.2.840.113619.2.25.4.1260288.1278509664.191', '1.2.840.113619.2.327.3.4137921600.124.1407058488.371', '1.2.840.113619.2.334.3.4137913664.424.1466463538.272', '1.2.840.113619.2.55.3.4137898048.540.1264740411.654', '1.2.840.113619.2.55.3.4137913664.869.1301053239.794', '1.2.840.113970.3.57.1.3219726.660248536.20100321.1084131', '1.2.840.113970.3.57.1.3614815.677188018.20100415.1162942', '1.2.840.113970.3.57.1.3763611.696273942.20100514.1122931', '1.2.840.113970.3.57.1.4052749.730108233.20100705.1185125', '1.2.840.113970.3.57.1.4056415.730680166.20100706.1145200', '1.2.840.113970.3.57.1.4405910.770169583.20100903.1163114', '1.2.840.113970.3.57.1.4655468.746054459.20101013.1082134', '1.2.840.113970.3.57.1.49348828.20140501.1152855', '1.2.840.113970.3.57.1.4969094.825298235.20101121.1203449', '1.2.840.113970.3.57.1.49751506.20140729.1100309', '1.2.840.113970.3.57.1.5047382.829836439.20101201.1104716', '1.2.840.113970.3.57.1.5090606.836995318.20101208.1163557', '1.2.840.113970.3.57.1.5124951.841221198.20101214.1121828', '1.2.840.113970.3.57.1.51357670.20150625.1090927', '1.2.840.113970.3.57.1.52787137.20160406.1151608', '1.2.840.113970.3.57.1.55223913.20170707.1104429', '1.2.840.113970.3.57.1.59404158.20190525.1', '1.2.840.113970.3.57.1.62404600.20200917.1092505', '1.2.840.113970.3.57.1.63796608.20210406.1154212', '1.2.840.114350.2.629.2.798268.2.214049946.1', '1.2.840.113619.2.334.3.4137913664.316.1416947176.728', '1.2.840.113619.2.334.3.4137913664.358.1493407736.930', '1.2.840.113619.2.334.3.4137913664.664.1426504723.10', '1.2.840.113619.2.340.3.4137913664.233.1617226394.758', '1.2.840.113619.2.55.3.4137913920.105.1281440647.285', '1.2.840.113619.2.55.3.4137919064.81.1308744113.151', '1.2.840.113970.3.57.1.4340491.763145748.20100824.1160430', '1.2.840.113970.3.57.1.4533761.763631267.20100923.1092253', '1.2.840.113970.3.57.1.4607101.790945843.20101004.1173319', '1.2.840.113970.3.57.1.47689543.20130427.1094829', '1.2.840.113970.3.57.1.48824001.20140107.1201514', '1.2.840.113970.3.57.1.4894117.820442914.20101115.1', '1.2.840.113970.3.57.1.50245280.20141109.1125217', '1.2.840.113970.3.57.1.5047802.831585137.20101201.1121119', '1.2.840.113970.3.57.1.61153789.20200207.1073528', '1.2.840.113619.2.327.3.4137937730.76.1480974569.157', '1.2.840.113619.2.452.3.4137937730.934.1620642878.238', '1.2.840.113619.2.25.1.1762890133.1277399821.101', '1.2.840.113619.2.452.3.4137937730.292.1577705619.326', '1.2.840.113619.2.327.3.4137921600.777.1446791046.827', '1.2.840.113970.3.57.1.2090058.640972004.20100221.1112206', '1.2.840.113970.3.57.1.3697867.688014421.20100502.1095814', '1.2.840.113970.3.57.1.4184890.744374984.20100727.1102449', '1.2.840.113970.3.57.1.4474984.778257809.20100916.1', '1.2.840.113970.3.57.1.4580173.739977092.20100930.1101118', '1.2.840.113970.3.57.1.4767311.809154599.20101029.1223326', '1.2.840.113970.3.57.1.47734138.20130507.1102958', '1.2.840.113970.3.57.1.5198629.847511491.20101222.1115204', '1.2.840.113970.3.57.1.53528987.20160829.1092825', '1.2.840.113970.3.57.1.55414302.20170810.1150932', '1.2.840.113970.3.57.1.58854015.20190305.1124432', '1.2.840.113970.3.57.1.62723936.20201102.1122439', '1.2.840.113970.3.57.1.8483035.20110615.1105311', '1.2.840.113619.2.284.3.4137913664.33.1368186291.228', '1.2.840.113619.2.278.3.4137880131.89.1468934834.80', '1.2.840.113619.2.55.3.4137898048.450.1264404041.826', '1.2.840.113619.2.55.3.4137898048.658.1262581298.908', '1.2.840.113619.2.55.3.4137937730.201.1283370977.950', '1.2.840.113619.2.55.3.4137913920.378.1268400299.589', '1.2.840.113619.2.278.3.4137906246.531.1468000584.70', '1.2.840.113619.2.340.3.4137913664.890.1519920186.497', '1.2.840.113619.2.278.3.4137913920.339.1386073031.685', '1.2.840.113619.2.55.3.4137913664.930.1327929181.89', '1.2.840.113619.2.55.3.4137898048.456.1263960016.452', '1.2.840.113619.2.55.3.4137898048.573.1265688084.446', '1.2.840.113619.2.55.3.4137913920.859.1306152382.176', '1.2.840.113619.2.55.3.4137913920.49.1285760775.796', '1.2.840.113970.3.57.1.3535246.672672494.20100408.1160013', '1.2.840.113970.3.57.1.3881610.709332709.20100603.1185025', '1.2.840.113970.3.57.1.4031281.727355108.20100630.1202446', '1.2.840.113970.3.57.1.4603293.790535781.20101004.1', '1.2.840.113970.3.57.1.47158524.20121128.1144305', '1.2.840.113970.3.57.1.48785022.20131229.1122812', '1.2.840.113970.3.57.1.5050436.831736828.20101201.1145817', '1.2.840.113970.3.57.1.53148032.20160617.1', '1.2.840.113970.3.57.1.56566535.20180228.1132735', '1.2.840.113970.3.57.1.57378602.20180713.1133233', '1.2.840.113970.3.57.1.59456663.20190603.1094141', '1.2.840.113970.3.57.1.64169035.20210527.1142142', '1.2.840.113970.3.57.1.18312951.20111028.1130906', '1.2.840.113619.2.55.1.1762890133.2241.1265487737.860', '1.2.840.113970.3.57.1.3650772.681803841.20100422.1160130', '1.2.840.113970.3.57.1.4363092.765584160.20100827.1204844', '1.2.840.113970.3.57.1.10968887.20110727.1141411', '1.2.840.113970.3.57.1.3981576.721961237.20100623.1', '1.2.840.113970.3.57.1.28751236.20120209.1182132', '1.2.840.113619.2.55.3.4137913920.96.1277898152.206', '1.2.840.113619.2.278.3.4137906246.578.1438645489.621', '1.2.840.113970.3.57.1.41221444.20120422.1140346', '1.2.840.113619.2.327.3.4137937730.712.1485897973.485', '1.2.840.113970.3.57.1.4376710.767124611.20100830.1164402', '1.2.840.113619.2.327.3.4137921600.683.1373108320.118', '1.2.840.113970.3.57.1.4347176.763712103.20100825.1120652', '1.2.840.113619.2.55.3.4137913664.893.1304508347.450', '1.2.840.113970.3.57.1.4317264.760070005.20100819.1185108', '1.2.840.113970.3.57.1.3801003.701358650.20100521.1225024', '1.2.840.113619.2.55.3.4137913920.314.1290601390.805', '1.2.840.113970.3.57.1.3930271.714756367.20100611.1201120', '1.2.840.113619.2.452.3.4137914282.359.1581059411.51', '1.2.840.113970.3.57.1.3606510.676389442.20100414.1153128', '1.2.840.113619.2.452.3.4137937730.626.1613672833.22', '1.2.840.113970.3.57.1.5144643.842690502.20101216.1', '1.2.840.113970.3.57.1.5209269.848820319.20101224.1135258', '1.2.840.113970.3.57.1.5209289.848822640.20101224.1140532', '1.2.840.113970.3.57.1.52319692.20160106.1155357', '1.2.840.113970.3.57.1.53426441.20160809.1165104', '1.2.840.113970.3.57.1.56759505.20180402.1', '1.2.840.113970.3.57.1.56876931.20180420.1111722', '1.2.840.113970.3.57.1.57175539.20180608.1151445', '1.2.840.113970.3.57.1.57442215.20180724.1155158', '1.2.840.113970.3.57.1.63592362.20210309.1151914', '1.2.840.113970.3.57.1.63962688.20210429.1085910', '1.2.840.113619.2.55.3.4137937730.188.1272884760.195', '1.2.840.113619.2.55.3.4137937730.843.1268670391.855', '1.2.840.113619.2.55.1.1762890133.2211.1263495204.338', '1.2.840.113619.2.55.3.4137913920.35.1287056772.543', '1.2.840.113619.2.55.3.4137937730.263.1268327252.380', '1.2.840.113970.3.57.1.4289530.755926672.20100813.1124157', '1.2.840.113970.3.57.1.4063977.731696453.20100707.1230008', '1.2.840.113619.2.55.3.4137898048.323.1263128734.513', '1.2.840.113619.2.25.4.2436394.1351275462.47', '1.2.840.113619.2.55.3.4137937730.383.1346756337.198', '1.2.840.113619.2.452.3.4137937730.880.1597833614.392', '1.2.840.113970.3.57.1.4539852.778443510.20100924.1092158', '1.2.840.113970.3.57.1.47295241.20130108.1095420', '1.2.840.113970.3.57.1.3618662.677151174.20100416.1093839', '1.2.840.113619.2.55.3.4137913664.931.1332848521.360', '1.2.840.113619.2.55.3.4137913920.667.1286544113.306', '1.2.840.113619.2.55.3.4137937730.147.1278416606.583', '1.2.840.113970.3.57.1.49759160.20140730.1150210', '1.2.840.113970.3.57.1.4304392.758425809.20100817.1152934', '1.2.840.113970.3.57.1.5121397.840461285.20101213.1140806', '1.2.840.113970.3.57.1.58163955.20181116.1133327', '1.2.840.113970.3.57.1.61473486.20200329.1094314', '1.2.840.113970.3.57.1.4593289.789341784.20101001.1210500', '1.2.840.113619.2.284.3.4137913664.228.1343342467.42', '1.2.840.113970.3.57.1.4826969.814282415.20101106.1', '1.2.840.113970.3.57.1.3935354.715563740.20100613.1115925', '1.2.840.113970.3.57.1.49261025.20140414.1', '1.2.840.113619.2.452.3.4137919146.647.1598351842.945', '1.2.840.113970.3.57.1.4773560.809560722.20101030.1175440', '1.2.840.113970.3.57.1.49670558.20140711.1091510', '1.2.840.113970.3.57.1.49408784.20140514.1151454', '1.2.840.113970.3.57.1.48169266.20130814.1180728', '1.2.840.113619.2.327.3.4137921600.444.1461646005.589', '1.2.840.113970.3.57.1.4169740.742624341.20100724.1104442', '1.2.840.113970.3.57.1.49246916.20140410.1112156', '1.2.840.113619.2.55.3.4137913920.134.1285676575.643', '1.2.840.113970.3.57.1.3737498.693550900.20100510.1224704', '1.2.840.113619.2.55.3.4137937730.357.1273688089.806', '1.2.840.113970.3.57.1.55140765.20170621.1143629', '1.2.840.113970.3.57.1.5574721.883536255.20110210.1', '1.2.840.113970.3.57.1.58917438.20190314.1130243', '1.2.840.113970.3.57.1.60344320.20191014.1093349', '1.2.840.113970.3.57.1.64865302.20210903.1102916', '1.2.840.113970.3.57.1.53951437.20161115.1', '1.2.840.113970.3.57.1.12847237.20110820.1164653', '1.2.840.113970.3.57.1.4184010.744289415.20100727.1', '1.2.840.113970.3.57.1.4317421.760163336.20100819.1211321', '1.2.840.113619.2.278.3.4137913920.64.1347622042.682', '1.2.840.113619.2.327.3.4137921600.37.1376025372.795', '1.2.840.113970.3.57.1.4875594.818249654.20101111.1160125', '1.2.840.113970.3.57.1.5004044.828173307.20101126.1', '1.2.840.113970.3.57.1.60366936.20191016.1141224', '1.2.840.113970.3.57.1.4346830.757692441.20100825.1110826', '1.2.840.113970.3.57.1.2010120.633457675.20100210.1082137', '1.2.840.113619.2.55.3.4137898048.323.1263128740.296', '1.2.840.113619.2.55.1.1762890133.2242.1271248808.986', '1.2.840.113619.2.55.3.4137898048.663.1262842370.331', '1.2.840.113970.3.57.1.5122579.840755650.20101213.1210058', '1.2.840.113970.3.57.1.62320559.20200904.1123949', '1.2.840.113619.2.55.3.4137898048.450.1264404042.26', '1.2.840.113970.3.57.1.50969675.20150407.1160416', '1.2.840.113619.2.55.3.4137937730.95.1274890075.331'] + + +def dicom_series_to_nifti(input_path, output_file, reorient_nifti, pipeline_name=None): + # split input path to get study name + print("Input path") + print(str(input_path)) + reader = sitk.ImageSeriesReader() + dicom_names = reader.GetGDCMSeriesFileNames(str(input_path)) + + expected_size = [512, 512, 1] + valid_dicom_names = [] + + for dicom_name in dicom_names: + ds = pydicom.dcmread(dicom_name) + rows, cols = ds.Rows, ds.Columns + + if rows == expected_size[0] and cols == expected_size[1]: + valid_dicom_names.append(dicom_name) + else: + print(f"Skipping {dicom_name} due to size mismatch: {rows}x{cols}") + + if not valid_dicom_names: + raise RuntimeError("No valid DICOM files found with the expected size.") + + reader.SetFileNames(valid_dicom_names) + image = reader.Execute() + sitk.WriteImage(image, output_file) + + ds = series_selector(valid_dicom_names[0], pipeline_name=pipeline_name) + return ds diff --git a/comp2comp/io/io.py b/comp2comp/io/io.py index 9870e41d..843a1640 100644 --- a/comp2comp/io/io.py +++ b/comp2comp/io/io.py @@ -95,8 +95,10 @@ def __call__(self, inference_pipeline): # if self.input_path is a folder if self.input_path.is_dir(): # store a dcm object for retrieving dicom tags - dcm_files = [d for d in os.listdir(self.input_path) if d.endswith('.dcm')] - inference_pipeline.dcm = pydicom.read_file(os.path.join(self.input_path, dcm_files[0])) + dcm_files = [d for d in os.listdir(self.input_path) if d.endswith(".dcm")] + inference_pipeline.dcm = pydicom.read_file( + os.path.join(self.input_path, dcm_files[0]) + ) ds = dicom_series_to_nifti( self.input_path, diff --git a/comp2comp/io/io_utils.py b/comp2comp/io/io_utils.py index d062a598..4d154bd8 100644 --- a/comp2comp/io/io_utils.py +++ b/comp2comp/io/io_utils.py @@ -58,11 +58,17 @@ def get_dicom_or_nifti_paths_and_num(path): with open(path, "r") as f: for dicom_folder_path in f: dicom_folder_path = dicom_folder_path.strip() - if dicom_folder_path.endswith(".nii") or dicom_folder_path.endswith(".nii.gz"): - dicom_nifti_paths.append( (dicom_folder_path, getNumSlicesNifti(dicom_folder_path))) + if dicom_folder_path.endswith(".nii") or dicom_folder_path.endswith( + ".nii.gz" + ): + dicom_nifti_paths.append( + (dicom_folder_path, getNumSlicesNifti(dicom_folder_path)) + ) else: - dicom_nifti_paths.append( (dicom_folder_path, len(os.listdir(dicom_folder_path)))) - else: + dicom_nifti_paths.append( + (dicom_folder_path, len(os.listdir(dicom_folder_path))) + ) + else: for root, dirs, files in os.walk(path): if len(files) > 0: # if all(file.endswith(".dcm") or file.endswith(".dicom") for file in files): diff --git a/comp2comp/muscle_adipose_tissue/fda_muscle_adipose_tissue.py b/comp2comp/muscle_adipose_tissue/fda_muscle_adipose_tissue.py new file mode 100644 index 00000000..45dd8101 --- /dev/null +++ b/comp2comp/muscle_adipose_tissue/fda_muscle_adipose_tissue.py @@ -0,0 +1,636 @@ +import os +import zipfile +from pathlib import Path +from time import perf_counter +from typing import List, Union + +import cv2 +import h5py +import nibabel as nib +import numpy as np +import pandas as pd +import wget +from keras import backend as K +from tqdm import tqdm + +from comp2comp.inference_class_base import InferenceClass +from comp2comp.metrics.metrics import CrossSectionalArea, HounsfieldUnits +from comp2comp.models.models import Models + +# from comp2comp.muscle_adipose_tissue.data import Dataset, predict + +NORMATIVE_VALUES = { + "12-15": {"mean": 0.9708984375000025, "number": 128, "std": 0.14374630102575722}, + "16-19": {"mean": 1.043302985074628, "number": 134, "std": 0.1171623603196859}, + "20-29": {"mean": 1.0714148148148168, "number": 162, "std": 0.12245757380508455}, + "30-39": {"mean": 1.0732317460317455, "number": 126, "std": 0.11887298110191502}, + "40-49": {"mean": 1.0724762430939225, "number": 181, "std": 0.1565666675667262}, + "50-59": {"mean": 1.0113555555555536, "number": 162, "std": 0.15325632970875053}, + "60-69": {"mean": 0.9780291390728507, "number": 151, "std": 0.15834423166478903}, + "8-11": {"mean": 0.7280215686274502, "number": 102, "std": 0.1002736605470211}, +} + +CT_DENSITY_TO_BMD_DENSITY_SLOPE = 0.002533809257151385 +CT_DENSITY_TO_BMD_DENSITY_INTERCEPT = 0.2298020483820905 + + +def CALCULATE_DXA_SCORES(spine_bmd: float, age: int) -> dict: + if age < 20: + age_group = "8-19" + elif age > 69: + age_group = "60-69" + else: + age_group = f"{(age // 10) * 10}-{((age // 10) * 10) + 9}" + + t_score_mean = ( + NORMATIVE_VALUES["20-29"]["mean"] + NORMATIVE_VALUES["30-39"]["mean"] + ) / 2 + t_score_std = ( + NORMATIVE_VALUES["20-29"]["std"] + NORMATIVE_VALUES["30-39"]["std"] + ) / 2 + t_score = (spine_bmd - t_score_mean) / t_score_std + + z_score_mean = NORMATIVE_VALUES[age_group]["mean"] + z_score_std = NORMATIVE_VALUES[age_group]["std"] + z_score = (spine_bmd - z_score_mean) / z_score_std + + return {"T-score": t_score, "Z-score": z_score} + + +class MuscleAdiposeTissueSegmentation(InferenceClass): + """Muscle adipose tissue segmentation class.""" + + def __init__(self, batch_size: int, model_name: str, model_dir: str = None): + super().__init__() + self.batch_size = batch_size + self.model_name = model_name + self.model_type = Models.model_from_name(model_name) + + def forward_pass_2d(self, files): + dataset = Dataset(files, windows=self.model_type.windows) + num_workers = 1 + + print("Computing segmentation masks using {}...".format(self.model_name)) + start_time = perf_counter() + _, preds, results = predict( + self.model, + dataset, + num_workers=num_workers, + use_multiprocessing=num_workers > 1, + batch_size=self.batch_size, + ) + K.clear_session() + print( + f"Completed {len(files)} segmentations in {(perf_counter() - start_time):.2f} seconds." + ) + for i in range(len(results)): + results[i]["preds"] = preds[i] + return results + + def download_muscle_adipose_tissue_model(self, model_dir: Union[str, Path]): + download_dir = Path( + os.path.join( + model_dir, + ".totalsegmentator/nnunet/results/nnUNet/2d/Task927_FatMuscle/nnUNetTrainerV2__nnUNetPlansv2.1", + ) + ) + all_path = download_dir / "all" + if not os.path.exists(all_path): + download_dir.mkdir(parents=True, exist_ok=True) + wget.download( + "https://huggingface.co/stanfordmimi/multilevel_muscle_adipose_tissue/resolve/main/all.zip", + out=os.path.join(download_dir, "all.zip"), + ) + with zipfile.ZipFile(os.path.join(download_dir, "all.zip"), "r") as zip_ref: + zip_ref.extractall(download_dir) + os.remove(os.path.join(download_dir, "all.zip")) + wget.download( + "https://huggingface.co/stanfordmimi/multilevel_muscle_adipose_tissue/resolve/main/plans.pkl", + out=os.path.join(download_dir, "plans.pkl"), + ) + print("Muscle and adipose tissue model downloaded.") + else: + print("Muscle and adipose tissue model already downloaded.") + + def __call__(self, inference_pipeline): + inference_pipeline.muscle_adipose_tissue_model_type = self.model_type + inference_pipeline.muscle_adipose_tissue_model_name = self.model_name + + if self.model_name == "stanford_v0.0.2": + self.download_muscle_adipose_tissue_model(inference_pipeline.model_dir) + nifti_path = os.path.join( + inference_pipeline.output_dir, + "segmentations", + "converted_dcm_multilevel.nii.gz", + ) + output_path = os.path.join( + inference_pipeline.output_dir, + "segmentations", + "multilevel_muscle_fat_seg.nii.gz", + ) + + from nnunet.inference import predict + + predict.predict_cases( + model=os.path.join( + inference_pipeline.model_dir, + ".totalsegmentator/nnunet/results/nnUNet/2d/Task927_FatMuscle/nnUNetTrainerV2__nnUNetPlansv2.1", + ), + list_of_lists=[[nifti_path]], + output_filenames=[output_path], + folds="all", + save_npz=False, + num_threads_preprocessing=8, + num_threads_nifti_save=8, + segs_from_prev_stage=None, + do_tta=False, + mixed_precision=True, + overwrite_existing=False, + all_in_gpu=False, + step_size=0.5, + checkpoint_name="model_final_checkpoint", + segmentation_export_kwargs=None, + ) + + image_nib = nib.load(nifti_path) + image_nib = nib.as_closest_canonical(image_nib) + image = image_nib.get_fdata() + pred = nib.load(output_path) + pred = nib.as_closest_canonical(pred) + pred = pred.get_fdata() + + images = [image[:, :, i] for i in range(image.shape[-1])] + preds = [pred[:, :, i] for i in range(pred.shape[-1])] + + # flip both axes and transpose + images = [np.flip(np.flip(image, axis=0), axis=1).T for image in images] + preds = [np.flip(np.flip(pred, axis=0), axis=1).T for pred in preds] + + spacings = [ + image_nib.header.get_zooms()[0:2] for i in range(image.shape[-1]) + ] + + categories = self.model_type.categories + + # for each image in images, convert to one hot encoding + masks = [] + for pred in preds: + mask = np.zeros((pred.shape[0], pred.shape[1], 4)) + for i, category in enumerate(categories): + mask[:, :, i] = pred == categories[category] + mask = mask.astype(np.uint8) + masks.append(mask) + return {"images": images, "preds": masks, "spacings": spacings} + + else: + dicom_file_paths = inference_pipeline.dicom_file_paths + # if dicom_file_names not an attribute of inference_pipeline, add it + if not hasattr(inference_pipeline, "dicom_file_names"): + inference_pipeline.dicom_file_names = [ + dicom_file_path.stem for dicom_file_path in dicom_file_paths + ] + self.model = self.model_type.load_model(inference_pipeline.model_dir) + + results = self.forward_pass_2d(dicom_file_paths) + images = [] + for result in results: + images.append(result["image"]) + preds = [] + for result in results: + preds.append(result["preds"]) + spacings = [] + for result in results: + spacings.append(result["spacing"]) + + return {"images": images, "preds": preds, "spacings": spacings} + + +class MuscleAdiposeTissuePostProcessing(InferenceClass): + """Post-process muscle and adipose tissue segmentation.""" + + def __init__(self): + super().__init__() + + def preds_to_mask(self, preds): + """Convert model predictions to a mask. + + Args: + preds (np.ndarray): Model predictions. + + Returns: + np.ndarray: Mask. + """ + if self.use_softmax: + # softmax + labels = np.zeros_like(preds, dtype=np.uint8) + l_argmax = np.argmax(preds, axis=-1) + for c in range(labels.shape[-1]): + labels[l_argmax == c, c] = 1 + return labels.astype(np.bool) + else: + # sigmoid + return preds >= 0.5 + + def __call__(self, inference_pipeline, images, preds, spacings): + """Post-process muscle and adipose tissue segmentation.""" + self.model_type = inference_pipeline.muscle_adipose_tissue_model_type + self.use_softmax = self.model_type.use_softmax + self.model_name = inference_pipeline.muscle_adipose_tissue_model_name + return self.post_process(images, preds, spacings) + + def remove_small_objects(self, mask, min_size=10): + mask = mask.astype(np.uint8) + components, output, stats, centroids = cv2.connectedComponentsWithStats( + mask, connectivity=8 + ) + sizes = stats[1:, -1] + mask = np.zeros((output.shape)) + for i in range(0, components - 1): + if sizes[i] >= min_size: + mask[output == i + 1] = 1 + return mask + + def post_process( + self, + images, + preds, + spacings, + ): + categories = self.model_type.categories + + start_time = perf_counter() + + if self.model_name == "stanford_v0.0.2": + masks = preds + else: + masks = [self.preds_to_mask(p) for p in preds] + + for i, _ in enumerate(masks): + # Keep only channels from the model_type categories dict + masks[i] = np.squeeze(masks[i]) + + masks = self.fill_holes(masks) + + cats = list(categories.keys()) + + file_idx = 0 + for mask, image in tqdm(zip(masks, images), total=len(masks)): + muscle_mask = mask[..., cats.index("muscle")] + imat_mask = mask[..., cats.index("imat")] + imat_mask = ( + np.logical_and( + (image * muscle_mask) <= -30, (image * muscle_mask) >= -190 + ) + ).astype(int) + imat_mask = self.remove_small_objects(imat_mask) + mask[..., cats.index("imat")] += imat_mask + mask[..., cats.index("muscle")][imat_mask == 1] = 0 + masks[file_idx] = mask + images[file_idx] = image + file_idx += 1 + + print( + f"Completed post-processing in {(perf_counter() - start_time):.2f} seconds." + ) + + return {"images": images, "masks": masks, "spacings": spacings} + + # function that fills in holes in a segmentation mask + def _fill_holes(self, mask: np.ndarray, mask_id: int): + """Fill in holes in a segmentation mask. + + Args: + mask (ndarray): NxHxW + mask_id (int): Label of the mask. + + Returns: + ndarray: Filled mask. + """ + int_mask = ((1 - mask) > 0.5).astype(np.int8) + components, output, stats, _ = cv2.connectedComponentsWithStats( + int_mask, connectivity=8 + ) + sizes = stats[1:, -1] + components = components - 1 + # Larger threshold for SAT + # TODO make this configurable / parameter + if mask_id == 2: + min_size = 50 + # min_size = 0 + else: + min_size = 5 + # min_size = 0 + img_out = np.ones_like(mask) + for i in range(0, components): + if sizes[i] > min_size: + img_out[output == i + 1] = 0 + return img_out + + def fill_holes(self, ys: List): + """Take an array of size NxHxWxC and for each channel fill in holes. + + Args: + ys (list): List of segmentation masks. + """ + segs = [] + for n in range(len(ys)): + ys_out = [ + self._fill_holes(ys[n][..., i], i) for i in range(ys[n].shape[-1]) + ] + segs.append(np.stack(ys_out, axis=2).astype(float)) + + return segs + + +class MuscleAdiposeTissueComputeMetrics(InferenceClass): + """Compute muscle and adipose tissue metrics.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline, images, masks, spacings): + """Compute muscle and adipose tissue metrics.""" + self.model_type = inference_pipeline.muscle_adipose_tissue_model_type + self.model_name = inference_pipeline.muscle_adipose_tissue_model_name + metrics = self.compute_metrics_all(images, masks, spacings) + return metrics + + def compute_metrics_all(self, images, masks, spacings): + """Compute metrics for all images and masks. + + Args: + images (List[np.ndarray]): Images. + masks (List[np.ndarray]): Masks. + + Returns: + Dict: Results. + """ + results = [] + for image, mask, spacing in zip(images, masks, spacings): + results.append(self.compute_metrics(image, mask, spacing)) + # return {"images": images, "results": results} + return {"results": results} + + def compute_metrics(self, x, mask, spacing): + """Compute results for a given segmentation.""" + categories = self.model_type.categories + + hu = HounsfieldUnits() + csa_units = "cm^2" if spacing else "" + csa = CrossSectionalArea(csa_units) + + hu_vals = hu(mask, x, category_dim=-1) + csa_vals = csa(mask=mask, spacing=spacing, category_dim=-1) + + # check if any values are nan and replace with 0 + hu_vals = np.nan_to_num(hu_vals) + csa_vals = np.nan_to_num(csa_vals) + + assert mask.shape[-1] == len( + categories + ), "{} categories found in mask, " "but only {} categories specified".format( + mask.shape[-1], len(categories) + ) + + results = { + cat: { + "mask": mask[..., idx], + hu.name(): hu_vals[idx], + csa.name(): csa_vals[idx], + } + for idx, cat in enumerate(categories.keys()) + } + return results + + +class MuscleAdiposeTissueH5Saver(InferenceClass): + """Save results to an HDF5 file.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline, results): + """Save results to an HDF5 file.""" + self.model_type = inference_pipeline.muscle_adipose_tissue_model_type + self.model_name = inference_pipeline.muscle_adipose_tissue_model_name + self.output_dir = inference_pipeline.output_dir + self.h5_output_dir = os.path.join(self.output_dir, "segmentations") + os.makedirs(self.h5_output_dir, exist_ok=True) + self.dicom_file_paths = inference_pipeline.dicom_file_paths + self.dicom_file_names = inference_pipeline.dicom_file_names + self.save_results(results) + return {"results": results} + + def save_results(self, results): + """Save results to an HDF5 file.""" + categories = self.model_type.categories + cats = list(categories.keys()) + + for i, result in enumerate(results): + file_name = self.dicom_file_names[i] + with h5py.File( + os.path.join(self.h5_output_dir, file_name + ".h5"), "w" + ) as f: + for cat in cats: + mask = result[cat]["mask"] + f.create_dataset(name=cat, data=np.array(mask, dtype=np.uint8)) + + +class MuscleAdiposeTissueMetricsSaver(InferenceClass): + """Save metrics to a CSV file.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline, results): + """Save metrics to a CSV file.""" + self.model_type = inference_pipeline.muscle_adipose_tissue_model_type + self.model_name = inference_pipeline.muscle_adipose_tissue_model_name + self.output_dir = inference_pipeline.output_dir + self.age = inference_pipeline.age + self.csv_output_dir = os.path.join(self.output_dir, "metrics") + self.segmentation_output_dir = os.path.join(self.output_dir, "segmentations") + os.makedirs(self.csv_output_dir, exist_ok=True) + self.dicom_file_paths = inference_pipeline.dicom_file_paths + self.dicom_file_names = inference_pipeline.dicom_file_names + self.save_results(results) + return {} + + def save_results(self, results): + """Save results to a CSV file.""" + self.model_type.categories + df = pd.DataFrame( + columns=[ + "Level", + "Index", + "Muscle HU", + "Muscle CSA (cm^2)", + "SAT HU", + "SAT CSA (cm^2)", + "VAT HU", + "VAT CSA (cm^2)", + "IMAT HU", + "IMAT CSA (cm^2)", + ] + ) + + for i, result in enumerate(results): + row = [] + row.append(self.dicom_file_names[i]) + row.append(self.dicom_file_paths[i]) + for cat in result: + row.append(result[cat]["Hounsfield Unit"]) + row.append(result[cat]["Cross-sectional Area (cm^2)"]) + df.loc[i] = row + df = df.iloc[::-1] + df.to_csv( + os.path.join(self.csv_output_dir, "muscle_adipose_tissue_metrics.csv"), + index=False, + ) + + metrics_data = pd.read_csv( + os.path.join(self.csv_output_dir, "muscle_adipose_tissue_metrics.csv") + ) + + # print(metrics_data) + + levels_to_include = ["L1", "L2", "L3", "L4"] + + exam_per_level_densities = {} + try: + for level in levels_to_include: + exam_per_level_densities[level] = metrics_data.loc[ + metrics_data["Level"] == level + ][f"VAT HU"].values[0] + except: + print("No VAT") + return + np.mean(list(exam_per_level_densities.values())) + air_roi_values = -1000 + + image_path = os.path.join( + self.segmentation_output_dir, "converted_dcm_multilevel.nii.gz" + ) + seg_path = os.path.join( + self.segmentation_output_dir, "multilevel_muscle_fat_seg.nii.gz" + ) + + image = nib.load(image_path) + image = nib.as_closest_canonical(image) + pixel_spacing = image.header["pixdim"][1:4] + image = image.get_fdata() + seg = nib.load(seg_path) + seg = nib.as_closest_canonical(seg) + seg = seg.get_fdata() + sat_seg = (seg == 1).astype(int) + vat_seg = (seg == 2).astype(int) + + try: + position_anterior_most_sat_pixel = np.max(np.where(sat_seg == 1)[1]) + except: + print("No SAT") + return + air_seg = np.zeros_like(sat_seg) + left_right_pixels = 50 / pixel_spacing[0] + middle_left_right_pixel = image.shape[0] / 2 + anterior_posterior_pixels = 20 / pixel_spacing[1] + air_seg[ + int(middle_left_right_pixel - left_right_pixels) : int( + middle_left_right_pixel + left_right_pixels + ), + int(position_anterior_most_sat_pixel + anterior_posterior_pixels) : int( + position_anterior_most_sat_pixel + (2 * anterior_posterior_pixels) + ), + :, + ] = 1 + + air_measured = np.mean(image[air_seg == 1]) + # if air roi is not between -1050 and -950, then continue + if air_roi_values < -1050 or air_roi_values > -950: + print("Error, air ROI is not between -1050 and -950 HU, aborting scan.") + print("Air ROI: {air_roi_values:1.f} HU") + return + + vat_measured = np.mean(image[vat_seg == 1]) + print(f"Air ROI: {air_measured:.1f} HU") + print(f"VAT ROI: {vat_measured:.1f} HU") + + hu_vat = -95 + hu_air = -1000 + slope = (hu_vat - hu_air) / (vat_measured - air_measured) + intercept = hu_air - slope * air_measured + # calibrated_value = slope * value_measured + intercept + + spine_metrics = pd.read_csv( + os.path.join(self.csv_output_dir, "spine_metrics.csv") + ) + spine_metrics["Calibrated Seg HU"] = spine_metrics["Seg HU"].apply( + lambda x: slope * x + intercept + ) + spine_metrics.to_csv( + os.path.join(self.csv_output_dir, "spine_metrics.csv"), index=False + ) + + self.PredictDXAScores() + + def PredictDXAScores(self): + metrics_data = pd.read_csv( + os.path.join(self.csv_output_dir, "muscle_adipose_tissue_metrics.csv") + ) + spine_metrics = pd.read_csv( + os.path.join(self.csv_output_dir, "spine_metrics.csv") + ) + + levels_to_include = ["L1", "L2", "L3", "L4"] + average_calibrated_hu = spine_metrics.loc[ + spine_metrics["Level"].isin(levels_to_include) + ]["Calibrated Seg HU"].mean() + + binary_prediction = average_calibrated_hu < 300 + dxa_density_prediction = ( + CT_DENSITY_TO_BMD_DENSITY_SLOPE * average_calibrated_hu + + CT_DENSITY_TO_BMD_DENSITY_INTERCEPT + ) + normative_scores = CALCULATE_DXA_SCORES(dxa_density_prediction, self.age) + dxa_t_score_prediction = normative_scores["T-score"] + dxa_z_score_prediction = normative_scores["Z-score"] + + dxa_scores = pd.DataFrame( + columns=[ + "Mean L1-L4 Calibrated HU", + "Binary Prediction", + "Density Prediction", + "T-score Prediction", + "Z-score Prediction", + "Age", + ] + ) + dxa_scores.loc[0] = [ + average_calibrated_hu, + int(binary_prediction), + dxa_density_prediction, + dxa_t_score_prediction, + dxa_z_score_prediction, + self.age, + ] + dxa_scores.to_csv( + os.path.join(self.csv_output_dir, "dxa_predictions.csv"), index=False + ) + + print("\n" * 3) + print("#" * 80) + print("BMD Report:") + print( + "BMD score prediction: {} (T-score {} -1)\n".format( + "ABNORMAL" if binary_prediction else "NORMAL", + "<" if binary_prediction else "≥", + ) + ) + print("Predicted Z and T-score (NOT FDA approved):") + print(f"Predicted Z-score: {dxa_z_score_prediction:.1f}") + print(f"Predicted T-score: {dxa_t_score_prediction:.1f}") + + print("#" * 80) + print("\n" * 3) + + return spine_metrics diff --git a/comp2comp/spine/fda_spine.py b/comp2comp/spine/fda_spine.py new file mode 100644 index 00000000..a52f8132 --- /dev/null +++ b/comp2comp/spine/fda_spine.py @@ -0,0 +1,391 @@ +""" +@author: louisblankemeier +""" + +import os +import zipfile +from pathlib import Path +from time import time +from typing import Union + +import nibabel as nib +import numpy as np +import pandas as pd +import wget +from totalsegmentatorv2.python_api import totalsegmentator + +from comp2comp.inference_class_base import InferenceClass + +# from comp2comp.visualization.dicom import to_dicom +from comp2comp.models.fda_models import Models +from comp2comp.spine import fda_spine_utils + +# from totalsegmentator.libs import ( +# download_pretrained_weights, +# nostdout, +# setup_nnunet, +# ) + + +class SpineSegmentation(InferenceClass): + """Spine segmentation.""" + + def __init__(self, model_name, save=True): + super().__init__() + self.model_name = model_name + self.save_segmentations = save + + def __call__(self, inference_pipeline): + # inference_pipeline.dicom_series_path = self.input_path + self.output_dir = inference_pipeline.output_dir + self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") + if not os.path.exists(self.output_dir_segmentations): + os.makedirs(self.output_dir_segmentations) + + self.model_dir = inference_pipeline.model_dir + + # seg, mv = self.spine_seg( + # os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + # self.output_dir_segmentations + "spine.nii.gz", + # inference_pipeline.model_dir, + # ) + os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir + + seg = totalsegmentator( + input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + output=os.path.join(self.output_dir_segmentations, "segmentation.nii"), + task_ids=[292], + ml=True, + nr_thr_resamp=1, + nr_thr_saving=6, + fast=False, + nora_tag="None", + preview=False, + task="total", + # roi_subset=[ + # "vertebrae_T12", + # "vertebrae_L1", + # "vertebrae_L2", + # "vertebrae_L3", + # "vertebrae_L4", + # "vertebrae_L5", + # ], + roi_subset=None, + statistics=False, + radiomics=False, + crop_path=None, + body_seg=False, + force_split=False, + output_type="nifti", + quiet=False, + verbose=False, + test=0, + skip_saving=True, + device="gpu", + license_number=None, + statistics_exclude_masks_at_border=True, + no_derived_masks=False, + v1_order=False, + ) + mv = nib.load( + os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz") + ) + + # save the seg + nib.save( + seg, + os.path.join(self.output_dir_segmentations, "spine_seg.nii.gz"), + ) + + # inference_pipeline.segmentation = nib.load( + # os.path.join(self.output_dir_segmentations, "segmentation.nii") + # ) + inference_pipeline.segmentation = seg + inference_pipeline.medical_volume = mv + inference_pipeline.save_segmentations = self.save_segmentations + return {} + + def setup_nnunet_c2c(self, model_dir: Union[str, Path]): + """Adapted from TotalSegmentator.""" + + model_dir = Path(model_dir) + config_dir = model_dir / Path("." + self.model_name) + (config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir( + exist_ok=True, parents=True + ) + (config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True) + weights_dir = config_dir / "nnunet/results" + self.weights_dir = weights_dir + + os.environ["nnUNet_raw_data_base"] = str( + weights_dir + ) # not needed, just needs to be an existing directory + os.environ["nnUNet_preprocessed"] = str( + weights_dir + ) # not needed, just needs to be an existing directory + os.environ["RESULTS_FOLDER"] = str(weights_dir) + + def download_spine_model(self, model_dir: Union[str, Path]): + download_dir = Path( + os.path.join( + self.weights_dir, + "nnUNet/3d_fullres/Task252_Spine/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1", + ) + ) + fold_0_path = download_dir / "fold_0" + if not os.path.exists(fold_0_path): + download_dir.mkdir(parents=True, exist_ok=True) + wget.download( + "https://huggingface.co/louisblankemeier/spine_v1/resolve/main/fold_0.zip", + out=os.path.join(download_dir, "fold_0.zip"), + ) + with zipfile.ZipFile( + os.path.join(download_dir, "fold_0.zip"), "r" + ) as zip_ref: + zip_ref.extractall(download_dir) + os.remove(os.path.join(download_dir, "fold_0.zip")) + wget.download( + "https://huggingface.co/louisblankemeier/spine_v1/resolve/main/plans.pkl", + out=os.path.join(download_dir, "plans.pkl"), + ) + print("Spine model downloaded.") + else: + print("Spine model already downloaded.") + + def spine_seg( + self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir + ): + """Run spine segmentation. + + Args: + input_path (Union[str, Path]): Input path. + output_path (Union[str, Path]): Output path. + """ + + print("Segmenting spine...") + st = time() + os.environ["SCRATCH"] = self.model_dir + os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir + + # Setup nnunet + model = "3d_fullres" + folds = [0] + trainer = "nnUNetTrainerV2_ep4000_nomirror" + crop_path = None + task_id = [252] + + if self.model_name == "ts_spine": + setup_nnunet() + download_pretrained_weights(task_id[0]) + elif self.model_name == "stanford_spine_v0.0.1": + self.setup_nnunet_c2c(model_dir) + self.download_spine_model(model_dir) + else: + raise ValueError("Invalid model name.") + + if not self.save_segmentations: + output_path = None + + from totalsegmentator.nnunet import nnUNet_predict_image + + with nostdout(): + img, seg = nnUNet_predict_image( + input_path, + output_path, + task_id, + model=model, + folds=folds, + trainer=trainer, + tta=False, + multilabel_image=True, + resample=1.5, + crop=None, + crop_path=crop_path, + task_name="total", + nora_tag="None", + preview=False, + nr_threads_resampling=1, + nr_threads_saving=6, + quiet=False, + verbose=False, + test=0, + ) + end = time() + + # Log total time for spine segmentation + print(f"Total time for spine segmentation: {end-st:.2f}s.") + + if self.model_name == "stanford_spine_v0.0.1": + seg_data = seg.get_fdata() + # subtract 17 from seg values except for 0 + seg_data = np.where(seg_data == 0, 0, seg_data - 17) + seg = nib.Nifti1Image(seg_data, seg.affine, seg.header) + + return seg, img + + +class AxialCropper(InferenceClass): + """Crop the CT image (medical_volume) and segmentation based on user-specified + lower and upper levels of the spine. + """ + + def __init__(self, lower_level: str = "L5", upper_level: str = "L1", save=True): + """ + Args: + lower_level (str, optional): Lower level of the spine. Defaults to "L5". + upper_level (str, optional): Upper level of the spine. Defaults to "L1". + save (bool, optional): Save cropped image and segmentation. Defaults to True. + + Raises: + ValueError: If lower_level or upper_level is not a valid spine level. + """ + super().__init__() + self.lower_level = lower_level + self.upper_level = upper_level + ts_spine_full_model = Models.model_from_name("ts_spine") + categories = ts_spine_full_model.categories + try: + self.lower_level_index = categories[self.lower_level] + self.upper_level_index = categories[self.upper_level] + except KeyError: + raise ValueError("Invalid spine level.") from None + self.save = save + + def __call__(self, inference_pipeline): + """ + First dim goes from L to R. + Second dim goes from P to A. + Third dim goes from I to S. + """ + segmentation = inference_pipeline.segmentation + segmentation_data = segmentation.get_fdata() + try: + upper_level_index = np.where(segmentation_data == self.upper_level_index)[ + 2 + ].max() + except: + upper_level_index = segmentation_data.shape[2] + try: + lower_level_index = np.where(segmentation_data == self.lower_level_index)[ + 2 + ].min() + except: + lower_level_index = 0 + segmentation = segmentation.slicer[:, :, lower_level_index:upper_level_index] + inference_pipeline.segmentation = segmentation + + medical_volume = inference_pipeline.medical_volume + medical_volume = medical_volume.slicer[ + :, :, lower_level_index:upper_level_index + ] + inference_pipeline.medical_volume = medical_volume + + if self.save: + nib.save( + segmentation, + os.path.join( + inference_pipeline.output_dir, "segmentations", "spine.nii.gz" + ), + ) + nib.save( + medical_volume, + os.path.join( + inference_pipeline.output_dir, + "segmentations", + "converted_dcm.nii.gz", + ), + ) + return {} + + +class SpineComputeROIs(InferenceClass): + def __init__(self, spine_model): + super().__init__() + self.spine_model_name = spine_model + self.spine_model_type = Models.model_from_name(self.spine_model_name) + + def __call__(self, inference_pipeline): + # Compute ROIs + inference_pipeline.spine_model_type = self.spine_model_type + + (_, rois, segmentation_hus, centroids_3d, _) = fda_spine_utils.compute_rois( + inference_pipeline.segmentation, + inference_pipeline.medical_volume, + self.spine_model_type, + ) + + inference_pipeline.segmentation_hus = segmentation_hus + inference_pipeline.centroids_3d = centroids_3d + inference_pipeline.rois = rois + + return {} + + +class SpineMetricsSaver(InferenceClass): + """Save metrics to a CSV file.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + """Save metrics to a CSV file.""" + # self.spine_hus = inference_pipeline.spine_hus + self.seg_hus = inference_pipeline.segmentation_hus + self.output_dir = inference_pipeline.output_dir + # self.bounds = inference_pipeline.bounds + self.csv_output_dir = os.path.join(self.output_dir, "metrics") + if not os.path.exists(self.csv_output_dir): + os.makedirs(self.csv_output_dir, exist_ok=True) + self.save_results() + # if hasattr(inference_pipeline, "dicom_ds"): + # if not os.path.exists(os.path.join(self.output_dir, "dicom_metadata.csv")): + # io_utils.write_dicom_metadata_to_csv( + # inference_pipeline.dicom_ds, + # os.path.join(self.output_dir, "dicom_metadata.csv"), + # ) + + return {} + + def save_results(self): + """Save results to a CSV file.""" + # df = pd.DataFrame(columns=["Level", "ROI HU", "Seg HU"]) + # for i, level in enumerate(self.spine_hus): + # hu = self.spine_hus[level] + # seg_hu = self.seg_hus[level] + # row = [level, hu, seg_hu] + # df.loc[i] = row + # df = df.iloc[::-1] + # df.to_csv(os.path.join(self.csv_output_dir, "spine_metrics.csv"), index=False) + df = pd.DataFrame( + columns=["Level", "Seg HU"] + ) # , "Lower Bound", "Upper Bound"]) + for i, level in enumerate(self.seg_hus): + # hu = self.spine_hus[level] + seg_hu = self.seg_hus[level] + # bounds = self.bounds[level] + row = [level, seg_hu] # , bounds[0], bounds[1]] + df.loc[i] = row + df = df.iloc[::-1] + df.to_csv(os.path.join(self.csv_output_dir, "spine_metrics.csv"), index=False) + + +class SpineFindDicoms(InferenceClass): + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + inferior_superior_centers = fda_spine_utils.find_spine_dicoms( + inference_pipeline.centroids_3d, + ) + + fda_spine_utils.save_nifti_select_slices( + inference_pipeline.output_dir, inferior_superior_centers + ) + inference_pipeline.dicom_file_paths = [ + str(center) for center in inferior_superior_centers + ] + inference_pipeline.names = list(inference_pipeline.rois.keys()) + inference_pipeline.dicom_file_names = list(inference_pipeline.rois.keys()) + inference_pipeline.inferior_superior_centers = inferior_superior_centers + + return {} diff --git a/comp2comp/spine/fda_spine_utils.py b/comp2comp/spine/fda_spine_utils.py new file mode 100644 index 00000000..df89a1f6 --- /dev/null +++ b/comp2comp/spine/fda_spine_utils.py @@ -0,0 +1,847 @@ +""" +@author: louisblankemeier +""" + +import logging +import math +import os +import time +from typing import Dict, List + +import cv2 +import nibabel as nib +import numpy as np +import scipy +from scipy.ndimage import zoom + +# from comp2comp.spine import spine_visualization + + +def find_spine_dicoms(centroids: Dict): # , path: str, levels): + """Find the dicom files corresponding to the spine T12 - L5 levels.""" + + vertical_positions = [] + for level in centroids: + centroid = centroids[level] + vertical_positions.append(round(centroid[2])) + + # dicom_files = [] + # ipps = [] + # for dicom_path in glob(path + "/*.dcm"): + # ipp = dcmread(dicom_path).ImagePositionPatient + # ipps.append(ipp[2]) + # dicom_files.append(dicom_path) + + # dicom_files = [x for _, x in sorted(zip(ipps, dicom_files))] + # dicom_files = list(np.array(dicom_files)[vertical_positions]) + + # return (dicom_files, levels, vertical_positions) + return vertical_positions + + +def save_nifti_select_slices(output_dir: str, vertical_positions): + nifti_path = os.path.join(output_dir, "segmentations", "converted_dcm.nii.gz") + nifti_in = nib.load(nifti_path) + nifti_np = nifti_in.get_fdata() + nifti_np = nifti_np[:, :, vertical_positions] + nifti_out = nib.Nifti1Image(nifti_np, nifti_in.affine, nifti_in.header) + # save the nifti + nifti_output_path = os.path.join( + output_dir, "segmentations", "converted_dcm_multilevel.nii.gz" + ) + nib.save(nifti_out, nifti_output_path) + + +# Function that takes a numpy array as input, computes the +# sagittal centroid of each label and returns a list of the +# centroids +def compute_centroids(seg: np.ndarray, spine_model_type): + """Compute the centroids of the labels. + + Args: + seg (np.ndarray): Segmentation volume. + spine_model_type (str): Model type. + + Returns: + List[int]: List of centroids. + """ + # take values of spine_model_type.categories dictionary + # and convert to list + centroids = {} + for level in spine_model_type.categories: + label_idx = spine_model_type.categories[level] + try: + pos = compute_centroid(seg, "sagittal", label_idx) + centroids[level] = pos + except Exception: + logging.warning(f"Label {level} not found in segmentation volume.") + return centroids + + +# Function that takes a numpy array as input, as well as a list of centroids, +# takes a slice through the centroid on axis = 1 for each centroid +# and returns a list of the slices +def get_slices(seg: np.ndarray, centroids: Dict, spine_model_type): + """Get the slices corresponding to the centroids. + + Args: + seg (np.ndarray): Segmentation volume. + centroids (List[int]): List of centroids. + spine_model_type (str): Model type. + + Returns: + List[np.ndarray]: List of slices. + """ + seg = seg.astype(np.uint8) + slices = {} + for level in centroids: + label_idx = spine_model_type.categories[level] + binary_seg = (seg[centroids[level], :, :] == label_idx).astype(int) + if ( + np.sum(binary_seg) > 200 + ): # heuristic to make sure enough of the body is showing + slices[level] = binary_seg + return slices + + +# Function that takes a mask and for each deletes the right most +# connected component. Returns the mask with the right most +# connected component deleted +def delete_right_most_connected_component(mask: np.ndarray): + """Delete the right most connected component corresponding to spinous processes. + + Args: + mask (np.ndarray): Mask volume. + + Returns: + np.ndarray: Mask volume. + """ + mask = mask.astype(np.uint8) + _, labels, _, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8) + right_most_connected_component = np.argmin(centroids[1:, 1]) + 1 + mask[labels == right_most_connected_component] = 0 + return mask + + +# compute center of mass of 2d mask +def compute_center_of_mass(mask: np.ndarray): + """Compute the center of mass of a 2D mask. + + Args: + mask (np.ndarray): Mask volume. + + Returns: + np.ndarray: Center of mass. + """ + mask = mask.astype(np.uint8) + _, _, _, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8) + center_of_mass = np.mean(centroids[1:, :], axis=0) + return center_of_mass + + +# Function that takes a 3d centroid and retruns a binary mask with a 3d +# roi around the centroid +def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray): + """Compute a 3D ROI from a 3D mask. + + Args: + img (np.ndarray): Image volume. + centroid (np.ndarray): Centroid. + + Returns: + np.ndarray: ROI volume. + """ + roi = np.zeros(img.shape) + + img_np = img.get_fdata() + + pixel_spacing = img.header.get_zooms() + length_i = 5.0 / pixel_spacing[0] + length_j = 5.0 / pixel_spacing[1] + length_k = 5.0 / pixel_spacing[2] + + print( + f"Computing ROI with centroid {centroid[0]:.3f}, {centroid[1]:.3f}, {centroid[2]:.3f} " + f"and pixel spacing " + f"{pixel_spacing[0]:.3f}mm, {pixel_spacing[1]:.3f}mm, {pixel_spacing[2]:.3f}mm..." + ) + + # cubic ROI around centroid + """ + roi[ + int(centroid[0] - length) : int(centroid[0] + length), + int(centroid[1] - length) : int(centroid[1] + length), + int(centroid[2] - length) : int(centroid[2] + length), + ] = 1 + """ + # spherical ROI around centroid + spherical = False + roi = np.zeros(img_np.shape) + + if spherical: + i_lower = math.floor(centroid[0] - length_i) + j_lower = math.floor(centroid[1] - length_j) + k_lower = math.floor(centroid[2] - length_k) + i_lower_idx = 1000 + j_lower_idx = 1000 + k_lower_idx = 1000 + i_upper_idx = 0 + j_upper_idx = 0 + k_upper_idx = 0 + found_pixels = False + for i in range(i_lower, i_lower + 2 * math.ceil(length_i) + 1): + for j in range(j_lower, j_lower + 2 * math.ceil(length_j) + 1): + for k in range(k_lower, k_lower + 2 * math.ceil(length_k) + 1): + if (i - centroid[0]) ** 2 / length_i**2 + ( + j - centroid[1] + ) ** 2 / length_j**2 + ( + k - centroid[2] + ) ** 2 / length_k**2 <= 1: + roi[i, j, k] = 1 + if i < i_lower_idx: + i_lower_idx = i + if j < j_lower_idx: + j_lower_idx = j + if k < k_lower_idx: + k_lower_idx = k + if i > i_upper_idx: + i_upper_idx = i + if j > j_upper_idx: + j_upper_idx = j + if k > k_upper_idx: + k_upper_idx = k + found_pixels = True + if not found_pixels: + print("No pixels in ROI!") + raise ValueError + print( + f"Number of pixels included in i, j, and k directions: {i_upper_idx - i_lower_idx + 1}, " + f"{j_upper_idx - j_lower_idx + 1}, {k_upper_idx - k_lower_idx + 1}" + ) + return roi + else: + roi_start_time = time.time() + + mask = None + inferior_superior_line = seg[int(centroid[0]), int(centroid[1]), :] + # get the center point + updated_z_center = np.mean(np.where(inferior_superior_line == 1)) + lower_z_idx = updated_z_center - ((length_k * 1.5) // 2) + upper_z_idx = updated_z_center + ((length_k * 1.5) // 2) + for idx in range(int(lower_z_idx), int(upper_z_idx) + 1): + # take multiple to increase robustness + posterior_anterior_lines = [ + slice[:, idx], + slice[:, idx + 1], + slice[:, idx - 1], + ] + posterior_anterior_sums = [ + np.sum(posterior_anterior_lines[0]), + np.sum(posterior_anterior_lines[1]), + np.sum(posterior_anterior_lines[2]), + ] + min_idx = np.argmin(posterior_anterior_sums) + + posterior_anterior_line = posterior_anterior_lines[min_idx] + updated_posterior_anterior_center = ( + np.min(np.where(posterior_anterior_line == 1)) + + np.sum(posterior_anterior_line) * 0.58 + ) + posterior_anterior_length = (posterior_anterior_sums[min_idx] * 0.5) // 2 + + left_right_lines = [ + seg[:, int(updated_posterior_anterior_center), idx], + seg[:, int(updated_posterior_anterior_center) + 1, idx], + seg[:, int(updated_posterior_anterior_center) - 1, idx], + ] + + left_right_sums = [ + np.sum(left_right_lines[0]), + np.sum(left_right_lines[1]), + np.sum(left_right_lines[2]), + ] + + min_idx = np.argmin(left_right_sums) + left_right_line = left_right_lines[min_idx] + + updated_left_right_center = np.mean(np.where(left_right_line == 1)) + left_right_length = (left_right_sums[min_idx] * 0.65) // 2 + + roi_2d = np.zeros((img_np.shape[0], img_np.shape[1])) + h = updated_left_right_center + k = updated_posterior_anterior_center + + # Semi-axes lengths (a, b) + a = left_right_length + b = posterior_anterior_length + + # Calculate the min and max indices for x and y + x_min = int(max(h - a, 0)) + x_max = int(min(h + a, img_np.shape[0])) + y_min = int(max(k - b, 0)) + y_max = int(min(k + b, img_np.shape[1])) + + # Create an oval ROI within the bounds + for x in range(x_min - 2, x_max + 2): + for y in range(y_min - 2, y_max + 2): + if ((x - h) / a) ** 2 + ((y - k) / b) ** 2 <= 1: + roi_2d[x, y] = 1 + roi[:, :, idx] = roi_2d + if idx == int(centroid[2]): + mask = np.flip(np.flip(np.transpose(roi_2d), axis=0), axis=1) + if idx == updated_z_center: + updated_mask = np.flip(np.flip(np.transpose(roi_2d), axis=0), axis=1) + if mask is None: + mask = updated_mask + + start_time = time.time() + + # Make sure there is no overlap with the cortical bone + num_iteration = 2 + if pixel_spacing[2] >= 3: + num_iteration = 1 + struct = scipy.ndimage.generate_binary_structure(3, 1) + struct = scipy.ndimage.iterate_structure(struct, num_iteration) + seg = scipy.ndimage.binary_erosion(seg, structure=struct).astype(np.int8) + + roi = roi * seg + + end_time = time.time() + roi_end_time = time.time() + elapsed_time = end_time - start_time + print(f"Elapsed time for erosion operation: {elapsed_time} seconds") + + elapsed_time = roi_end_time - roi_start_time + print(f"Elapsed time for full ROI computation: {elapsed_time} seconds") + + return roi, mask + + +# Function that takes a 3d image and a 3d binary mask and returns that average +# value of the image inside the mask +def mean_img_mask(img: np.ndarray, mask: np.ndarray, index: int): + """Compute the mean of an image inside a mask. + + Args: + img (np.ndarray): Image volume. + mask (np.ndarray): Mask volume. + rescale_slope (float): Rescale slope. + rescale_intercept (float): Rescale intercept. + + Returns: + float: Mean value. + """ + img = img.astype(np.float32) + mask = mask.astype(np.float32) + img_masked = (img * mask)[mask > 0] + # mean = (rescale_slope * np.mean(img_masked)) + rescale_intercept + # median = (rescale_slope * np.median(img_masked)) + rescale_intercept + mean = np.mean(img_masked) + return mean + + +def compute_rois(seg, img, spine_model_type): + """Compute the ROIs for the spine. + + Args: + seg (np.ndarray): Segmentation volume. + img (np.ndarray): Image volume. + rescale_slope (float): Rescale slope. + rescale_intercept (float): Rescale intercept. + spine_model_type (Models): Model type. + + Returns: + spine_hus (List[float]): List of HU values. + rois (List[np.ndarray]): List of ROIs. + centroids_3d (List[np.ndarray]): List of centroids. + """ + seg_np = seg.get_fdata() + centroids = compute_centroids(seg_np, spine_model_type) + slices = get_slices(seg_np, centroids, spine_model_type) + for level in slices: + slice = slices[level] + # keep only the two largest connected components + two_largest, two = keep_two_largest_connected_components(slice) + if two: + slices[level] = delete_right_most_connected_component(two_largest) + + # Compute ROIs + rois = {} + spine_hus = {} + centroids_3d = {} + segmentation_hus = {} + spine_masks = {} + for i, level in enumerate(slices): + slice = slices[level] + center_of_mass = compute_center_of_mass(slice) + centroid = np.array([centroids[level], center_of_mass[1], center_of_mass[0]]) + roi, mask_2d = roi_from_mask( + img, + centroid, + (seg_np == spine_model_type.categories[level]).astype(int), + slice, + ) + image_numpy = img.get_fdata() + spine_hus[level] = mean_img_mask(image_numpy, roi, i) + rois[level] = roi + mask = (seg_np == spine_model_type.categories[level]).astype(int) + segmentation_hus[level] = mean_img_mask(image_numpy, mask, i) + centroids_3d[level] = centroid + spine_masks[level] = mask_2d + return (spine_hus, rois, segmentation_hus, centroids_3d, spine_masks) + + +def keep_two_largest_connected_components(mask: Dict): + """Keep the two largest connected components. + + Args: + mask (np.ndarray): Mask volume. + + Returns: + np.ndarray: Mask volume. + """ + mask = mask.astype(np.uint8) + # sort connected components by size + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( + mask, connectivity=8 + ) + stats = stats[1:, 4] + sorted_indices = np.argsort(stats)[::-1] + # keep only the two largest connected components + mask = np.zeros(mask.shape) + mask[labels == sorted_indices[0] + 1] = 1 + two = True + try: + mask[labels == sorted_indices[1] + 1] = 1 + except Exception: + two = False + return (mask, two) + + +def compute_centroid(seg: np.ndarray, plane: str, label: int): + """Compute the centroid of a label in a given plane. + + Args: + seg (np.ndarray): Segmentation volume. + plane (str): Plane. + label (int): Label. + + Returns: + int: Centroid. + """ + if plane == "axial": + sum_out_axes = (0, 1) + sum_axis = 2 + elif plane == "sagittal": + sum_out_axes = (1, 2) + sum_axis = 0 + elif plane == "coronal": + sum_out_axes = (0, 2) + sum_axis = 1 + sums = np.sum(seg == label, axis=sum_out_axes) + normalized_sums = sums / np.sum(sums) + pos = int(np.sum(np.arange(0, seg.shape[sum_axis]) * normalized_sums)) + return pos + + +def to_one_hot(label: np.ndarray, model_type, spine_hus): + """Convert a label to one-hot encoding. + + Args: + label (np.ndarray): Label volume. + model_type (Models): Model type. + + Returns: + np.ndarray: One-hot encoding volume. + """ + levels = list(spine_hus.keys()) + levels.reverse() + one_hot_label = np.zeros((label.shape[0], label.shape[1], len(levels))) + for i, level in enumerate(levels): + label_idx = model_type.categories[level] + one_hot_label[:, :, i] = (label == label_idx).astype(int) + return one_hot_label + + +def visualize_coronal_sagittal_spine( + seg: np.ndarray, + rois: List[np.ndarray], + mvs: np.ndarray, + centroids_3d: np.ndarray, + output_dir: str, + spine_hus=None, + seg_hus=None, + model_type=None, + pixel_spacing=None, + format="png", +): + """Visualize the coronal and sagittal planes of the spine. + + Args: + seg (np.ndarray): Segmentation volume. + rois (List[np.ndarray]): List of ROIs. + mvs (dm.MedicalVolume): Medical volume. + centroids (List[int]): List of centroids. + label_text (List[str]): List of labels. + output_dir (str): Output directory. + spine_hus (List[float], optional): List of HU values. Defaults to None. + model_type (Models, optional): Model type. Defaults to None. + """ + + sagittal_vals, coronal_vals = curved_planar_reformation(mvs, centroids_3d) + zoom_factor = pixel_spacing[2] / pixel_spacing[1] + + sagittal_image = mvs[sagittal_vals, :, range(len(sagittal_vals))] + sagittal_label = seg[sagittal_vals, :, range(len(sagittal_vals))] + sagittal_image = zoom(sagittal_image, (zoom_factor, 1), order=3) + sagittal_label = zoom(sagittal_label, (zoom_factor, 1), order=1).round() + + one_hot_sag_label = to_one_hot(sagittal_label, model_type, spine_hus) + for roi in rois: + one_hot_roi_label = roi[sagittal_vals, :, range(len(sagittal_vals))] + one_hot_roi_label = zoom(one_hot_roi_label, (zoom_factor, 1), order=1).round() + one_hot_sag_label = np.concatenate( + ( + one_hot_sag_label, + one_hot_roi_label.reshape( + (one_hot_roi_label.shape[0], one_hot_roi_label.shape[1], 1) + ), + ), + axis=2, + ) + + coronal_image = mvs[:, coronal_vals, range(len(coronal_vals))] + coronal_label = seg[:, coronal_vals, range(len(coronal_vals))] + coronal_image = zoom(coronal_image, (1, zoom_factor), order=3) + coronal_label = zoom(coronal_label, (1, zoom_factor), order=1).round() + + # coronal_image = zoom(coronal_image, (zoom_factor, 1), order=3) + # coronal_label = zoom(coronal_label, (zoom_factor, 1), order=0).astype(int) + + one_hot_cor_label = to_one_hot(coronal_label, model_type, spine_hus) + for roi in rois: + one_hot_roi_label = roi[:, coronal_vals, range(len(coronal_vals))] + one_hot_roi_label = zoom(one_hot_roi_label, (1, zoom_factor), order=1).round() + one_hot_cor_label = np.concatenate( + ( + one_hot_cor_label, + one_hot_roi_label.reshape( + (one_hot_roi_label.shape[0], one_hot_roi_label.shape[1], 1) + ), + ), + axis=2, + ) + + # flip both axes of coronal image + sagittal_image = np.flip(sagittal_image, axis=0) + sagittal_image = np.flip(sagittal_image, axis=1) + + # flip both axes of coronal label + one_hot_sag_label = np.flip(one_hot_sag_label, axis=0) + one_hot_sag_label = np.flip(one_hot_sag_label, axis=1) + + coronal_image = np.transpose(coronal_image) + one_hot_cor_label = np.transpose(one_hot_cor_label, (1, 0, 2)) + + # flip both axes of sagittal image + coronal_image = np.flip(coronal_image, axis=0) + coronal_image = np.flip(coronal_image, axis=1) + + # flip both axes of sagittal label + one_hot_cor_label = np.flip(one_hot_cor_label, axis=0) + one_hot_cor_label = np.flip(one_hot_cor_label, axis=1) + + if format == "png": + sagittal_name = "spine_sagittal.png" + coronal_name = "spine_coronal.png" + elif format == "dcm": + sagittal_name = "spine_sagittal.dcm" + coronal_name = "spine_coronal.dcm" + else: + raise ValueError("Format must be either png or dcm") + + img_sagittal = spine_visualization.spine_binary_segmentation_overlay( + sagittal_image, + one_hot_sag_label, + output_dir, + sagittal_name, + spine_hus=spine_hus, + seg_hus=seg_hus, + model_type=model_type, + pixel_spacing=pixel_spacing, + ) + img_coronal = spine_visualization.spine_binary_segmentation_overlay( + coronal_image, + one_hot_cor_label, + output_dir, + coronal_name, + spine_hus=spine_hus, + seg_hus=seg_hus, + model_type=model_type, + pixel_spacing=pixel_spacing, + ) + + return img_sagittal, img_coronal + + +def curved_planar_reformation(mvs, centroids): + centroids = sorted(centroids, key=lambda x: x[2]) + centroids = [(int(x[0]), int(x[1]), int(x[2])) for x in centroids] + sagittal_centroids = [centroids[i][0] for i in range(0, len(centroids))] + coronal_centroids = [centroids[i][1] for i in range(0, len(centroids))] + axial_centroids = [centroids[i][2] for i in range(0, len(centroids))] + sagittal_vals = [sagittal_centroids[0]] * axial_centroids[0] + coronal_vals = [coronal_centroids[0]] * axial_centroids[0] + + for i in range(1, len(axial_centroids)): + num = axial_centroids[i] - axial_centroids[i - 1] + interp = list( + np.linspace(sagittal_centroids[i - 1], sagittal_centroids[i], num=num) + ) + sagittal_vals.extend(interp) + interp = list( + np.linspace(coronal_centroids[i - 1], coronal_centroids[i], num=num) + ) + coronal_vals.extend(interp) + + sagittal_vals.extend([sagittal_centroids[-1]] * (mvs.shape[2] - len(sagittal_vals))) + coronal_vals.extend([coronal_centroids[-1]] * (mvs.shape[2] - len(coronal_vals))) + sagittal_vals = np.array(sagittal_vals) + coronal_vals = np.array(coronal_vals) + sagittal_vals = sagittal_vals.astype(int) + coronal_vals = coronal_vals.astype(int) + + return (sagittal_vals, coronal_vals) + + +''' +def compare_ts_stanford_centroids(labels_path, pred_centroids): + """Compare the centroids of the Stanford dataset with the centroids of the TS dataset. + + Args: + labels_path (str): Path to the Stanford dataset labels. + """ + t12_diff = [] + l1_diff = [] + l2_diff = [] + l3_diff = [] + l4_diff = [] + l5_diff = [] + num_skipped = 0 + + labels = glob(labels_path + "/*") + for label_path in labels: + # modify label_path to give pred_path + pred_path = label_path.replace("labelsTs", "predTs_TS") + print(label_path.split("/")[-1]) + label_nib = nib.load(label_path) + label = label_nib.get_fdata() + spacing = label_nib.header.get_zooms()[2] + pred_nib = nib.load(pred_path) + pred = pred_nib.get_fdata() + if True: + pred[pred == 18] = 6 + pred[pred == 19] = 5 + pred[pred == 20] = 4 + pred[pred == 21] = 3 + pred[pred == 22] = 2 + pred[pred == 23] = 1 + + for label_idx in range(1, 7): + label_level = label == label_idx + indexes = np.array(range(label.shape[2])) + sums = np.sum(label_level, axis=(0, 1)) + normalized_sums = sums / np.sum(sums) + label_centroid = np.sum(indexes * normalized_sums) + print(f"Centroid for label {label_idx}: {label_centroid}") + + if False: + try: + pred_centroid = pred_centroids[6 - label_idx] + except Exception: + # Change this part + print("Something wrong with pred_centroids, skipping!") + num_skipped += 1 + break + + # if revert_to_original: + if True: + pred_level = pred == label_idx + sums = np.sum(pred_level, axis=(0, 1)) + indices = list(range(sums.shape[0])) + groupby_input = zip(indices, list(sums)) + g = groupby(groupby_input, key=lambda x: x[1] > 0.0) + m = max([list(s) for v, s in g if v > 0], key=lambda x: np.sum(list(zip(*x))[1])) + res = list(zip(*m)) + indexes = list(res[0]) + sums = list(res[1]) + normalized_sums = sums / np.sum(sums) + pred_centroid = np.sum(indexes * normalized_sums) + print(f"Centroid for prediction {label_idx}: {pred_centroid}") + + diff = np.absolute(pred_centroid - label_centroid) * spacing + + if label_idx == 1: + t12_diff.append(diff) + elif label_idx == 2: + l1_diff.append(diff) + elif label_idx == 3: + l2_diff.append(diff) + elif label_idx == 4: + l3_diff.append(diff) + elif label_idx == 5: + l4_diff.append(diff) + elif label_idx == 6: + l5_diff.append(diff) + + print(f"Skipped {num_skipped}") + print("The final mean differences in mm:") + print( + np.mean(t12_diff), + np.mean(l1_diff), + np.mean(l2_diff), + np.mean(l3_diff), + np.mean(l4_diff), + np.mean(l5_diff), + ) + print("The final median differences in mm:") + print( + np.median(t12_diff), + np.median(l1_diff), + np.median(l2_diff), + np.median(l3_diff), + np.median(l4_diff), + np.median(l5_diff), + ) + + +def compare_ts_stanford_roi_hus(image_path): + """Compare the HU values of the Stanford dataset with the HU values of the TS dataset. + + image_path (str): Path to the Stanford dataset images. + """ + img_paths = glob(image_path + "/*") + differences = np.zeros((40, 6)) + ground_truth = np.zeros((40, 6)) + for i, img_path in enumerate(img_paths): + print(f"Image number {i + 1}") + image_path_no_0000 = re.sub(r"_0000", "", img_path) + ts_seg_path = image_path_no_0000.replace("imagesTs", "predTs_TS") + stanford_seg_path = image_path_no_0000.replace("imagesTs", "labelsTs") + img = nib.load(img_path).get_fdata() + img = np.swapaxes(img, 0, 1) + ts_seg = nib.load(ts_seg_path).get_fdata() + ts_seg = np.swapaxes(ts_seg, 0, 1) + stanford_seg = nib.load(stanford_seg_path).get_fdata() + stanford_seg = np.swapaxes(stanford_seg, 0, 1) + ts_model_type = Models.model_from_name("ts_spine") + (spine_hus_ts, rois, centroids_3d) = compute_rois(ts_seg, img, 1, 0, ts_model_type) + stanford_model_type = Models.model_from_name("stanford_spine_v0.0.1") + (spine_hus_stanford, rois, centroids_3d) = compute_rois( + stanford_seg, img, 1, 0, stanford_model_type + ) + difference_vals = np.abs(np.array(spine_hus_ts) - np.array(spine_hus_stanford)) + print(f"Differences {difference_vals}\n") + differences[i, :] = difference_vals + ground_truth[i, :] = spine_hus_stanford + print("\n") + # compute average percent change from ground truth + percent_change = np.divide(differences, ground_truth) * 100 + average_percent_change = np.mean(percent_change, axis=0) + median_percent_change = np.median(percent_change, axis=0) + # print average percent change + print("Average percent change from ground truth:") + print(average_percent_change) + print("Median percent change from ground truth:") + print(median_percent_change) + # print average difference + average_difference = np.mean(differences, axis=0) + median_difference = np.median(differences, axis=0) + print("Average difference from ground truth:") + print(average_difference) + print("Median difference from ground truth:") + print(median_difference) + + +def process_post_hoc(pred_path): + """Apply post-hoc heuristics for improving Stanford spine model vertical centroid predictions. + + Args: + pred_path (str): Path to the prediction. + """ + pred_nib = nib.load(pred_path) + pred = pred_nib.get_fdata() + + pred_bodies = np.logical_and(pred >= 1, pred <= 6) + pred_bodies = pred_bodies.astype(np.int64) + + labels_out, N = cc3d.connected_components(pred_bodies, return_N=True, connectivity=6) + + stats = cc3d.statistics(labels_out) + print(stats) + + labels_out_list = [] + voxel_counts_list = list(stats["voxel_counts"]) + for idx_lab in range(1, N + 2): + labels_out_list.append(labels_out == idx_lab) + + centroids_list = list(stats["centroids"][:, 2]) + + labels = [] + centroids = [] + voxels = [] + + for idx, count in enumerate(voxel_counts_list): + if count > 10000: + labels.append(labels_out_list[idx]) + centroids.append(centroids_list[idx]) + voxels.append(count) + + top_comps = [ + (counts0, labels0, centroids0) + for counts0, labels0, centroids0 in sorted(zip(voxels, labels, centroids), reverse=True) + ] + top_comps = top_comps[1:7] + + # ====== Check whether the connected components are fusing vertebral bodies ====== + revert_to_original = False + + volumes = list(zip(*top_comps))[0] + if volumes[0] > 1.5 * volumes[1]: + revert_to_original = True + print("Reverting to original...") + + labels = list(zip(*top_comps))[1] + centroids = list(zip(*top_comps))[2] + + top_comps = zip(centroids, labels) + pred_centroids = [x for x, _ in sorted(top_comps)] + + for label_idx in range(1, 7): + if not revert_to_original: + try: + pred_centroid = pred_centroids[6 - label_idx] + except: + # Change this part + print( + "Post processing failure, probably < 6 predicted bodies. Reverting to original labels." + ) + revert_to_original = True + + if revert_to_original: + pred_level = pred == label_idx + sums = np.sum(pred_level, axis=(0, 1)) + indices = list(range(sums.shape[0])) + groupby_input = zip(indices, list(sums)) + # sys.exit() + g = groupby(groupby_input, key=lambda x: x[1] > 0.0) + m = max([list(s) for v, s in g if v > 0], key=lambda x: np.sum(list(zip(*x))[1])) + # sys.exit() + # m = max([list(s) for v, s in g], key=lambda np.sum) + res = list(zip(*m)) + indexes = list(res[0]) + sums = list(res[1]) + normalized_sums = sums / np.sum(sums) + pred_centroid = np.sum(indexes * normalized_sums) + print(f"Centroid for prediction {label_idx}: {pred_centroid}") +''' diff --git a/comp2comp/spine/spine.py b/comp2comp/spine/spine.py index d46667df..d01f2224 100644 --- a/comp2comp/spine/spine.py +++ b/comp2comp/spine/spine.py @@ -15,6 +15,11 @@ import pandas as pd import wget from PIL import Image +from totalsegmentator.libs import ( + download_pretrained_weights, + nostdout, + setup_nnunet, +) from totalsegmentatorv2.python_api import totalsegmentator from comp2comp.inference_class_base import InferenceClass @@ -22,12 +27,6 @@ from comp2comp.spine import spine_utils from comp2comp.visualization.dicom import to_dicom -from totalsegmentator.libs import ( - download_pretrained_weights, - nostdout, - setup_nnunet, -) - class SpineSegmentation(InferenceClass): """Spine segmentation.""" @@ -41,11 +40,15 @@ def __call__(self, inference_pipeline): # inference_pipeline.dicom_series_path = self.input_path self.output_dir = inference_pipeline.output_dir self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") + inference_pipeline.output_dir_masks = os.path.join(self.output_dir, "masks/") + if not os.path.exists(self.output_dir_segmentations): os.makedirs(self.output_dir_segmentations) + if not os.path.exists(inference_pipeline.output_dir_masks): + os.makedirs(inference_pipeline.output_dir_masks) self.model_dir = inference_pipeline.model_dir - + inference_pipeline.spine_model_name = self.model_name seg, mv = self.spine_seg( @@ -53,7 +56,7 @@ def __call__(self, inference_pipeline): self.output_dir_segmentations + "spine.nii.gz", inference_pipeline.model_dir, ) - + os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir mv = nib.load( @@ -64,6 +67,13 @@ def __call__(self, inference_pipeline): nib.save( seg, os.path.join(self.output_dir_segmentations, "spine_seg.nii.gz"), + # os.path.join(inference_pipeline.output_dir_masks, "spine_seg.nii.gz"), + ) + + nib.save( + mv, + # os.path.join(self.output_dir_segmentations, "spine_seg.nii.gz"), + os.path.join(inference_pipeline.output_dir_masks, "ct.nii.gz"), ) # inference_pipeline.segmentation = nib.load( @@ -72,7 +82,7 @@ def __call__(self, inference_pipeline): inference_pipeline.segmentation = seg inference_pipeline.medical_volume = mv inference_pipeline.save_segmentations = self.save_segmentations - + return {} def setup_nnunet_c2c(self, model_dir: Union[str, Path]): @@ -137,9 +147,11 @@ def spine_seg( os.environ["SCRATCH"] = self.model_dir os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir - if self.model_name == 'ts_spine': + if self.model_name == "ts_spine": seg = totalsegmentator( - input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + input=os.path.join( + self.output_dir_segmentations, "converted_dcm.nii.gz" + ), output=os.path.join(self.output_dir_segmentations, "segmentation.nii"), task_ids=[292], ml=True, @@ -165,11 +177,11 @@ def spine_seg( statistics_exclude_masks_at_border=True, no_derived_masks=False, v1_order=False, - ) - + ) + img = None - elif self.model_name == 'stanford_spine_v0.0.1': + elif self.model_name == "stanford_spine_v0.0.1": # Setup nnunet model = "3d_fullres" folds = [0] @@ -326,6 +338,38 @@ def __call__(self, inference_pipeline): inference_pipeline.centroids_3d = centroids_3d inference_pipeline.spine_masks = spine_masks + # save the ROIs and spine masks as a single array + full_rois_array = np.zeros( + inference_pipeline.medical_volume.shape, dtype=np.int8 + ) + full_spine_masks_array = np.zeros( + inference_pipeline.medical_volume.shape, dtype=np.int8 + ) + + for i, level in enumerate(rois.keys()): + full_rois_array[rois[level] > 0] = i + 1 + # full_spine_masks_array[spine_masks[level] > 0] = i + 1 + + inference_pipeline.saveArrToNifti( + full_rois_array, + os.path.join( + inference_pipeline.output_dir_masks, "central_rois_mask.nii.gz" + ), + ) + + nib.save( + inference_pipeline.segmentation, + os.path.join(inference_pipeline.output_dir_masks, "spine_seg.nii.gz"), + ) + + # inference_pipeline.saveArrToNifti( + # full_spine_masks_array, + # os.path.join( + # inference_pipeline.output_dir_masks, + # "vertebrae_body_mask.nii.gz", + # ), + # ) + return {} diff --git a/comp2comp/spine/spine_utils.py b/comp2comp/spine/spine_utils.py index 975213a8..378d18ba 100644 --- a/comp2comp/spine/spine_utils.py +++ b/comp2comp/spine/spine_utils.py @@ -230,23 +230,25 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray) updated_z_center = np.mean(np.where(inferior_superior_line == 1)) lower_z_idx = updated_z_center - ((length_k * 1.5) // 2) upper_z_idx = updated_z_center + ((length_k * 1.5) // 2) - + for idx in range(int(lower_z_idx), int(upper_z_idx) + 1): - print(f'idx: {idx}') + print(f"idx: {idx}") # take multiple to increase robustness posterior_anterior_lines = [ slice[:, idx], slice[:, idx + 1], slice[:, idx - 1], ] - posterior_anterior_sums = np.array([ - np.sum(posterior_anterior_lines[0]), - np.sum(posterior_anterior_lines[1]), - np.sum(posterior_anterior_lines[2]), - ]) + posterior_anterior_sums = np.array( + [ + np.sum(posterior_anterior_lines[0]), + np.sum(posterior_anterior_lines[1]), + np.sum(posterior_anterior_lines[2]), + ] + ) if posterior_anterior_sums.sum() == 0: - print(f'skipped idx {idx} since posterior_anterior_sums where 0') + print(f"skipped idx {idx} since posterior_anterior_sums where 0") continue # min_idx = np.argmin(posterior_anterior_sums) @@ -258,7 +260,7 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray) min_idx = np.where(nonzero_mask)[0][min_idx] posterior_anterior_line = posterior_anterior_lines[min_idx] - + updated_posterior_anterior_center = ( np.min(np.where(posterior_anterior_line == 1)) + np.sum(posterior_anterior_line) * 0.58 @@ -272,14 +274,16 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray) seg[:, int(updated_posterior_anterior_center) - 1, idx], ] - left_right_sums = np.array([ - np.sum(left_right_lines[0]), - np.sum(left_right_lines[1]), - np.sum(left_right_lines[2]), - ]) + left_right_sums = np.array( + [ + np.sum(left_right_lines[0]), + np.sum(left_right_lines[1]), + np.sum(left_right_lines[2]), + ] + ) if left_right_sums.sum() == 0: - print(f'skipped idx {idx} since left_right_sums where 0') + print(f"skipped idx {idx} since left_right_sums where 0") continue # min_idx = np.argmin(left_right_sums) @@ -401,19 +405,21 @@ def compute_rois(seg, img, spine_model_type): centroids_3d = {} segmentation_hus = {} spine_masks = {} + + image_numpy = img.get_fdata() + for i, level in enumerate(slices): slice = slices[level] center_of_mass = compute_center_of_mass(slice) centroid = np.array([centroids[level], center_of_mass[1], center_of_mass[0]]) - print(f'Processing i={i} at level={level}') - + print(f"Processing i={i} at level={level}") + roi, mask_2d = roi_from_mask( img, centroid, (seg_np == spine_model_type.categories[level]).astype(int), slice, ) - image_numpy = img.get_fdata() spine_hus[level] = mean_img_mask(image_numpy, roi, i) rois[level] = roi mask = (seg_np == spine_model_type.categories[level]).astype(int)