6161from CorpusCallosum .utils .types import CCMeasuresDict , SliceSelection , SubdivisionMethod
6262from FastSurferCNN .data_loader .conform import conform , is_conform
6363from FastSurferCNN .segstats import HelpFormatter
64- from FastSurferCNN .utils import AffineMatrix4x4 , Image3d , Mask3d , Shape3d , Vector2d , logging , nibabelImage
64+ from FastSurferCNN .utils import AffineMatrix4x4 , Image3d , Mask3d , Shape3d , Vector2d , logging , nibabelImage , Image4d
6565from FastSurferCNN .utils .arg_types import path_or_none
6666from FastSurferCNN .utils .common import SubjectDirectory , find_device
6767from FastSurferCNN .utils .lta import write_lta
68- from FastSurferCNN .utils .parallel import shutdown_executors , thread_executor
68+ from FastSurferCNN .utils .parallel import shutdown_executors , thread_executor , get_num_threads , serial_executor
6969from FastSurferCNN .utils .parser_defaults import modify_argument
7070from recon_surf .align_points import find_rigid
7171
@@ -163,7 +163,7 @@ def _set_help_sid(action):
163163 "cost of precision." ,
164164 )
165165 def _slice_selection (a : str ) -> SliceSelection :
166- if b := a .lower () in ("middle" , "all" ):
166+ if ( b := a .lower () ) in ("middle" , "all" ):
167167 return b
168168 return int (a )
169169 parser .add_argument (
@@ -494,7 +494,7 @@ def segment_cc(
494494 pc_coords : Vector2d ,
495495 aseg_nib : nibabelImage ,
496496 model_segmentation : "torch.nn.Module" ,
497- ) -> tuple [Mask3d , Image3d ]:
497+ ) -> tuple [Mask3d , Image4d ]:
498498 """Segment the corpus callosum using a trained model.
499499
500500 Performs corpus callosum segmentation on mid-sagittal slices using a trained model, with AC-PC points as anatomical
@@ -518,7 +518,7 @@ def segment_cc(
518518 cc_seg_labels : np.ndarray
519519 Binary cc_seg_labels of the corpus callosum.
520520 cc_softlabels : np.ndarray
521- Soft cc_seg_labels probabilities.
521+ Soft cc_seg_labels probabilities of shape (H, W, D, C=3) .
522522 """
523523 pre_clean_segmentation , inputs , cc_softlabels = segmentation_inference .run_inference_on_slice (
524524 model_segmentation ,
@@ -732,7 +732,7 @@ def main(
732732 sys .exit (1 )
733733
734734 logger .info ("Performing centroid registration to fsaverage space" )
735- orig2fsavg_vox2vox , orig2fsavg_ras2ras , fsavg_vox2ras , fsavg_header , fsavg_vox2ras_tkr = (
735+ orig2fsavg_vox2vox , orig2fsavg_ras2ras , fsavg_vox2ras , fsavg_header , _ = (
736736 register_centroids_to_fsavg (aseg_img )
737737 )
738738
@@ -754,11 +754,11 @@ def main(
754754 affine_x_offset = partial (create_sag_slice_vox2vox , fsaverage_middle = FSAVERAGE_MIDDLE / vox_size [0 ])
755755 fsavg2midslab_in_vox2vox : AffineMatrix4x4 = affine_x_offset (slices_to_analyze // 2 )
756756 # first, midslice->fsaverage in vox2vox, then vox2ras in fsaverage space
757- midslab_vox2ras : AffineMatrix4x4 = fsavg_vox2ras @ np .linalg .inv (fsavg2midslab_in_vox2vox )
757+ fsaverage_midslab_vox2ras : AffineMatrix4x4 = fsavg_vox2ras @ np .linalg .inv (fsavg2midslab_in_vox2vox )
758758
759759 # calculate vox2vox for input resampling volumes
760- def _orig2midslab_vox2vox (extra_slices : int ) -> AffineMatrix4x4 :
761- fsavg2midslab = affine_x_offset (slices_to_analyze // 2 + extra_slices // 2 )
760+ def _orig2midslab_vox2vox (additional_context : int = 0 ) -> AffineMatrix4x4 :
761+ fsavg2midslab = affine_x_offset (slices_to_analyze // 2 + additional_context // 2 )
762762 # first, orig->fsaverage in vox2vox, then fsaverage->midslab in vox2vox
763763 return fsavg2midslab @ orig2fsavg_vox2vox
764764
@@ -769,16 +769,16 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
769769 ac_coords , pc_coords = localize_ac_pc (
770770 np .asarray (orig .dataobj ),
771771 aseg_img ,
772- _orig2midslab_vox2vox (extra_slices = 2 ),
772+ _orig2midslab_vox2vox (additional_context = 2 ),
773773 _model_localization .result (),
774774 target_shape ,
775775 )
776776 logger .info ("Starting corpus callosum segmentation" )
777- # "+ 8" in x-direction for context slices
778- target_shape : Shape3d = (slices_to_analyze + 8 , fsavg_header ["dims" ][1 ], fsavg_header ["dims" ][2 ])
777+ extra_slices = 8 # 8 extra in x-direction for context slices
778+ target_shape : Shape3d = (slices_to_analyze + extra_slices , fsavg_header ["dims" ][1 ], fsavg_header ["dims" ][2 ])
779779 midslices : Image3d = affine_transform (
780780 np .asarray (orig .dataobj ),
781- np .linalg .inv (_orig2midslab_vox2vox (extra_slices = 8 )), # inverse is required for affine_transform
781+ np .linalg .inv (_orig2midslab_vox2vox (additional_context = extra_slices )), # inverse is required for affine_transform
782782 output_shape = target_shape ,
783783 order = 2 , # @ClePol unclear, why this is not order=3
784784 mode = "constant" ,
@@ -799,7 +799,7 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
799799 logger .info (f"Saving { name } softlabels to { sd .filename_by_attribute (f'cc_softlabels_{ attr } ' )} " )
800800 io_futures .append (thread_executor ().submit (
801801 nib .save ,
802- nib .MGHImage (cc_fn_softlabels [..., i ], midslab_vox2ras , orig .header ),
802+ nib .MGHImage (cc_fn_softlabels [..., i ], fsaverage_midslab_vox2ras , orig .header ),
803803 sd .filename_by_attribute (f"cc_softlabels_{ attr } " ),
804804 ))
805805
@@ -819,7 +819,6 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
819819 subdivision_method = subdivision_method ,
820820 contour_smoothing = contour_smoothing ,
821821 vox_size = vox_size ,
822- vox2ras_tkr = fsavg_vox2ras_tkr ,
823822 subject_dir = sd ,
824823 )
825824 io_futures .extend (slice_io_futures )
@@ -837,7 +836,7 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
837836 cc_subseg_midslice = make_subdivision_mask (
838837 (cc_fn_seg_labels .shape [1 ], cc_fn_seg_labels .shape [2 ]),
839838 middle_slice_result ["split_contours" ],
840- vox_size [1 :],
839+ vox_size [1 :3 ],
841840 )
842841 else :
843842 logger .warning ("Too many subsegments for lookup table, skipping sub-division of output segmentation." )
@@ -847,19 +846,21 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4:
847846 if sd .has_attribute ("cc_segmentation" ):
848847 io_futures .append (thread_executor ().submit (
849848 nib .save ,
850- nib .MGHImage (cc_fn_seg_labels , midslab_vox2ras , orig .header ),
849+ nib .MGHImage (cc_fn_seg_labels , fsaverage_midslab_vox2ras , orig .header ),
851850 sd .filename_by_attribute ("cc_segmentation" ),
852851 ))
853852 # map soft labels to original space (in parallel because this takes a while, and we only do it to save the labels)
854853 if sd .has_attribute ("cc_orig_segfile" ):
855- io_futures .append (thread_executor ().submit (
854+ # if num_threads is not large enough (>1), this might be blocking ; serial_executor runs the function in submit
855+ executor = thread_executor () if get_num_threads () > 2 else serial_executor ()
856+ io_futures .append (executor .submit (
856857 map_softlabels_to_orig ,
857858 cc_fn_softlabels = cc_fn_softlabels ,
858- orig_fsaverage_vox2vox = orig2fsavg_vox2vox ,
859859 orig = orig ,
860860 orig_space_segmentation_path = sd .filename_by_attribute ("cc_orig_segfile" ),
861- fsaverage_middle = FSAVERAGE_MIDDLE ,
861+ orig2slab_vox2vox = _orig2midslab_vox2vox () ,
862862 cc_subseg_midslice = cc_subseg_midslice ,
863+ orig2midslice_vox2vox = affine_x_offset (0 ) @ orig2fsavg_vox2vox , # orig2fsavg, then full2midslice
863864 ))
864865
865866 METRICS = [
@@ -1004,6 +1005,7 @@ def save_cc_measures_json(cc_mid_measure_file: Path, metrics: dict[str, object])
10041005 conf_name = options .conf_name ,
10051006 aseg_name = options .aseg_name ,
10061007 subject_dir = options .subject_dir ,
1008+ #FIXME: slice_selection is True/bool
10071009 slice_selection = options .slice_selection ,
10081010 num_thickness_points = options .num_thickness_points ,
10091011 subdivisions = list (options .subdivisions ),
0 commit comments