Skip to content

Commit 8535c64

Browse files
committed
Commit Message to fix
1 parent 36ab5a3 commit 8535c64

File tree

19 files changed

+392
-385
lines changed

19 files changed

+392
-385
lines changed

CorpusCallosum/cc_visualization.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,19 @@ def make_parser() -> argparse.ArgumentParser:
9393
default=0,
9494
help="Enable verbose (pass twice for debug-output).",
9595
)
96-
9796
return parser
9897

98+
9999
def options_parse() -> argparse.Namespace:
100100
"""Parse command line arguments for the pipeline."""
101101
parser = make_parser()
102102
args = parser.parse_args()
103103

104104
# Create output directory if it doesn't exist
105105
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
106-
107106
return args
108107

109108

110-
111-
112109
def load_contours_from_template_dir(
113110
template_dir: Path, resolution: float, smoothing_window: int
114111
) -> list[CCContour]:
@@ -122,7 +119,6 @@ def load_contours_from_template_dir(
122119
)
123120

124121
fsaverage_contour = None
125-
126122
contours: list[CCContour] = []
127123
for thickness_file in thickness_files:
128124
try:
@@ -176,14 +172,11 @@ def main(
176172
_, _, vox2ras_tkr = load_fsaverage_data(FSAVERAGE_DATA_PATH)
177173

178174
contours = load_contours_from_template_dir(
179-
Path(template_dir), resolution=resolution, smoothing_window=smoothing_window
175+
Path(template_dir), resolution=resolution, smoothing_window=smoothing_window,
180176
)
181177

182178
# 2D visualization
183179
mid_contour = contours[len(contours) // 2]
184-
185-
186-
187180

188181
# for now, we only support thickness visualization, this is preparing to plot also p-values and icc values
189182
mode = "thickness"
@@ -215,8 +208,7 @@ def main(
215208
cc_mesh.plot_mesh(**plot_kwargs)
216209
cc_mesh.plot_mesh(output_path=str(output_dir / "cc_mesh.html"), **plot_kwargs)
217210

218-
219-
cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr)
211+
cc_mesh = cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr)
220212
logger.info(f"Writing vtk file to {output_dir / 'cc_mesh.vtk'}")
221213
cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk"))
222214
logger.info(f"Writing freesurfer surface file to {output_dir / 'cc_mesh.fssurf'}")

CorpusCallosum/fastsurfer_cc.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@
6161
from CorpusCallosum.utils.types import CCMeasuresDict, SliceSelection, SubdivisionMethod
6262
from FastSurferCNN.data_loader.conform import conform, is_conform
6363
from FastSurferCNN.segstats import HelpFormatter
64-
from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask3d, Shape3d, Vector2d, logging, nibabelImage
64+
from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask3d, Shape3d, Vector2d, logging, nibabelImage, Image4d
6565
from FastSurferCNN.utils.arg_types import path_or_none
6666
from FastSurferCNN.utils.common import SubjectDirectory, find_device
6767
from FastSurferCNN.utils.lta import write_lta
68-
from FastSurferCNN.utils.parallel import shutdown_executors, thread_executor
68+
from FastSurferCNN.utils.parallel import shutdown_executors, thread_executor, get_num_threads, serial_executor
6969
from FastSurferCNN.utils.parser_defaults import modify_argument
7070
from recon_surf.align_points import find_rigid
7171

@@ -163,7 +163,7 @@ def _set_help_sid(action):
163163
"cost of precision.",
164164
)
165165
def _slice_selection(a: str) -> SliceSelection:
166-
if b := a.lower() in ("middle", "all"):
166+
if (b := a.lower()) in ("middle", "all"):
167167
return b
168168
return int(a)
169169
parser.add_argument(
@@ -494,7 +494,7 @@ def segment_cc(
494494
pc_coords: Vector2d,
495495
aseg_nib: nibabelImage,
496496
model_segmentation: "torch.nn.Module",
497-
) -> tuple[Mask3d, Image3d]:
497+
) -> tuple[Mask3d, Image4d]:
498498
"""Segment the corpus callosum using a trained model.
499499
500500
Performs corpus callosum segmentation on mid-sagittal slices using a trained model, with AC-PC points as anatomical
@@ -518,7 +518,7 @@ def segment_cc(
518518
cc_seg_labels : np.ndarray
519519
Binary cc_seg_labels of the corpus callosum.
520520
cc_softlabels : np.ndarray
521-
Soft cc_seg_labels probabilities.
521+
Soft cc_seg_labels probabilities of shape (H, W, D, C=3).
522522
"""
523523
pre_clean_segmentation, inputs, cc_softlabels = segmentation_inference.run_inference_on_slice(
524524
model_segmentation,
@@ -732,7 +732,7 @@ def main(
732732
sys.exit(1)
733733

734734
logger.info("Performing centroid registration to fsaverage space")
735-
orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_vox2ras, fsavg_header, fsavg_vox2ras_tkr = (
735+
orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_vox2ras, fsavg_header, _ = (
736736
register_centroids_to_fsavg(aseg_img)
737737
)
738738

@@ -754,11 +754,11 @@ def main(
754754
affine_x_offset = partial(create_sag_slice_vox2vox, fsaverage_middle=FSAVERAGE_MIDDLE / vox_size[0])
755755
fsavg2midslab_in_vox2vox: AffineMatrix4x4 = affine_x_offset(slices_to_analyze // 2)
756756
# first, midslice->fsaverage in vox2vox, then vox2ras in fsaverage space
757-
midslab_vox2ras: AffineMatrix4x4 = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_in_vox2vox)
757+
fsaverage_midslab_vox2ras: AffineMatrix4x4 = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_in_vox2vox)
758758

759759
# calculate vox2vox for input resampling volumes
760-
def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
761-
fsavg2midslab = affine_x_offset(slices_to_analyze // 2 + extra_slices // 2)
760+
def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4:
761+
fsavg2midslab = affine_x_offset(slices_to_analyze // 2 + additional_context // 2)
762762
# first, orig->fsaverage in vox2vox, then fsaverage->midslab in vox2vox
763763
return fsavg2midslab @ orig2fsavg_vox2vox
764764

@@ -769,16 +769,16 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
769769
ac_coords, pc_coords = localize_ac_pc(
770770
np.asarray(orig.dataobj),
771771
aseg_img,
772-
_orig2midslab_vox2vox(extra_slices=2),
772+
_orig2midslab_vox2vox(additional_context=2),
773773
_model_localization.result(),
774774
target_shape,
775775
)
776776
logger.info("Starting corpus callosum segmentation")
777-
# "+ 8" in x-direction for context slices
778-
target_shape: Shape3d = (slices_to_analyze + 8, fsavg_header["dims"][1], fsavg_header["dims"][2])
777+
extra_slices = 8 # 8 extra in x-direction for context slices
778+
target_shape: Shape3d = (slices_to_analyze + extra_slices, fsavg_header["dims"][1], fsavg_header["dims"][2])
779779
midslices: Image3d = affine_transform(
780780
np.asarray(orig.dataobj),
781-
np.linalg.inv(_orig2midslab_vox2vox(extra_slices=8)), # inverse is required for affine_transform
781+
np.linalg.inv(_orig2midslab_vox2vox(additional_context=extra_slices)), # inverse is required for affine_transform
782782
output_shape=target_shape,
783783
order=2, # @ClePol unclear, why this is not order=3
784784
mode="constant",
@@ -799,7 +799,7 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
799799
logger.info(f"Saving {name} softlabels to {sd.filename_by_attribute(f'cc_softlabels_{attr}')}")
800800
io_futures.append(thread_executor().submit(
801801
nib.save,
802-
nib.MGHImage(cc_fn_softlabels[..., i], midslab_vox2ras, orig.header),
802+
nib.MGHImage(cc_fn_softlabels[..., i], fsaverage_midslab_vox2ras, orig.header),
803803
sd.filename_by_attribute(f"cc_softlabels_{attr}"),
804804
))
805805

@@ -819,7 +819,6 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
819819
subdivision_method=subdivision_method,
820820
contour_smoothing=contour_smoothing,
821821
vox_size=vox_size,
822-
vox2ras_tkr=fsavg_vox2ras_tkr,
823822
subject_dir=sd,
824823
)
825824
io_futures.extend(slice_io_futures)
@@ -837,7 +836,7 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
837836
cc_subseg_midslice = make_subdivision_mask(
838837
(cc_fn_seg_labels.shape[1], cc_fn_seg_labels.shape[2]),
839838
middle_slice_result["split_contours"],
840-
vox_size[1:],
839+
vox_size[1:3],
841840
)
842841
else:
843842
logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.")
@@ -847,19 +846,21 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
847846
if sd.has_attribute("cc_segmentation"):
848847
io_futures.append(thread_executor().submit(
849848
nib.save,
850-
nib.MGHImage(cc_fn_seg_labels, midslab_vox2ras, orig.header),
849+
nib.MGHImage(cc_fn_seg_labels, fsaverage_midslab_vox2ras, orig.header),
851850
sd.filename_by_attribute("cc_segmentation"),
852851
))
853852
# map soft labels to original space (in parallel because this takes a while, and we only do it to save the labels)
854853
if sd.has_attribute("cc_orig_segfile"):
855-
io_futures.append(thread_executor().submit(
854+
# if num_threads is not large enough (>1), this might be blocking ; serial_executor runs the function in submit
855+
executor = thread_executor() if get_num_threads() > 2 else serial_executor()
856+
io_futures.append(executor.submit(
856857
map_softlabels_to_orig,
857858
cc_fn_softlabels=cc_fn_softlabels,
858-
orig_fsaverage_vox2vox=orig2fsavg_vox2vox,
859859
orig=orig,
860860
orig_space_segmentation_path=sd.filename_by_attribute("cc_orig_segfile"),
861-
fsaverage_middle=FSAVERAGE_MIDDLE,
861+
orig2slab_vox2vox=_orig2midslab_vox2vox(),
862862
cc_subseg_midslice=cc_subseg_midslice,
863+
orig2midslice_vox2vox=affine_x_offset(0) @ orig2fsavg_vox2vox, # orig2fsavg, then full2midslice
863864
))
864865

865866
METRICS = [
@@ -1004,6 +1005,7 @@ def save_cc_measures_json(cc_mid_measure_file: Path, metrics: dict[str, object])
10041005
conf_name=options.conf_name,
10051006
aseg_name=options.aseg_name,
10061007
subject_dir=options.subject_dir,
1008+
#FIXME: slice_selection is True/bool
10071009
slice_selection=options.slice_selection,
10081010
num_thickness_points=options.num_thickness_points,
10091011
subdivisions=list(options.subdivisions),

CorpusCallosum/localization/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def predict(
188188
outputs = model(inputs) * torch.as_tensor([PATCH_SIZE + PATCH_SIZE], device=device)
189189

190190
t_crops = [(t_dict['crop_left'] + t_dict['crop_top']) * 2]
191-
outs: np.ndarray[tuple[int, Literal[4]], np.dtype[float]] = outputs.cpu().numpy() + np.asarray(t_crops, dtype=float)
191+
outs: np.ndarray[tuple[int, Literal[4]], np.dtype[np.float_]] = outputs.cpu().numpy() + np.asarray(t_crops, dtype=float)
192192
crop_offsets: tuple[int, int] = (t_dict["crop_left"][0], t_dict["crop_top"][0])
193193
return outs[:, :2], outs[:, 2:], crop_offsets
194194

CorpusCallosum/segmentation/inference.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def run_inference(
9292
voxel_size: tuple[float, float],
9393
device: torch.device | None = None,
9494
transform: transforms.Transform | None = None
95-
) -> tuple[np.ndarray[Shape4d, np.dtype[int]], Image4d, Image4d]:
95+
) -> tuple[np.ndarray[Shape4d, np.dtype[np.int_]], Image4d, Image4d]:
9696
"""Run inference on a single image slice.
9797
9898
Parameters
@@ -229,15 +229,15 @@ def _load(label_path: str | Path) -> int:
229229
return images, ac_centers, pc_centers, label_widths, labels, subj_ids
230230

231231
@overload
232-
def one_hot_to_label(one_hot: Image4d, label_ids: list[int] | None = None) -> np.ndarray[Shape3d, np.dtype[int]]: ...
232+
def one_hot_to_label(one_hot: Image4d, label_ids: list[int] | None = None) -> np.ndarray[Shape3d, np.dtype[np.int_]]: ...
233233

234234
@overload
235-
def one_hot_to_label(one_hot: Image3d, label_ids: list[int] | None = None) -> np.ndarray[Shape2d, np.dtype[int]]: ...
235+
def one_hot_to_label(one_hot: Image3d, label_ids: list[int] | None = None) -> np.ndarray[Shape2d, np.dtype[np.int_]]: ...
236236

237237
def one_hot_to_label(
238-
one_hot: np.ndarray[tuple[int, ...], np.dtype[bool]],
238+
one_hot: np.ndarray[tuple[int, ...], np.dtype[np.bool_]],
239239
label_ids: list[int] | None = None,
240-
) -> np.ndarray[tuple[int, ...], np.dtype[int]]:
240+
) -> np.ndarray[tuple[int, ...], np.dtype[np.int_]]:
241241
"""Convert one-hot encoded segmentation to label map.
242242
243243
Converts a one-hot encoded segmentation array to discrete labels by taking
@@ -273,7 +273,7 @@ def run_inference_on_slice(
273273
ac_center: Vector2d,
274274
pc_center: Vector2d,
275275
voxel_size: tuple[float, float],
276-
) -> tuple[np.ndarray[Shape3d, np.dtype[int]], Image4d, Image4d]:
276+
) -> tuple[np.ndarray[Shape3d, np.dtype[np.int_]], Image4d, Image4d]:
277277
"""Run inference on a single slice.
278278
279279
Parameters

CorpusCallosum/segmentation/segmentation_postprocessing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,9 +500,9 @@ def extract_largest_connected_component(
500500

501501

502502
def clean_cc_segmentation(
503-
seg_arr: np.ndarray[Shape3d, np.dtype[int]],
503+
seg_arr: np.ndarray[Shape3d, np.dtype[np.int_]],
504504
max_connection_distance: float = 3.0,
505-
) -> tuple[np.ndarray[Shape3d, np.dtype[int]], Mask3d]:
505+
) -> tuple[np.ndarray[Shape3d, np.dtype[np.int_]], Mask3d]:
506506
"""Clean corpus callosum segmentation by removing non-connected components.
507507
508508
Parameters

CorpusCallosum/shape/contour.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ class CCContour:
7373

7474
def __init__(
7575
self,
76-
contour: np.ndarray[tuple[Literal["N", 2]], np.dtype[float]],
77-
thickness_values: np.ndarray[tuple[Literal["N"]], np.dtype[float]],
76+
contour: np.ndarray[tuple[Literal["N", 2]], np.dtype[np.float_]],
77+
thickness_values: np.ndarray[tuple[Literal["N"]], np.dtype[np.float_]],
7878
endpoint_idxs: tuple[int, int] | None = None,
7979
resolution: float = 1.0
8080
):
@@ -86,8 +86,10 @@ def __init__(
8686
Array of shape (N, 2) containing 2D contour points.
8787
thickness_values : np.ndarray
8888
Array of thickness measurements for each contour point.
89-
endpoint_idxs : tuple[int, int]
89+
endpoint_idxs : tuple[int, int], optional
9090
Tuple containing start and end indices for the contour.
91+
resolution : float, default=1.0
92+
The left-right spacing.
9193
"""
9294
self.contour = contour
9395
if self.contour.shape[1] != 2:

CorpusCallosum/shape/endpoint_heuristic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import skimage.measure
2020
from scipy.ndimage import label
2121

22-
from FastSurferCNN.utils import Vector2d
22+
from FastSurferCNN.utils import Vector2d, Mask2d
2323

2424

2525
def smooth_contour(x: np.ndarray, y: np.ndarray, window_size: int) -> tuple[np.ndarray, np.ndarray]:
@@ -169,7 +169,7 @@ def extract_cc_contour(cc_mask: np.ndarray, contour_smoothing: int = 5) -> np.nd
169169

170170
@overload
171171
def get_endpoints(
172-
cc_mask: np.ndarray[tuple[int, int], np.dtype[bool]],
172+
cc_mask: Mask2d,
173173
ac_2d: Vector2d,
174174
pc_2d: Vector2d,
175175
resolution: tuple[float, float],
@@ -180,7 +180,7 @@ def get_endpoints(
180180

181181
@overload
182182
def get_endpoints(
183-
cc_mask: np.ndarray[tuple[int, int], np.dtype[bool]],
183+
cc_mask: Mask2d,
184184
ac_2d: Vector2d,
185185
pc_2d: Vector2d,
186186
resolution: tuple[float, float],
@@ -190,7 +190,7 @@ def get_endpoints(
190190

191191

192192
def get_endpoints(
193-
cc_mask: np.ndarray[tuple[int, int], np.dtype[bool]],
193+
cc_mask: Mask2d,
194194
ac_2d: Vector2d,
195195
pc_2d: Vector2d,
196196
resolution: tuple[float, float],

0 commit comments

Comments
 (0)