diff --git a/.gitignore b/.gitignore index 1ba05edc..72a20a13 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ docs/_build **/test_data **/testing_data **/sample_data +**/ct_scan_paths **/test_results # Ignore images diff --git a/bin/C2C b/bin/C2C index 3b8ac892..e3578c36 100755 --- a/bin/C2C +++ b/bin/C2C @@ -31,15 +31,15 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" ### AAA Pipeline def AAAPipelineBuilder(path, args): - pipeline = InferencePipeline( - [ - AxialCropperPipelineBuilder(path, args), + # build the “crop” pipeline and pull out its step‐classes + crop_pipeline = AxialCropperPipelineBuilder(path, args) + # now concatenate its classes with the AAA steps + classes = crop_pipeline.inference_classes + [ aaa.AortaSegmentation(), aaa.AortaDiameter(), - aaa.AortaMetricsSaver() - ] - ) - return pipeline + aaa.AortaMetricsSaver(), + ] + return InferencePipeline(classes) def MuscleAdiposeTissuePipelineBuilder(args): pipeline = InferencePipeline( diff --git a/comp2comp/aaa/aaa.py b/comp2comp/aaa/aaa.py index fd82d764..b768acb0 100644 --- a/comp2comp/aaa/aaa.py +++ b/comp2comp/aaa/aaa.py @@ -1,4 +1,3 @@ -import math import operator import os import zipfile @@ -28,8 +27,117 @@ def __init__(self, save=True): self.model_name = "totalsegmentator" self.save_segmentations = save + def _infer_raw_geometry(self, input_path: Union[str, Path]): + """ + Return (raw_shape, raw_affine) for the *raw* CT space. + - If input_path is a DICOM folder: build affine from DICOM headers. + - If input_path is a NIfTI file (.nii/.nii.gz): load with nibabel. + """ + if os.path.isdir(input_path): + # --- DICOM series case --- + files = [ + os.path.join(input_path, f) + for f in os.listdir(input_path) + if not f.startswith(".") + ] + ds_list = [] + for fp in files: + try: + ds = pydicom.dcmread(fp, stop_before_pixels=True) + ds_list.append(ds) + except Exception: + pass + if not ds_list: + raise RuntimeError(f"No readable DICOM slices found in: {input_path}") + + ds0 = ds_list[0] + rows = int(ds0.Rows) + cols = int(ds0.Columns) + + # Direction cosines + iop = [float(x) for x in ds0.ImageOrientationPatient] # 6 values + row_cos = np.array(iop[:3], dtype=float) + col_cos = np.array(iop[3:], dtype=float) + normal = np.cross(row_cos, col_cos) + + # Pixel spacing (mm) + ps = [float(x) for x in ds0.PixelSpacing] # [row_spacing, col_spacing] + row_spacing = ps[0] + col_spacing = ps[1] + + # Sort slices along the normal using ImagePositionPatient or InstanceNumber + def ipp_dot(d): + if hasattr(d, "ImagePositionPatient"): + ipp = np.array( + [float(x) for x in d.ImagePositionPatient], dtype=float + ) + return float(np.dot(ipp, normal)) + return 0.0 + + if all(hasattr(d, "ImagePositionPatient") for d in ds_list): + ds_list.sort(key=ipp_dot) + elif all(hasattr(d, "InstanceNumber") for d in ds_list): + ds_list.sort(key=lambda d: int(d.InstanceNumber)) + else: + # fallback: filename order + ds_list.sort(key=lambda d: getattr(d, "SOPInstanceUID", "")) + + # Slice spacing (use IPP difference if possible) + if ( + len(ds_list) > 1 + and hasattr(ds_list[0], "ImagePositionPatient") + and hasattr(ds_list[1], "ImagePositionPatient") + ): + ipp0 = np.array( + [float(x) for x in ds_list[0].ImagePositionPatient], dtype=float + ) + ipp1 = np.array( + [float(x) for x in ds_list[1].ImagePositionPatient], dtype=float + ) + slice_step = abs(np.dot((ipp1 - ipp0), normal)) + if slice_step == 0: + slice_step = float( + getattr( + ds0, + "SpacingBetweenSlices", + getattr(ds0, "SliceThickness", 1.0), + ) + ) + else: + slice_step = float( + getattr( + ds0, "SpacingBetweenSlices", getattr(ds0, "SliceThickness", 1.0) + ) + ) + + # Origin = IPP of first slice + if hasattr(ds_list[0], "ImagePositionPatient"): + origin = np.array( + [float(x) for x in ds_list[0].ImagePositionPatient], dtype=float + ) + else: + origin = np.zeros(3, dtype=float) + + # Build affine (voxel indices i,j,k -> world x,y,z) + # i: columns (x), j: rows (y), k: slices (z) + aff = np.eye(4, dtype=float) + aff[0:3, 0] = col_cos * col_spacing + aff[0:3, 1] = row_cos * row_spacing + aff[0:3, 2] = normal * slice_step + aff[0:3, 3] = origin + + raw_shape = (rows, cols, len(ds_list)) + raw_affine = aff + return raw_shape, raw_affine + + else: + # --- NIfTI case --- + nii = nib.load(str(input_path)) + return tuple(nii.shape), np.array(nii.affine, dtype=float) + def __call__(self, inference_pipeline): - # inference_pipeline.dicom_series_path = self.input_path + print("TRY NEW: DICOM series path is:", inference_pipeline.input_path) + inference_pipeline.dicom_series_path = inference_pipeline.input_path self.output_dir = inference_pipeline.output_dir self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") if not os.path.exists(self.output_dir_segmentations): @@ -37,12 +145,30 @@ def __call__(self, inference_pipeline): self.model_dir = inference_pipeline.model_dir + # keep the path to the *raw* NIfTI you feed into the model + raw_nii_path = os.path.join( + self.output_dir_segmentations, "converted_dcm.nii.gz" + ) + seg, mv = self.spine_seg( - os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + raw_nii_path, self.output_dir_segmentations + "spine.nii.gz", inference_pipeline.model_dir, ) + # NEW: derive raw geometry from the *input* CT (DICOM folder or NIfTI) + raw_shape, raw_affine = self._infer_raw_geometry(inference_pipeline.input_path) + inference_pipeline.raw_shape = raw_shape + inference_pipeline.raw_affine = raw_affine + + # Keep processed-space metadata for pixel→mm conversion & mapping + inference_pipeline.proc_affine = mv.affine + inference_pipeline.proc_zooms = mv.header.get_zooms() + inference_pipeline.proc_shape = mv.shape + + print("DEBUG raw_shape:", inference_pipeline.raw_shape) + print("DEBUG proc_shape:", inference_pipeline.proc_shape) + seg = seg.get_fdata() medical_volume = mv.get_fdata() @@ -177,224 +303,186 @@ def __init__(self): super().__init__() def normalize_img(self, img: np.ndarray) -> np.ndarray: - """Normalize the image. - Args: - img (np.ndarray): Input image. - Returns: - np.ndarray: Normalized image. - """ return (img - img.min()) / (img.max() - img.min()) def __call__(self, inference_pipeline): axial_masks = ( inference_pipeline.axial_masks - ) # list of 2D numpy arrays of shape (512, 512) - ct_img = ( - inference_pipeline.ct_image - ) # 3D numpy array of shape (512, 512, num_axial_slices) + ) # list of 2D masks (processed space) + ct_img = inference_pipeline.ct_image # list/array of slices (processed space) - # image output directory + # output dirs output_dir = inference_pipeline.output_dir output_dir_slices = os.path.join(output_dir, "images/slices/") - if not os.path.exists(output_dir_slices): - os.makedirs(output_dir_slices) - - output_dir = inference_pipeline.output_dir output_dir_summary = os.path.join(output_dir, "images/summary/") - if not os.path.exists(output_dir_summary): - os.makedirs(output_dir_summary) - - DICOM_PATH = inference_pipeline.dicom_series_path - dicom = pydicom.dcmread(DICOM_PATH + "/" + os.listdir(DICOM_PATH)[0]) - - dicom.PhotometricInterpretation = "YBR_FULL" - pixel_conversion = dicom.PixelSpacing - print("Pixel conversion: " + str(pixel_conversion)) - RATIO_PIXEL_TO_MM = pixel_conversion[0] - - SLICE_COUNT = dicom["InstanceNumber"].value - print(SLICE_COUNT) - - SLICE_COUNT = len(ct_img) - diameterDict = {} - - for i in range(len(ct_img)): - mask = axial_masks[i].astype("uint8") - - img = ct_img[i] - - img = np.clip(img, -300, 1800) - img = self.normalize_img(img) * 255.0 - img = img.reshape((img.shape[0], img.shape[1], 1)) - img = np.tile(img, (1, 1, 3)) - - contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) - - if len(contours) != 0: - areas = [cv2.contourArea(c) for c in contours] - sorted_areas = np.sort(areas) - - areas = [cv2.contourArea(c) for c in contours] - sorted_areas = np.sort(areas) - contours = contours[areas.index(sorted_areas[-1])] - - img.copy() - - back = img.copy() - cv2.drawContours(back, [contours], 0, (0, 255, 0), -1) + os.makedirs(output_dir_slices, exist_ok=True) + os.makedirs(output_dir_summary, exist_ok=True) - alpha = 0.25 - img = cv2.addWeighted(img, 1 - alpha, back, alpha, 0) + # --- Affine-based mapping: processed -> raw --- + # processed image metadata (where we actually measure pixels) + proc_affine = np.array(inference_pipeline.proc_affine) + proc_zooms = tuple(inference_pipeline.proc_zooms) + proc_shape = tuple(inference_pipeline.proc_shape) # (H, W, Z_proc) + Hp, Wp, Zp = proc_shape - ellipse = cv2.fitEllipse(contours) - (xc, yc), (d1, d2), angle = ellipse + # raw CT reference space (target indexing for CSV) + raw_shape = tuple(inference_pipeline.raw_shape) # e.g., (512, 512, 480) + raw_aff = np.array(inference_pipeline.raw_affine) # 4x4 + Z_raw = raw_shape[2] - cv2.ellipse(img, ellipse, (0, 255, 0), 1) + print(f"[DEBUG] raw_shape={raw_shape} -> Z_raw={Z_raw}") + print(f"[DEBUG] proc_shape={proc_shape} -> Zp={Zp}") - xc, yc = ellipse[0] - cv2.circle(img, (int(xc), int(yc)), 5, (0, 0, 255), -1) + # precompute inverse affine for raw (do this once) + inv_raw_aff = np.linalg.inv(raw_aff) - rmajor = max(d1, d2) / 2 - rminor = min(d1, d2) / 2 + # use in-plane mm-per-pixel from the processed image (safer if x!=y) + RATIO_PIXEL_TO_MM = float((proc_zooms[0] + proc_zooms[1]) / 2.0) - ### Draw major axes + # pick the center of the processed image in XY to define the slice plane + cx = (Wp - 1) / 2.0 + cy = (Hp - 1) / 2.0 - if angle > 90: - angle = angle - 90 - else: - angle = angle + 90 - print(angle) - xtop = xc + math.cos(math.radians(angle)) * rmajor - ytop = yc + math.sin(math.radians(angle)) * rmajor - xbot = xc + math.cos(math.radians(angle + 180)) * rmajor - ybot = yc + math.sin(math.radians(angle + 180)) * rmajor - cv2.line( - img, (int(xtop), int(ytop)), (int(xbot), int(ybot)), (0, 0, 255), 3 - ) - - ### Draw minor axes - - if angle > 90: - angle = angle - 90 - else: - angle = angle + 90 - print(angle) - x1 = xc + math.cos(math.radians(angle)) * rminor - y1 = yc + math.sin(math.radians(angle)) * rminor - x2 = xc + math.cos(math.radians(angle + 180)) * rminor - y2 = yc + math.sin(math.radians(angle + 180)) * rminor - cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 3) - - # pixel_length = math.sqrt( (x1-x2)**2 + (y1-y2)**2 ) - pixel_length = rminor * 2 - - print("Pixel_length_minor: " + str(pixel_length)) - - area_px = cv2.contourArea(contours) - area_mm = round(area_px * RATIO_PIXEL_TO_MM) - area_cm = area_mm / 10 - - diameter_mm = round((pixel_length) * RATIO_PIXEL_TO_MM) - diameter_cm = diameter_mm / 10 - - diameterDict[(SLICE_COUNT - (i))] = diameter_cm - - img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) - - h, w, c = img.shape - lbls = [ - "Area (mm): " + str(area_mm) + "mm", - "Area (cm): " + str(area_cm) + "cm", - "Diameter (mm): " + str(diameter_mm) + "mm", - "Diameter (cm): " + str(diameter_cm) + "cm", - "Slice: " + str(SLICE_COUNT - (i)), - ] - font = cv2.FONT_HERSHEY_SIMPLEX + print("Processed shape (H,W,Z):", proc_shape) + print("Raw shape (H,W,Z):", raw_shape) - scale = 0.03 - fontScale = min(w, h) / (25 / scale) + # maps RAW index -> diameter (cm) + diameter_cm_by_raw_index = {} - cv2.putText(img, lbls[0], (10, 40), font, fontScale, (0, 255, 0), 2) + for i in range(Zp): # i is processed-space slice index + # ensure binary mask + mask = (axial_masks[i] > 0).astype("uint8") + if mask.max() == 0: + continue # no aorta on this processed slice - cv2.putText(img, lbls[1], (10, 70), font, fontScale, (0, 255, 0), 2) + # ---- map processed slice i to raw slice index k_raw ---- + proc_voxel = np.array([cx, cy, i, 1.0]) # center of slice + world_xyz1 = proc_affine @ proc_voxel # world coord + raw_ijk1 = inv_raw_aff @ world_xyz1 # raw voxel coord + k_raw = int(round(raw_ijk1[2])) # raw z-index - cv2.putText(img, lbls[2], (10, 100), font, fontScale, (0, 255, 0), 2) + # clamp into valid range + if k_raw < 0 or k_raw >= Z_raw: + continue - cv2.putText(img, lbls[3], (10, 130), font, fontScale, (0, 255, 0), 2) - - cv2.putText(img, lbls[4], (10, 160), font, fontScale, (0, 255, 0), 2) - - cv2.imwrite( - output_dir_slices + "slice" + str(SLICE_COUNT - (i)) + ".png", img - ) - - plt.bar(list(diameterDict.keys()), diameterDict.values(), color="b") - - plt.title(r"$\bf{Diameter}$" + " " + r"$\bf{Progression}$") - - plt.xlabel("Slice Number") - - plt.ylabel("Diameter Measurement (cm)") - plt.savefig(output_dir_summary + "diameter_graph.png", dpi=500) - - print(diameterDict) - print(max(diameterDict.items(), key=operator.itemgetter(1))[0]) - print(diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]]) - - inference_pipeline.max_diameter = diameterDict[ - max(diameterDict.items(), key=operator.itemgetter(1))[0] - ] - - img = ct_img[ - SLICE_COUNT - (max(diameterDict.items(), key=operator.itemgetter(1))[0]) - ] - img = np.clip(img, -300, 1800) - img = self.normalize_img(img) * 255.0 - img = img.reshape((img.shape[0], img.shape[1], 1)) - img2 = np.tile(img, (1, 1, 3)) - img2 = cv2.rotate(img2, cv2.ROTATE_90_COUNTERCLOCKWISE) - - img1 = cv2.imread( - output_dir_slices - + "slice" - + str(max(diameterDict.items(), key=operator.itemgetter(1))[0]) - + ".png" - ) - - border_size = 3 - img1 = cv2.copyMakeBorder( - img1, - top=border_size, - bottom=border_size, - left=border_size, - right=border_size, - borderType=cv2.BORDER_CONSTANT, - value=[0, 244, 0], - ) - img2 = cv2.copyMakeBorder( - img2, - top=border_size, - bottom=border_size, - left=border_size, - right=border_size, - borderType=cv2.BORDER_CONSTANT, - value=[244, 0, 0], - ) + # ---- measure diameter on the processed slice ---- + img = ct_img[i] + img = np.clip(img, -300, 1800) + img = self.normalize_img(img) * 255.0 + img = img.reshape((img.shape[0], img.shape[1], 1)) + img = np.tile(img, (1, 1, 3)) - vis = np.concatenate((img2, img1), axis=1) - cv2.imwrite(output_dir_summary + "out.png", vis) + contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) + if not contours: + continue + + # largest component + areas = [cv2.contourArea(c) for c in contours] + contours = contours[int(np.argmax(areas))] + + # overlay (cosmetic) + back = img.copy() + cv2.drawContours(back, [contours], 0, (0, 255, 0), -1) + img = cv2.addWeighted(img, 0.75, back, 0.25, 0) + + ellipse = cv2.fitEllipse(contours) + (xc, yc), (d1, d2), angle = ellipse + cv2.ellipse(img, ellipse, (0, 255, 0), 1) + cv2.circle(img, (int(xc), int(yc)), 5, (0, 0, 255), -1) + + max(d1, d2) / 2.0 + rminor = min(d1, d2) / 2.0 + + # diameter from minor axis + pixel_length = rminor * 2.0 + diameter_mm = round(pixel_length * RATIO_PIXEL_TO_MM) + diameter_cm = diameter_mm / 10.0 + + # if multiple processed slices map to the same raw index (rare), keep the max + prev = diameter_cm_by_raw_index.get(k_raw) + if (prev is None) or (diameter_cm > prev): + diameter_cm_by_raw_index[k_raw] = diameter_cm + + # Save an image for quick QC (tagged with RAW index) + img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) + h, w, _ = img.shape + font = cv2.FONT_HERSHEY_SIMPLEX + fontScale = min(w, h) / (25 / 0.03) + cv2.putText( + img, + f"CT raw slice index: {k_raw}", + (10, 40), + font, + fontScale, + (0, 255, 0), + 2, + ) + cv2.putText( + img, + f"Diameter (cm): {diameter_cm}", + (10, 70), + font, + fontScale, + (0, 255, 0), + 2, + ) + cv2.imwrite(os.path.join(output_dir_slices, f"slice_raw_{k_raw}.png"), img) + + # --- Summary plot vs RAW index --- + if diameter_cm_by_raw_index: + xs = sorted(diameter_cm_by_raw_index.keys()) + ys = [diameter_cm_by_raw_index[k] for k in xs] + plt.bar(xs, ys) + plt.title(r"$\bf{Diameter}$" + " " + r"$\bf{Progression}$") + plt.xlabel("CT raw slice index (0-based)") + plt.ylabel("Diameter (cm)") + plt.savefig(os.path.join(output_dir_summary, "diameter_graph.png"), dpi=500) + plt.close() + + # --- Max diameter (by RAW index) --- + if diameter_cm_by_raw_index: + max_raw_idx = max( + diameter_cm_by_raw_index.items(), key=operator.itemgetter(1) + )[0] + print( + "Max raw index:", + max_raw_idx, + "diameter_cm:", + diameter_cm_by_raw_index[max_raw_idx], + ) + inference_pipeline.max_diameter = diameter_cm_by_raw_index[max_raw_idx] + else: + max_raw_idx = None + inference_pipeline.max_diameter = float("nan") - image_folder = output_dir_slices - fps = 20 + # --- MP4 (frames may be fewer than Z_raw; only saved slices with aorta) --- image_files = [ - os.path.join(image_folder, img) - for img in Tcl().call("lsort", "-dict", os.listdir(image_folder)) - if img.endswith(".png") + os.path.join(output_dir_slices, f) + for f in Tcl().call("lsort", "-dict", os.listdir(output_dir_slices)) + if f.endswith(".png") ] - clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip( - image_files, fps=fps - ) - clip.write_videofile(output_dir_summary + "aaa.mp4") + if image_files: + clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip( + image_files, fps=20 + ) + clip.write_videofile(os.path.join(output_dir_summary, "aaa.mp4")) + + # --- CSV over ALL raw slices: 0..Z_raw-1 (NaN where no aorta) --- + metrics_dir = os.path.join(inference_pipeline.output_dir, "metrics") + os.makedirs(metrics_dir, exist_ok=True) + + rows = [] + for k in range(Z_raw): + d_cm = diameter_cm_by_raw_index.get(k, np.nan) + rows.append( + { + "ct_slice": k, + "diameter_cm": d_cm, + "diameter_mm": (d_cm * 10.0) if not np.isnan(d_cm) else np.nan, + } + ) + df = pd.DataFrame(rows, columns=["ct_slice", "diameter_cm", "diameter_mm"]) + df.to_csv(os.path.join(metrics_dir, "aorta_diameters.csv"), index=False) return {} diff --git a/comp2comp/aortic_calcium/aortic_calcium.py b/comp2comp/aortic_calcium/aortic_calcium.py index a555a58b..8777f60d 100644 --- a/comp2comp/aortic_calcium/aortic_calcium.py +++ b/comp2comp/aortic_calcium/aortic_calcium.py @@ -12,8 +12,8 @@ import matplotlib.pyplot as plt import numpy as np +import pydicom from scipy import ndimage -import pydicom # from totalsegmentator.libs import ( # download_pretrained_weights, @@ -34,9 +34,9 @@ def __init__(self): def __call__(self, inference_pipeline): # check if kernels are allowed if agatston is used - if inference_pipeline.args.threshold == 'agatston': + if inference_pipeline.args.threshold == "agatston": self.reconKernelChecker(inference_pipeline.dcm) - + # inference_pipeline.dicom_series_path = self.input_path self.output_dir = inference_pipeline.output_dir self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") @@ -115,27 +115,63 @@ def reconKernelChecker(self, dcm): ge_kernels = ["standard", "md stnd"] philips_kernels = ["a", "b", "c", "sa", "sb"] canon_kernels = ["fc08", "fc18"] - siemens_kernels = ["b20s", "b20f", "b30f", "b31s", "b31f", "br34f", "b35f", "bf37f", "br38f", "b41f", - "qr40", "qr40d", "br36f", "br40", "b40f", "br40d", "i30f", "i31f", "i26f", "i31s", - "i40f", "b30s", "br36d", "bf39f", "b41s", "br40f"] + siemens_kernels = [ + "b20s", + "b20f", + "b30f", + "b31s", + "b31f", + "br34f", + "b35f", + "bf37f", + "br38f", + "b41f", + "qr40", + "qr40d", + "br36f", + "br40", + "b40f", + "br40d", + "i30f", + "i31f", + "i26f", + "i31s", + "i40f", + "b30s", + "br36d", + "bf39f", + "b41s", + "br40f", + ] toshiba_kernels = ["fc01", "fc02", "fc07", "fc08", "fc13", "fc18"] - all_kernels = ge_kernels+philips_kernels+canon_kernels+siemens_kernels+toshiba_kernels - - conv_kernel_raw = dcm['ConvolutionKernel'].value - + all_kernels = ( + ge_kernels + + philips_kernels + + canon_kernels + + siemens_kernels + + toshiba_kernels + ) + + conv_kernel_raw = dcm["ConvolutionKernel"].value + if isinstance(conv_kernel_raw, pydicom.multival.MultiValue): conv_kernel = conv_kernel_raw[0].lower() - recon_kernel_extra = str(conv_kernel_raw) + str(conv_kernel_raw) else: conv_kernel = conv_kernel_raw.lower() - recon_kernel_extra = 'n/a' - + if conv_kernel in all_kernels: return True - else: - raise ValueError('Reconstruction kernel not allowed, found: ' + conv_kernel +'\n' - + 'Allowed kernels are: ' + str(all_kernels)) + else: + raise ValueError( + "Reconstruction kernel not allowed, found: " + + conv_kernel + + "\n" + + "Allowed kernels are: " + + str(all_kernels) + ) + class AorticCalciumSegmentation(InferenceClass): """Segmentaiton of aortic calcium""" @@ -168,10 +204,10 @@ def __init__(self): 47: "vertebrae_C4", 48: "vertebrae_C3", 49: "vertebrae_C2", - 50: "vertebrae_C1"} - - self.vertebrae_name = {v: k for k, v in self.vertebrae_num.items()} + 50: "vertebrae_C1", + } + self.vertebrae_name = {v: k for k, v in self.vertebrae_num.items()} def __call__(self, inference_pipeline): @@ -194,15 +230,21 @@ def __call__(self, inference_pipeline): os.makedirs(os.path.join(self.output_dir, "metrics/")) inference_pipeline.ct = inference_pipeline.medical_volume.get_fdata() - inference_pipeline.aorta_mask = (inference_pipeline.segmentation.get_fdata().round().astype(np.int8) == 52) - inference_pipeline.spine_mask = inference_pipeline.spine_segmentation.get_fdata().round().astype(np.uint8) - + inference_pipeline.aorta_mask = ( + inference_pipeline.segmentation.get_fdata().round().astype(np.int8) == 52 + ) + inference_pipeline.spine_mask = ( + inference_pipeline.spine_segmentation.get_fdata().round().astype(np.uint8) + ) + # convert to the index of TotalSegmentator - if inference_pipeline.spine_model_name == 'stanford_spine_v0.0.1': + if inference_pipeline.spine_model_name == "stanford_spine_v0.0.1": tmp_mask = inference_pipeline.spine_mask > 0 - inference_pipeline.spine_mask[tmp_mask] = inference_pipeline.spine_mask[tmp_mask] + 11 + inference_pipeline.spine_mask[tmp_mask] = ( + inference_pipeline.spine_mask[tmp_mask] + 11 + ) del tmp_mask - + spine_mask_bin = inference_pipeline.spine_mask > 0 # Determine the target number of pixels @@ -220,7 +262,7 @@ def __call__(self, inference_pipeline): inference_pipeline.aorta_mask, exclude_mask=spine_mask_bin, remove_size=3, - return_dilated_mask=True, + return_dilated_mask=True, return_eroded_aorta=True, threshold=inference_pipeline.args.threshold, dilation_iteration=target_aorta_dil, @@ -240,7 +282,7 @@ def __call__(self, inference_pipeline): "calcium_segmentations.nii.gz", ), ) - + inference_pipeline.saveArrToNifti( calcification_results["dilated_mask"], os.path.join( @@ -248,7 +290,7 @@ def __call__(self, inference_pipeline): "dilated_aorta_mask.nii.gz", ), ) - + inference_pipeline.saveArrToNifti( calcification_results["aorta_eroded"], os.path.join( @@ -263,7 +305,7 @@ def __call__(self, inference_pipeline): inference_pipeline.output_dir_segmentation_masks, "spine_mask.nii.gz" ), ) - + inference_pipeline.saveArrToNifti( inference_pipeline.aorta_mask, os.path.join( @@ -647,7 +689,7 @@ def getSmallestArraySlice(self, input_mask, margin=0): ) return (slice(x_start, x_end), slice(y_start, y_end), slice(z_start, z_end)) - + class AorticCalciumMetrics(InferenceClass): """Calculate metrics for the aortic calcifications""" @@ -659,7 +701,7 @@ def __call__(self, inference_pipeline): calc_mask = inference_pipeline.calc_mask spine_mask = inference_pipeline.spine_mask aorta_mask = inference_pipeline.aorta_mask - + t12_level = np.where((spine_mask == 32).sum(axis=(0, 1)))[0] l1_level = np.where((spine_mask == 31).sum(axis=(0, 1)))[0] @@ -686,7 +728,7 @@ def __call__(self, inference_pipeline): ), ) inference_pipeline.t12_plane = planes - + inference_pipeline.pix_dims = inference_pipeline.medical_volume.header[ "pixdim" ][1:4] @@ -741,10 +783,11 @@ def __call__(self, inference_pipeline): metrics["volume_total"] = calc_vol metrics["num_calc"] = num_lesions - + # percent of the aorta calcificed - metrics['perc_calcified'] = (calc_mask_region.sum() / aorta_mask_region.sum()) * 100 - + metrics["perc_calcified"] = ( + calc_mask_region.sum() / aorta_mask_region.sum() + ) * 100 if inference_pipeline.args.threshold == "agatston": if num_lesions == 0: diff --git a/comp2comp/aortic_calcium/aortic_calcium_visualization.py b/comp2comp/aortic_calcium/aortic_calcium_visualization.py index 351a8891..7a796fa1 100644 --- a/comp2comp/aortic_calcium/aortic_calcium_visualization.py +++ b/comp2comp/aortic_calcium/aortic_calcium_visualization.py @@ -2,8 +2,13 @@ import numpy as np +from comp2comp.aortic_calcium.visualization_utils import ( + createCalciumMosaic, + createMipPlot, + mergeMipAndMosaic, +) from comp2comp.inference_class_base import InferenceClass -from comp2comp.aortic_calcium.visualization_utils import createMipPlot, createCalciumMosaic, mergeMipAndMosaic + class AorticCalciumVisualizer(InferenceClass): def __init__(self): @@ -16,10 +21,10 @@ def __call__(self, inference_pipeline): if not os.path.exists(self.output_dir_images_organs): os.makedirs(self.output_dir_images_organs) - + # Create MIP part of the overview plot createMipPlot( - inference_pipeline.ct, + inference_pipeline.ct, inference_pipeline.calc_mask, inference_pipeline.aorta_mask, inference_pipeline.t12_plane == 1, @@ -28,26 +33,24 @@ def __call__(self, inference_pipeline): inference_pipeline.metrics, self.output_dir_images_organs, ) - - ab_num = inference_pipeline.metrics['Abdominal']['num_calc'] - th_num = inference_pipeline.metrics['Thoracic']['num_calc'] + + ab_num = inference_pipeline.metrics["Abdominal"]["num_calc"] + th_num = inference_pipeline.metrics["Thoracic"]["num_calc"] # Create mosaic part of the overview plot - if not (ab_num == 0 and th_num == 0): + if not (ab_num == 0 and th_num == 0): createCalciumMosaic( - inference_pipeline.ct, + inference_pipeline.ct, inference_pipeline.calc_mask, - inference_pipeline.dilated_aorta_mask, # the dilated mask is used here + inference_pipeline.dilated_aorta_mask, # the dilated mask is used here inference_pipeline.spine_mask, inference_pipeline.pix_dims, self.output_dir_images_organs, inference_pipeline.args.mosaic_type, ) - - # Merge the two images created above for the final report - mergeMipAndMosaic( - self.output_dir_images_organs - ) - + + # Merge the two images created above for the final report + mergeMipAndMosaic(self.output_dir_images_organs) + return {} @@ -145,7 +148,9 @@ def __call__(self, inference_pipeline): "{},{:.3f}\n".format("Min volume (cm³):", np.min(metrics["volume"])) ) f.write( - "{},{:.3f}\n".format("% Calcified aorta:", metrics["perc_calcified"]) + "{},{:.3f}\n".format( + "% Calcified aorta:", metrics["perc_calcified"] + ) ) if inference_pipeline.args.threshold == "agatston": @@ -228,7 +233,7 @@ def __call__(self, inference_pipeline): metrics["perc_calcified"], ) ) - + if inference_pipeline.args.threshold == "agatston": print( "{:<{}}{:.1f}".format( diff --git a/comp2comp/aortic_calcium/visualization_utils.py b/comp2comp/aortic_calcium/visualization_utils.py index b8b1f143..b19781ba 100644 --- a/comp2comp/aortic_calcium/visualization_utils.py +++ b/comp2comp/aortic_calcium/visualization_utils.py @@ -2,19 +2,18 @@ import os import shutil -import matplotlib.patches as patches import matplotlib.pyplot as plt +import numpy as np from matplotlib.colors import ListedColormap from matplotlib.transforms import Bbox -import nibabel as nib -import numpy as np from numpy.typing import NDArray from PIL import Image # color map used for segmnetations -color_array = plt.get_cmap('Set1')(range(10)) -color_array = np.concatenate( (np.array([[0,0,0,0]]),color_array[:-1,:]), axis=0) -map_object_seg = ListedColormap(name='segmentation_cmap',colors=color_array) +color_array = plt.get_cmap("Set1")(range(10)) +color_array = np.concatenate((np.array([[0, 0, 0, 0]]), color_array[:-1, :]), axis=0) +map_object_seg = ListedColormap(name="segmentation_cmap", colors=color_array) + def createMipPlot( ct: NDArray, @@ -24,27 +23,26 @@ def createMipPlot( HU_val: int, pix_size: NDArray, metrics: dict, - save_root: str - ) -> None: - ''' + save_root: str, +) -> None: + """ Create a MIP projection in the frontal and side plane with - the calcication overlayed. The text box is generated seperately + the calcication overlayed. The text box is generated seperately and then resampled to the MIP - ''' + """ - - ''' + """ Generate MIP image - ''' + """ # Create transparent hot cmap - hot = plt.get_cmap('hot', 256) + hot = plt.get_cmap("hot", 256) hot_colors = hot(np.linspace(0, 1, 256)) - hot_colors[0, -1] = 0 + hot_colors[0, -1] = 0 hot_transparent = ListedColormap(hot_colors) - fig, axx = plt.subplots(figsize=(12,12), dpi=300) - fig.patch.set_facecolor('black') - axx.set_facecolor('black') + fig, axx = plt.subplots(figsize=(12, 12), dpi=300) + fig.patch.set_facecolor("black") + axx.set_facecolor("black") # Create the frontal projection thres = 300 @@ -55,10 +53,12 @@ def createMipPlot( calc_mask_proj = np.flip(np.transpose(calc_mask.sum(axis=1)), axis=0) # normalize both views for the heat map if not calc_mask_proj.max() == 0: - calc_mask_proj = calc_mask_proj/calc_mask_proj.max() + calc_mask_proj = calc_mask_proj / calc_mask_proj.max() - aorta_mask_proj = np.flip(np.transpose(aorta_mask.max(axis=1)), axis=0)*2 - plane_mask_proj = np.where(np.flip(np.transpose( (plane_mask == 1).max(axis=(0,1))), axis=0))[0][0] + aorta_mask_proj = np.flip(np.transpose(aorta_mask.max(axis=1)), axis=0) * 2 + plane_mask_proj = np.where( + np.flip(np.transpose((plane_mask == 1).max(axis=(0, 1))), axis=0) + )[0][0] # Create the side projection ct_proj_side = np.flip(np.transpose(ct.max(axis=0)), axis=0) @@ -67,9 +67,9 @@ def createMipPlot( calc_mask_proj_side = np.flip(np.transpose(calc_mask.sum(axis=0)), axis=0) # normalize both views for the heat map if not calc_mask_proj_side.max() == 0: - calc_mask_proj_side = calc_mask_proj_side/calc_mask_proj_side.max() - - aorta_mask_proj_side = np.flip(np.transpose(aorta_mask.max(axis=0)), axis=0)*2 + calc_mask_proj_side = calc_mask_proj_side / calc_mask_proj_side.max() + + aorta_mask_proj_side = np.flip(np.transpose(aorta_mask.max(axis=0)), axis=0) * 2 # Concatenate together ct_proj_all = np.hstack([ct_proj, ct_proj_side]) @@ -77,79 +77,171 @@ def createMipPlot( aorta_mask_proj_all = np.hstack([aorta_mask_proj, aorta_mask_proj_side]) # Plot the results - axx.imshow(ct_proj_all, cmap='gray', vmin=thres, vmax=1600, aspect=pix_size[2]/pix_size[0], alpha=1) + axx.imshow( + ct_proj_all, + cmap="gray", + vmin=thres, + vmax=1600, + aspect=pix_size[2] / pix_size[0], + alpha=1, + ) # Aorta mask and calcification - aorta_im = axx.imshow(aorta_mask_proj_all, cmap=map_object_seg, zorder=1, - vmin=0, vmax=10, aspect=pix_size[2]/pix_size[0], alpha=0.6, interpolation='nearest') - calc_im = axx.imshow(calc_mask_proj_all, cmap=hot_transparent, aspect=pix_size[2]/pix_size[0], alpha=1, - interpolation='nearest', zorder=2) + aorta_im = axx.imshow( + aorta_mask_proj_all, + cmap=map_object_seg, + zorder=1, + vmin=0, + vmax=10, + aspect=pix_size[2] / pix_size[0], + alpha=0.6, + interpolation="nearest", + ) + calc_im = axx.imshow( + calc_mask_proj_all, + cmap=hot_transparent, + aspect=pix_size[2] / pix_size[0], + alpha=1, + interpolation="nearest", + zorder=2, + ) # Ab and Th separating plane - axx.plot([0, ct_proj_all.shape[1]], [plane_mask_proj, plane_mask_proj], - color=map_object_seg(3), lw=0.8, alpha=0.8, zorder=0) - axx.text(30, plane_mask_proj - 8, 'Thoracic', color=map_object_seg(3), - va='center', ha='left', alpha=0.8, fontsize=10) - axx.text(30, plane_mask_proj + 8, 'Abdominal', color=map_object_seg(3), - va='center', ha='left', alpha=0.8, fontsize=10) + axx.plot( + [0, ct_proj_all.shape[1]], + [plane_mask_proj, plane_mask_proj], + color=map_object_seg(3), + lw=0.8, + alpha=0.8, + zorder=0, + ) + axx.text( + 30, + plane_mask_proj - 8, + "Thoracic", + color=map_object_seg(3), + va="center", + ha="left", + alpha=0.8, + fontsize=10, + ) + axx.text( + 30, + plane_mask_proj + 8, + "Abdominal", + color=map_object_seg(3), + va="center", + ha="left", + alpha=0.8, + fontsize=10, + ) axx.set_xticks([]) - axx.set_ylabel('Slice number', color='white', fontsize=10) - axx.tick_params(axis='y', colors='white', labelsize=10) + axx.set_ylabel("Slice number", color="white", fontsize=10) + axx.tick_params(axis="y", colors="white", labelsize=10) - # extend black background + # extend black background axx.set_xlim(0, ct_proj_all.shape[1]) # wrap plot in Image tight_bbox = fig.get_tightbbox(fig.canvas.get_renderer()) # Create a new bounding box with padding only on the left # The bbox coordinates are [left, bottom, right, top] - custom_bbox = Bbox([[tight_bbox.x0 - 0.05, tight_bbox.y0-0.07], # Add 0.5 inches to the left only - [tight_bbox.x1, tight_bbox.y1]]) + custom_bbox = Bbox( + [ + [ + tight_bbox.x0 - 0.05, + tight_bbox.y0 - 0.07, + ], # Add 0.5 inches to the left only + [tight_bbox.x1, tight_bbox.y1], + ] + ) buf_mip = io.BytesIO() - fig.savefig(buf_mip, bbox_inches=custom_bbox, pad_inches=0, dpi=300, format='png') + fig.savefig(buf_mip, bbox_inches=custom_bbox, pad_inches=0, dpi=300, format="png") plt.close(fig) buf_mip.seek(0) image_mip = Image.open(buf_mip) - ''' + """ Generate the text box - ''' + """ spacing = 23 indent = 1 - text_box_x_offset = 20 report_text = [] - report_text.append(r'$\bf{Calcification\ Report}$') - + report_text.append(r"$\bf{Calcification\ Report}$") + for i, (region, region_metrics) in enumerate(metrics.items()): - report_text.append(region) - report_text.append('{:<{}}{:<{}}{}'.format('',indent, 'Total number:', spacing, region_metrics['num_calc'])) - report_text.append('{:<{}}{:<{}}{:.3f}'.format('',indent,'Total volume (cm³):', spacing, region_metrics['volume_total'])) - report_text.append('{:<{}}{:<{}}{:.3f}{}{:.3f}'.format('',indent,'Mean volume (cm³):', spacing, - np.mean(region_metrics["volume"]),r'$\pm$',np.std(region_metrics["volume"]))) - report_text.append('{:<{}}{:<{}}{:.1f}{}{:.1f}'.format('',indent,'Median HU:', spacing, - np.mean(region_metrics["median_hu"]),r'$\pm$',np.std(region_metrics["median_hu"]))) - report_text.append('{:<{}}{:<{}}{:.3f}'.format('',indent,'% Volume calcified:', spacing, - np.mean(region_metrics["perc_calcified"]))) - - if 'agatston_score' in region_metrics: - report_text.append('{:<{}}{:<{}}{:.0f}'.format('',indent,'Agatston:', spacing, region_metrics['agatston_score'])) - - report_text.append('\n') - - report_text.append('{:<{}}{:<{}}{}'.format('',indent, 'Threshold (HU):', spacing, HU_val)) - - fig_t, axx_t = plt.subplots(figsize=(5.85,5.85), dpi=300) - fig_t.patch.set_facecolor('black') - axx_t.set_facecolor('black') - - axx_t.imshow(np.ones((100,65)), cmap='gray') + report_text.append(region) + report_text.append( + "{:<{}}{:<{}}{}".format( + "", indent, "Total number:", spacing, region_metrics["num_calc"] + ) + ) + report_text.append( + "{:<{}}{:<{}}{:.3f}".format( + "", + indent, + "Total volume (cm³):", + spacing, + region_metrics["volume_total"], + ) + ) + report_text.append( + "{:<{}}{:<{}}{:.3f}{}{:.3f}".format( + "", + indent, + "Mean volume (cm³):", + spacing, + np.mean(region_metrics["volume"]), + r"$\pm$", + np.std(region_metrics["volume"]), + ) + ) + report_text.append( + "{:<{}}{:<{}}{:.1f}{}{:.1f}".format( + "", + indent, + "Median HU:", + spacing, + np.mean(region_metrics["median_hu"]), + r"$\pm$", + np.std(region_metrics["median_hu"]), + ) + ) + report_text.append( + "{:<{}}{:<{}}{:.3f}".format( + "", + indent, + "% Volume calcified:", + spacing, + np.mean(region_metrics["perc_calcified"]), + ) + ) + + if "agatston_score" in region_metrics: + report_text.append( + "{:<{}}{:<{}}{:.0f}".format( + "", indent, "Agatston:", spacing, region_metrics["agatston_score"] + ) + ) + + report_text.append("\n") + + report_text.append( + "{:<{}}{:<{}}{}".format("", indent, "Threshold (HU):", spacing, HU_val) + ) + + fig_t, axx_t = plt.subplots(figsize=(5.85, 5.85), dpi=300) + fig_t.patch.set_facecolor("black") + axx_t.set_facecolor("black") + + axx_t.imshow(np.ones((100, 65)), cmap="gray") bbox_props = dict(boxstyle="round", facecolor="gray", alpha=0.5) text_obj = axx_t.text( x=0.3, y=1, - s='\n'.join(report_text), + s="\n".join(report_text), color="white", # fontsize=12.12 * font_scale, fontsize=15.7, @@ -159,17 +251,26 @@ def createMipPlot( ha="left", ) - axx_t.set_aspect(0.69) - axx_t.axis('off') + axx_t.set_aspect(0.69) + axx_t.axis("off") fig.canvas.draw_idle() plt.tight_layout() - + # wrap text box in Image tight_bbox_text = fig_t.get_tightbbox(fig_t.canvas.get_renderer()) - custom_bbox_text = Bbox([[tight_bbox_text.x0 - 0.05, tight_bbox_text.y0-0.14], # Add 0.5 inches to the left only - [tight_bbox_text.x1+0.15, tight_bbox_text.y1+0.05]]) + custom_bbox_text = Bbox( + [ + [ + tight_bbox_text.x0 - 0.05, + tight_bbox_text.y0 - 0.14, + ], # Add 0.5 inches to the left only + [tight_bbox_text.x1 + 0.15, tight_bbox_text.y1 + 0.05], + ] + ) buf_text = io.BytesIO() - fig_t.savefig(buf_text, bbox_inches=custom_bbox_text, pad_inches=0, dpi=300, format='png') + fig_t.savefig( + buf_text, bbox_inches=custom_bbox_text, pad_inches=0, dpi=300, format="png" + ) plt.close(fig_t) buf_text.seek(0) image_text = Image.open(buf_text) @@ -180,30 +281,34 @@ def createMipPlot( aspect_ratio = image_text.height / image_text.width adjusted_width = int(image_mip.height / aspect_ratio) - image_text_resample = image_text.resize([adjusted_width, image_mip.height], Image.LANCZOS) - + image_text_resample = image_text.resize( + [adjusted_width, image_mip.height], Image.LANCZOS + ) + # Merge into one image - result = Image.new("RGB", (image_mip.width + image_text_resample.width, image_mip.height)) + result = Image.new( + "RGB", (image_mip.width + image_text_resample.width, image_mip.height) + ) result.paste(im=image_mip, box=(0, 0)) result.paste(im=image_text_resample, box=(image_mip.width, 0)) # create path and save - path = os.path.join(save_root, 'sub_figures') + path = os.path.join(save_root, "sub_figures") os.makedirs(path, exist_ok=True) - result.save(os.path.join(path,'projection.png'), dpi=(300, 300)) - + result.save(os.path.join(path, "projection.png"), dpi=(300, 300)) + def crop_and_pad_image(image, pad_percent=0.025, pad_color=(0, 0, 0, 255)): # Ensure image has alpha channel - image = image.convert('RGBA') - + image = image.convert("RGBA") + # Create a binary mask: 255 for non-black, 0 for black - mask = image.convert('RGB').point(lambda p: 255 if p != 0 else 0).convert('L') + mask = image.convert("RGB").point(lambda p: 255 if p != 0 else 0).convert("L") bbox = mask.getbbox() if not bbox: return image # No content found; return original - + # Crop the image cropped = image.crop(bbox) @@ -212,7 +317,7 @@ def crop_and_pad_image(image, pad_percent=0.025, pad_color=(0, 0, 0, 255)): pad = int(width * pad_percent) # Add padding - padded = Image.new('RGBA', (width + pad * 2, height + pad * 2), pad_color) + padded = Image.new("RGBA", (width + pad * 2, height + pad * 2), pad_color) padded.paste(cropped, (pad, pad)) return padded @@ -225,13 +330,13 @@ def createCalciumMosaic( spine_mask: NDArray, pix_size: NDArray, save_root: str, - mosaic_type: str = 'all' - ) -> None: - ''' - Wrapper function that calls different functions for creating the mosaic + mosaic_type: str = "all", +) -> None: + """ + Wrapper function that calls different functions for creating the mosaic depending on the "mosaic_type". - ''' - if mosaic_type == 'all': + """ + if mosaic_type == "all": createCalciumMosaicAll( ct, calc_mask, @@ -239,7 +344,7 @@ def createCalciumMosaic( pix_size, save_root, ) - elif mosaic_type == 'vertebrae': + elif mosaic_type == "vertebrae": createCalciumMosaicVertebrae( ct, calc_mask, @@ -249,7 +354,8 @@ def createCalciumMosaic( save_root, ) else: - raise ValueError('mosaic_type not recognized, got: ' + str(mosaic_type)) + raise ValueError("mosaic_type not recognized, got: " + str(mosaic_type)) + def createCalciumMosaicAll( ct: NDArray, @@ -257,9 +363,9 @@ def createCalciumMosaicAll( aorta_mask_dil: NDArray, pix_size: NDArray, save_root: str, - ) -> None: - - calc_idx = np.where( calc_mask.sum(axis=(0,1)) )[0] +) -> None: + + calc_idx = np.where(calc_mask.sum(axis=(0, 1)))[0] per_row = 15 @@ -268,7 +374,7 @@ def createCalciumMosaicAll( # target size of 60 mm crop size rounded to nearest multiple of 2 crop_size = round(60 / pix_size[0]) - crop_size = 2*round(crop_size/2) + crop_size = 2 * round(crop_size / 2) ct_crops = [] mask_crops = [] @@ -279,32 +385,38 @@ def createCalciumMosaicAll( aorta_dil_tmp = [] for i, idx in enumerate(calc_idx[::-1]): - ct_slice = np.flip(np.transpose(ct[:,:,idx]), axis=(0,1)) - mask_slice = np.flip(np.transpose(calc_mask[:,:,idx]), axis=(0,1)) - aorta_slice = np.flip(np.transpose(aorta_mask_dil[:,:,idx]), axis=(0,1)) - + ct_slice = np.flip(np.transpose(ct[:, :, idx]), axis=(0, 1)) + mask_slice = np.flip(np.transpose(calc_mask[:, :, idx]), axis=(0, 1)) + aorta_slice = np.flip(np.transpose(aorta_mask_dil[:, :, idx]), axis=(0, 1)) + x_center = np.where(aorta_slice.sum(axis=1))[0] - x_center = x_center[len(x_center)//2] - + x_center = x_center[len(x_center) // 2] + y_center = np.where(aorta_slice.sum(axis=0))[0] - y_center = y_center[len(y_center)//2] - - ct_tmp.append(ct_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ]) - - mask_tmp.append(mask_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ]) - - aorta_dil_tmp.append(aorta_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ]) - - if (i+1) % per_row == 0: + y_center = y_center[len(y_center) // 2] + + ct_tmp.append( + ct_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + + mask_tmp.append( + mask_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + + aorta_dil_tmp.append( + aorta_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + + if (i + 1) % per_row == 0: # print('got here') ct_crops.append(np.hstack(ct_tmp)) mask_crops.append(np.hstack(mask_tmp)) @@ -315,13 +427,33 @@ def createCalciumMosaicAll( aorta_dil_tmp = [] if len(ct_tmp) > 0: - pad_len = per_row*crop_size - len(ct_tmp)*crop_size - - ct_crops.append(np.pad(np.hstack(ct_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=-400)) - mask_crops.append(np.pad(np.hstack(mask_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=0)) - aorta_dil_crops.append(np.pad(np.hstack(aorta_dil_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=0)) + pad_len = per_row * crop_size - len(ct_tmp) * crop_size + + ct_crops.append( + np.pad( + np.hstack(ct_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=-400, + ) + ) + mask_crops.append( + np.pad( + np.hstack(mask_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=0, + ) + ) + aorta_dil_crops.append( + np.pad( + np.hstack(aorta_dil_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=0, + ) + ) - ct_crops = np.vstack(ct_crops) mask_crops = np.vstack(mask_crops) aorta_dil_crops = np.vstack(aorta_dil_crops) @@ -330,29 +462,47 @@ def createCalciumMosaicAll( combined_crops[aorta_dil_crops == 1] = 2 combined_crops[mask_crops == 1] = 1 - fig, axx = plt.subplots(1, figsize=(20,20)) + fig, axx = plt.subplots(1, figsize=(20, 20)) + + axx.imshow(ct_crops, cmap="gray", vmin=-150, vmax=250) + axx.imshow( + combined_crops, + vmin=0, + vmax=10, + cmap=map_object_seg, + interpolation="nearest", + alpha=0.4, + ) - axx.imshow(ct_crops, cmap='gray', vmin=-150, vmax=250) - axx.imshow(combined_crops, vmin=0, vmax=10, cmap=map_object_seg, interpolation='nearest', alpha=0.4) - for i, idx_ in enumerate(calc_idx[::-1]): # account for how the MIP is displayed idx = calc_mask.shape[2] - idx_ - x_ = crop_size*(i % per_row) + 4 # crop_size//2 - y_ = crop_size*(i // per_row) + 4 - axx.text(x_, y_, s=str(idx), color = 'white', fontsize=10, va='top', ha='left', - bbox=dict(facecolor='black', edgecolor='black', boxstyle='round,pad=0.1')) + x_ = crop_size * (i % per_row) + 4 # crop_size//2 + y_ = crop_size * (i // per_row) + 4 + axx.text( + x_, + y_, + s=str(idx), + color="white", + fontsize=10, + va="top", + ha="left", + bbox=dict(facecolor="black", edgecolor="black", boxstyle="round,pad=0.1"), + ) axx.set_xticks([]) axx.set_yticks([]) - - path = os.path.join(save_root, 'sub_figures') + + path = os.path.join(save_root, "sub_figures") os.makedirs(path, exist_ok=True) # Save with the custom bounding box - fig.savefig(os.path.join(path,'mosaic.png'),bbox_inches="tight", pad_inches=0, dpi=300) + fig.savefig( + os.path.join(path, "mosaic.png"), bbox_inches="tight", pad_inches=0, dpi=300 + ) plt.close(fig) + def createCalciumMosaicVertebrae( ct: NDArray, calc_mask: NDArray, @@ -360,8 +510,8 @@ def createCalciumMosaicVertebrae( spine_mask: NDArray, pix_size: NDArray, save_root: str, - ) -> None: - +) -> None: + vertebrae_num = { 26: "vertebrae_S1", 27: "vertebrae_L5", @@ -387,11 +537,12 @@ def createCalciumMosaicVertebrae( 47: "vertebrae_C4", 48: "vertebrae_C3", 49: "vertebrae_C2", - 50: "vertebrae_C1"} + 50: "vertebrae_C1", + } vertebrae_name = {v: k for k, v in vertebrae_num.items()} - calc_idx = np.where( calc_mask.sum(axis=(0,1)) )[0] + calc_idx = np.where(calc_mask.sum(axis=(0, 1)))[0] per_row = 5 found_vertebras = np.unique(spine_mask) @@ -400,7 +551,7 @@ def createCalciumMosaicVertebrae( # target size of 120 mm crop size rounded to nearest multiple of 2 crop_size = round(120 / pix_size[0]) - crop_size = 2*round(crop_size/2) + crop_size = 2 * round(crop_size / 2) ct_crops = [] mask_crops = [] @@ -416,42 +567,47 @@ def createCalciumMosaicVertebrae( for vert_num in found_vertebras[::-1]: # only vertebraes if vert_num < 26: - continue - + continue + # Find middle of the vertebra tmp_spine_mask = spine_mask == vert_num - tmp_spine_idx = np.where(tmp_spine_mask.sum(axis=(0,1)))[0] - idx = tmp_spine_idx[len(tmp_spine_idx)//2] + tmp_spine_idx = np.where(tmp_spine_mask.sum(axis=(0, 1)))[0] + idx = tmp_spine_idx[len(tmp_spine_idx) // 2] + + ct_slice = np.flip(np.transpose(ct[:, :, idx]), axis=(0, 1)) + mask_slice = np.flip(np.transpose(calc_mask[:, :, idx]), axis=(0, 1)) + aorta_slice = np.flip(np.transpose(aorta_mask_dil[:, :, idx]), axis=(0, 1)) - ct_slice = np.flip(np.transpose(ct[:,:,idx]), axis=(0,1)) - mask_slice = np.flip(np.transpose(calc_mask[:,:,idx]), axis=(0,1)) - aorta_slice = np.flip(np.transpose(aorta_mask_dil[:,:,idx]), axis=(0,1)) - x_center = np.where(aorta_slice.sum(axis=1))[0] # skip if no aorta is present if len(x_center) == 0: continue - x_center = x_center[len(x_center)//2] - + x_center = x_center[len(x_center) // 2] + y_center = np.where(aorta_slice.sum(axis=0))[0] - y_center = y_center[len(y_center)//2] - - ct_tmp.append(ct_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ] - ) - mask_tmp.append(mask_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ]) - - aorta_dil_tmp.append(aorta_slice[ - x_center-crop_size//2:x_center+crop_size//2, - y_center-crop_size//2:y_center+crop_size//2, - ]) - - if (i+1) % per_row == 0: + y_center = y_center[len(y_center) // 2] + + ct_tmp.append( + ct_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + mask_tmp.append( + mask_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + + aorta_dil_tmp.append( + aorta_slice[ + x_center - crop_size // 2 : x_center + crop_size // 2, + y_center - crop_size // 2 : y_center + crop_size // 2, + ] + ) + + if (i + 1) % per_row == 0: # print('got here') ct_crops.append(np.hstack(ct_tmp)) mask_crops.append(np.hstack(mask_tmp)) @@ -460,18 +616,38 @@ def createCalciumMosaicVertebrae( ct_tmp = [] mask_tmp = [] aorta_dil_tmp = [] - - vertebra_names.append(vertebrae_num[vert_num].split('_')[1]) + + vertebra_names.append(vertebrae_num[vert_num].split("_")[1]) i += 1 if len(ct_tmp) > 0: - pad_len = per_row*crop_size - len(ct_tmp)*crop_size - - ct_crops.append(np.pad(np.hstack(ct_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=-400)) - mask_crops.append(np.pad(np.hstack(mask_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=0)) - aorta_dil_crops.append(np.pad(np.hstack(aorta_dil_tmp), ((0,0), (0,pad_len)), mode='constant', constant_values=0)) + pad_len = per_row * crop_size - len(ct_tmp) * crop_size + + ct_crops.append( + np.pad( + np.hstack(ct_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=-400, + ) + ) + mask_crops.append( + np.pad( + np.hstack(mask_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=0, + ) + ) + aorta_dil_crops.append( + np.pad( + np.hstack(aorta_dil_tmp), + ((0, 0), (0, pad_len)), + mode="constant", + constant_values=0, + ) + ) - ct_crops = np.vstack(ct_crops) mask_crops = np.vstack(mask_crops) aorta_dil_crops = np.vstack(aorta_dil_crops) @@ -480,51 +656,74 @@ def createCalciumMosaicVertebrae( combined_crops[aorta_dil_crops == 1] = 2 combined_crops[mask_crops == 1] = 1 - fig, axx = plt.subplots(1, figsize=(20,20)) + fig, axx = plt.subplots(1, figsize=(20, 20)) + + axx.imshow(ct_crops, cmap="gray", vmin=-150, vmax=250) + axx.imshow( + combined_crops, + vmin=0, + vmax=10, + cmap=map_object_seg, + interpolation="nearest", + alpha=0.4, + ) - axx.imshow(ct_crops, cmap='gray', vmin=-150, vmax=250) - axx.imshow(combined_crops, vmin=0, vmax=10, cmap=map_object_seg, interpolation='nearest', alpha=0.4) - # for i, idx_ in enumerate(calc_idx[::-1]): for i, name in enumerate(vertebra_names): - x_ = crop_size*(i % per_row) + 4 # crop_size//2 - y_ = crop_size*(i // per_row) + 4 - axx.text(x_, y_, s=name, color = 'white', fontsize=18, va='top', ha='left', - bbox=dict(facecolor='black', edgecolor='black', boxstyle='round,pad=0.1')) + x_ = crop_size * (i % per_row) + 4 # crop_size//2 + y_ = crop_size * (i // per_row) + 4 + axx.text( + x_, + y_, + s=name, + color="white", + fontsize=18, + va="top", + ha="left", + bbox=dict(facecolor="black", edgecolor="black", boxstyle="round,pad=0.1"), + ) axx.set_xticks([]) axx.set_yticks([]) - - path = os.path.join(save_root, 'sub_figures') + + path = os.path.join(save_root, "sub_figures") os.makedirs(path, exist_ok=True) # Save with the custom bounding box - fig.savefig(os.path.join(path,'mosaic.png'),bbox_inches="tight", pad_inches=0, dpi=300) + fig.savefig( + os.path.join(path, "mosaic.png"), bbox_inches="tight", pad_inches=0, dpi=300 + ) plt.close(fig) - - + + def mergeMipAndMosaic(save_root: str) -> None: - ''' + """ This function loads the MIP and mosaic images and merges them into a single final image. - ''' - - if os.path.isfile(os.path.join(save_root, 'sub_figures/mosaic.png')): - img_proj = Image.open(os.path.join(save_root, 'sub_figures/projection.png')) - img_mosaic = Image.open(os.path.join(save_root, 'sub_figures/mosaic.png')) - + """ + + if os.path.isfile(os.path.join(save_root, "sub_figures/mosaic.png")): + img_proj = Image.open(os.path.join(save_root, "sub_figures/projection.png")) + img_mosaic = Image.open(os.path.join(save_root, "sub_figures/mosaic.png")) + # Match the width to the projection image aspect_ratio = img_mosaic.height / img_mosaic.width new_height = int(img_proj.width * aspect_ratio) - img_mosaic_resample = img_mosaic.resize([img_proj.width, new_height], Image.LANCZOS) + img_mosaic_resample = img_mosaic.resize( + [img_proj.width, new_height], Image.LANCZOS + ) - result = Image.new("RGB", (img_proj.width, img_proj.height + img_mosaic_resample.height)) + result = Image.new( + "RGB", (img_proj.width, img_proj.height + img_mosaic_resample.height) + ) result.paste(im=img_proj, box=(0, 0)) result.paste(im=img_mosaic_resample, box=(0, img_proj.height)) - result.save(os.path.join(save_root,'overview.png'), dpi=(300, 300)) + result.save(os.path.join(save_root, "overview.png"), dpi=(300, 300)) else: - shutil.copy2(os.path.join(save_root, 'sub_figures/projection.png'), os.path.join(save_root,'overview.png')) - \ No newline at end of file + shutil.copy2( + os.path.join(save_root, "sub_figures/projection.png"), + os.path.join(save_root, "overview.png"), + ) diff --git a/comp2comp/contrast_phase/contrast_inf.py b/comp2comp/contrast_phase/contrast_inf.py index 456d217f..a99bef9c 100644 --- a/comp2comp/contrast_phase/contrast_inf.py +++ b/comp2comp/contrast_phase/contrast_inf.py @@ -8,8 +8,10 @@ import scipy import SimpleITK as sitk from scipy import ndimage as ndi + # import xgboost + def loadNiiToArray(path): NiImg = nib.load(path) array = np.array(NiImg.dataobj) @@ -176,7 +178,7 @@ def getFeatures(TSArray, scanArray): kidneyLMask = getClassBinaryMask(TSArray, 3) kidneyRMask = getClassBinaryMask(TSArray, 2) adRMask = getClassBinaryMask(TSArray, 11) - + # aortaMask = getClassBinaryMask(TSArray, 52) # IVCMask = getClassBinaryMask(TSArray, 63) # portalMask = getClassBinaryMask(TSArray, 64) @@ -421,16 +423,11 @@ def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False): model = loadModel() # TS_array, image_array = loadNiftis(TS_output_nifti_path, image_nifti_path) featureArray, kidneyLMask, adRMask = getFeatures(TS_array, image_array) - + y_pred_proba = model.predict_proba([featureArray])[0] y_pred = np.argmax(y_pred_proba) - - phase_dict = { - 0: "non-contrast", - 1: "arterial", - 2: "venous", - 3: "delayed" - } + + phase_dict = {0: "non-contrast", 1: "arterial", 2: "venous", 3: "delayed"} pred_phase = phase_dict[y_pred] @@ -447,15 +444,15 @@ def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False): if not os.path.exists(output_path_metrics): os.makedirs(output_path_metrics) outputTxt = os.path.join(output_path_metrics, "phase_prediction.txt") - + with open(outputTxt, "w") as text_file: - text_file.write('phase,'+pred_phase + '\n') + text_file.write("phase," + pred_phase + "\n") for i in range(len(y_pred_proba)): - text_file.write('{},{:.3f}\n'.format(phase_dict[i], y_pred_proba[i])) + text_file.write("{},{:.3f}\n".format(phase_dict[i], y_pred_proba[i])) - print('Predicted phase: ' + pred_phase) + print("Predicted phase: " + pred_phase) for i in range(len(y_pred_proba)): - print('{},{:.3f}'.format(phase_dict[i], y_pred_proba[i])) + print("{},{:.3f}".format(phase_dict[i], y_pred_proba[i])) output_path_images = os.path.join(outputPath, "images") if not os.path.exists(output_path_images): diff --git a/comp2comp/contrast_phase/contrast_phase.py b/comp2comp/contrast_phase/contrast_phase.py index 74dab756..ca71a047 100644 --- a/comp2comp/contrast_phase/contrast_phase.py +++ b/comp2comp/contrast_phase/contrast_phase.py @@ -8,11 +8,12 @@ nostdout, setup_nnunet, ) -# from totalsegmentatorv2.python_api import totalsegmentator from comp2comp.contrast_phase.contrast_inf import predict_phase from comp2comp.inference_class_base import InferenceClass +# from totalsegmentatorv2.python_api import totalsegmentator + class ContrastPhaseDetection(InferenceClass): """Contrast Phase Detection.""" @@ -94,7 +95,7 @@ def run_segmentation( verbose=False, test=0, ) - + # seg = totalsegmentator( # input = os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), # output = os.path.join(self.output_dir_segmentations, "segmentation.nii"), diff --git a/comp2comp/inference_pipeline.py b/comp2comp/inference_pipeline.py index 0aca5daf..63d4163c 100644 --- a/comp2comp/inference_pipeline.py +++ b/comp2comp/inference_pipeline.py @@ -27,7 +27,7 @@ def __init__( self.inference_classes = inference_classes - def __call__(self, inference_pipeline=None, **kwargs): + def __call__(self, inference_pipeline=None, input_path: str = None, **kwargs): # print out the class names for each inference class print("") print("Inference pipeline:") @@ -37,12 +37,16 @@ def __call__(self, inference_pipeline=None, **kwargs): print("Starting inference pipeline for:\n") - if inference_pipeline: - for key, value in kwargs.items(): - setattr(inference_pipeline, key, value) - else: - for key, value in kwargs.items(): - setattr(self, key, value) + # pick whether we’re writing onto self or the wrapped pipeline + target = inference_pipeline or self + + # 1) stash input_path if provided + if input_path is not None: + target.input_path = input_path + + # 2) stash everything else (output_dir, model_dir, etc.) + for key, value in kwargs.items(): + setattr(target, key, value) output = {} for inference_class in self.inference_classes: diff --git a/comp2comp/io/io.py b/comp2comp/io/io.py index 9870e41d..843a1640 100644 --- a/comp2comp/io/io.py +++ b/comp2comp/io/io.py @@ -95,8 +95,10 @@ def __call__(self, inference_pipeline): # if self.input_path is a folder if self.input_path.is_dir(): # store a dcm object for retrieving dicom tags - dcm_files = [d for d in os.listdir(self.input_path) if d.endswith('.dcm')] - inference_pipeline.dcm = pydicom.read_file(os.path.join(self.input_path, dcm_files[0])) + dcm_files = [d for d in os.listdir(self.input_path) if d.endswith(".dcm")] + inference_pipeline.dcm = pydicom.read_file( + os.path.join(self.input_path, dcm_files[0]) + ) ds = dicom_series_to_nifti( self.input_path, diff --git a/comp2comp/io/io_utils.py b/comp2comp/io/io_utils.py index d062a598..4d154bd8 100644 --- a/comp2comp/io/io_utils.py +++ b/comp2comp/io/io_utils.py @@ -58,11 +58,17 @@ def get_dicom_or_nifti_paths_and_num(path): with open(path, "r") as f: for dicom_folder_path in f: dicom_folder_path = dicom_folder_path.strip() - if dicom_folder_path.endswith(".nii") or dicom_folder_path.endswith(".nii.gz"): - dicom_nifti_paths.append( (dicom_folder_path, getNumSlicesNifti(dicom_folder_path))) + if dicom_folder_path.endswith(".nii") or dicom_folder_path.endswith( + ".nii.gz" + ): + dicom_nifti_paths.append( + (dicom_folder_path, getNumSlicesNifti(dicom_folder_path)) + ) else: - dicom_nifti_paths.append( (dicom_folder_path, len(os.listdir(dicom_folder_path)))) - else: + dicom_nifti_paths.append( + (dicom_folder_path, len(os.listdir(dicom_folder_path))) + ) + else: for root, dirs, files in os.walk(path): if len(files) > 0: # if all(file.endswith(".dcm") or file.endswith(".dicom") for file in files): diff --git a/comp2comp/spine/spine.py b/comp2comp/spine/spine.py index d46667df..756291f5 100644 --- a/comp2comp/spine/spine.py +++ b/comp2comp/spine/spine.py @@ -15,6 +15,11 @@ import pandas as pd import wget from PIL import Image +from totalsegmentator.libs import ( + download_pretrained_weights, + nostdout, + setup_nnunet, +) from totalsegmentatorv2.python_api import totalsegmentator from comp2comp.inference_class_base import InferenceClass @@ -22,12 +27,6 @@ from comp2comp.spine import spine_utils from comp2comp.visualization.dicom import to_dicom -from totalsegmentator.libs import ( - download_pretrained_weights, - nostdout, - setup_nnunet, -) - class SpineSegmentation(InferenceClass): """Spine segmentation.""" @@ -45,7 +44,7 @@ def __call__(self, inference_pipeline): os.makedirs(self.output_dir_segmentations) self.model_dir = inference_pipeline.model_dir - + inference_pipeline.spine_model_name = self.model_name seg, mv = self.spine_seg( @@ -53,7 +52,7 @@ def __call__(self, inference_pipeline): self.output_dir_segmentations + "spine.nii.gz", inference_pipeline.model_dir, ) - + os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir mv = nib.load( @@ -72,7 +71,7 @@ def __call__(self, inference_pipeline): inference_pipeline.segmentation = seg inference_pipeline.medical_volume = mv inference_pipeline.save_segmentations = self.save_segmentations - + return {} def setup_nnunet_c2c(self, model_dir: Union[str, Path]): @@ -137,9 +136,11 @@ def spine_seg( os.environ["SCRATCH"] = self.model_dir os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir - if self.model_name == 'ts_spine': + if self.model_name == "ts_spine": seg = totalsegmentator( - input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + input=os.path.join( + self.output_dir_segmentations, "converted_dcm.nii.gz" + ), output=os.path.join(self.output_dir_segmentations, "segmentation.nii"), task_ids=[292], ml=True, @@ -165,11 +166,11 @@ def spine_seg( statistics_exclude_masks_at_border=True, no_derived_masks=False, v1_order=False, - ) - + ) + img = None - elif self.model_name == 'stanford_spine_v0.0.1': + elif self.model_name == "stanford_spine_v0.0.1": # Setup nnunet model = "3d_fullres" folds = [0] @@ -285,20 +286,24 @@ def __call__(self, inference_pipeline): inference_pipeline.medical_volume = medical_volume if self.save: - nib.save( - segmentation, - os.path.join( - inference_pipeline.output_dir, "segmentations", "spine.nii.gz" - ), + # Save spine segmentation + seg_path = os.path.join( + inference_pipeline.output_dir, "segmentations", "spine.nii.gz" ) - nib.save( - medical_volume, - os.path.join( - inference_pipeline.output_dir, - "segmentations", - "converted_dcm.nii.gz", - ), + # remove stale file & ensure directory exists + if os.path.exists(seg_path): + os.remove(seg_path) + os.makedirs(os.path.dirname(seg_path), exist_ok=True) + nib.save(segmentation, seg_path) + + # Save converted DICOM volume + converted_path = os.path.join( + inference_pipeline.output_dir, "segmentations", "converted_dcm.nii.gz" ) + if os.path.exists(converted_path): + os.remove(converted_path) + os.makedirs(os.path.dirname(converted_path), exist_ok=True) + nib.save(medical_volume, converted_path) return {} diff --git a/comp2comp/utils/process.py b/comp2comp/utils/process.py index ab997c85..9a7a8d29 100644 --- a/comp2comp/utils/process.py +++ b/comp2comp/utils/process.py @@ -128,7 +128,7 @@ def process_3d(args, pipeline_builder): pipeline = pipeline_builder(path, args) - pipeline(output_dir=output_dir, model_dir=model_dir) + pipeline(input_path=path, output_dir=output_dir, model_dir=model_dir) if not args.save_segmentations: # remove the segmentations folder