Skip to content

Commit 9232375

Browse files
committed
Fixing problems introduced by incomplete changes in review
resolving several issues from the review, like using concurrent.futures.* Cleanup, optimizations and formatting (e.g. variable names)
1 parent 22906d0 commit 9232375

21 files changed

+1008
-1123
lines changed

CorpusCallosum/cc_visualization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from CorpusCallosum.data.constants import FSAVERAGE_DATA_PATH
99
from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template
1010
from CorpusCallosum.data.read_write import load_fsaverage_data
11-
from CorpusCallosum.shape.cc_mesh import CC_Mesh
11+
from CorpusCallosum.shape.cc_mesh import CCMesh
1212

1313

1414
def make_parser() -> argparse.ArgumentParser:
@@ -110,7 +110,7 @@ def main(
110110
output_dir = Path(output_dir)
111111

112112
# Load data and create mesh
113-
cc_mesh = CC_Mesh(num_slices=1) # Will be resized when loading data
113+
cc_mesh = CCMesh(num_slices=1) # Will be resized when loading data
114114

115115
_, _, vox2ras_tkr = load_fsaverage_data(FSAVERAGE_DATA_PATH)
116116

CorpusCallosum/data/constants.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@
1919

2020
### Constants
2121
WEIGHTS_PATH = FASTSURFER_ROOT / "checkpoints"
22-
FSAVERAGE_CENTROIDS_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "fsaverage_centroids.json"
23-
FSAVERAGE_DATA_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "fsaverage_data.json" # Contains both affine and header
22+
FSAVERAGE_CENTROIDS_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "data" / "fsaverage_centroids.json"
23+
# Contains both affine and header
24+
FSAVERAGE_DATA_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "data" / "fsaverage_data.json"
2425
FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space
2526
CC_LABEL = 192 # Label value for corpus callosum in segmentation
2627
FORNIX_LABEL = 250 # Label value for fornix in segmentation
2728
SUBSEGMENT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation
2829

2930

3031
STANDARD_INPUT_PATHS = {
31-
"t1": "mri/orig.mgz",
32+
"conf_name": "mri/orig.mgz",
3233
"aseg_name": "mri/aparc.DKTatlas+aseg.deep.mgz",
3334
}
3435

CorpusCallosum/data/fsaverage_cc_template.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from scipy import ndimage
2121

2222
from CorpusCallosum.data import constants
23-
from CorpusCallosum.shape.cc_postprocessing import process_slice
23+
from CorpusCallosum.shape.cc_postprocessing import recon_cc_surf_measure
2424
from FastSurferCNN.utils.brainvolstats import mask_in_array
2525

2626

@@ -121,16 +121,16 @@ def load_fsaverage_cc_template() -> tuple[
121121
cc_mask = cc_mask_smoothed.astype(int) * 192
122122

123123
(_, contour_with_thickness, anterior_endpoint_idx,
124-
posterior_endpoint_idx) = process_slice(segmentation=cc_mask[None],
125-
slice_idx=0,
126-
ac_coords=AC,
127-
pc_coords=PC,
128-
affine=fsaverage_seg.affine,
129-
num_thickness_points=100,
130-
subdivisions=[1/6, 1/2, 2/3, 3/4],
131-
subdivision_method="shape",
132-
contour_smoothing=5,
133-
vox_size=1)
124+
posterior_endpoint_idx) = recon_cc_surf_measure(segmentation=cc_mask[None],
125+
slice_idx=0,
126+
ac_coords=AC,
127+
pc_coords=PC,
128+
affine=fsaverage_seg.affine,
129+
num_thickness_points=100,
130+
subdivisions=[1/6, 1/2, 2/3, 3/4],
131+
subdivision_method="shape",
132+
contour_smoothing=5,
133+
vox_size=1)
134134
outside_contour = contour_with_thickness[0].T
135135

136136

CorpusCallosum/data/read_write.py

Lines changed: 36 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -15,73 +15,42 @@
1515
import json
1616
import multiprocessing
1717
from pathlib import Path
18-
from typing import overload
18+
from typing import TypedDict
1919

2020
import nibabel as nib
21+
from numpy import typing as npt
2122
import numpy as np
2223

2324
import FastSurferCNN.utils.logging as logging
2425

25-
logger = logging.get_logger(__name__)
26-
27-
28-
def run_in_background(function: callable, debug: bool = False, *args, **kwargs) -> multiprocessing.Process | None:
29-
"""Run a function in the background using multiprocessing.
30-
31-
Parameters
32-
----------
33-
function : callable
34-
The function to execute.
35-
debug : bool, optional
36-
If True, run synchronously in current process, by default False.
37-
*args
38-
Positional arguments to pass to the function.
39-
**kwargs
40-
Keyword arguments to pass to the function.
41-
42-
Returns
43-
-------
44-
multiprocessing.Process or None
45-
Process object if running in background, None if in debug mode.
46-
"""
47-
if debug:
48-
function(*args, **kwargs)
49-
process = None
50-
else:
51-
process = multiprocessing.Process(target=function, args=args, kwargs=kwargs)
52-
process.start()
53-
return process
54-
26+
class FSAverageHeader(TypedDict):
27+
dims: npt.NDArray[int]
28+
delta: npt.NDArray[float]
29+
Mdc: npt.NDArray[float]
30+
Pxyz_c: npt.NDArray[float]
5531

56-
@overload
57-
def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: None = None) -> dict[int, np.ndarray]:
58-
...
32+
logger = logging.get_logger(__name__)
5933

60-
@overload
61-
def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int]) -> tuple[dict[int, np.ndarray], list[int]]:
62-
...
6334

64-
def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None = None):
35+
def get_centroids_from_nib(seg_img: nib.analyze.SpatialImage, label_ids: list[int] | None = None) \
36+
-> dict[int, np.ndarray | None]:
6537
"""Get centroids of segmentation labels in RAS coordinates.
6638
6739
Parameters
6840
----------
69-
seg_img : nibabel.Nifti1Image
41+
seg_img : nibabel.analyze.SpatialImage
7042
Input segmentation image.
7143
label_ids : list[int], optional
7244
List of label IDs to extract centroids for. If None, extracts all non-zero labels.
7345
7446
Returns
7547
-------
76-
dict[int, np.ndarray]
77-
If label_ids is None, returns a dict mapping label IDs to their centroids (x,y,z) in RAS coordinates.
78-
If label_ids is provided, returns a tuple containing:
79-
- dict[int, np.ndarray]: Mapping of found label IDs to their centroids.
80-
- list[int]: List of label IDs that were not found in the image.
48+
dict[int, np.ndarray | None]
49+
A dict mapping label IDs to their centroids (x,y,z) in RAS coordinates, None if label did not exist.
8150
"""
8251
# Get segmentation data and affine
83-
seg_data = seg_img.get_fdata()
84-
vox2ras = seg_img.affine
52+
seg_data: npt.NDArray[np.integer] = np.asarray(seg_img.dataobj)
53+
vox2ras: npt.NDArray[float] = seg_img.affine
8554

8655
# Get unique labels
8756
if label_ids is None:
@@ -90,61 +59,22 @@ def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None
9059
else:
9160
labels = label_ids
9261

62+
def _get_ras_centroid(mask_vox: npt.NDArray[np.integer]) -> npt.NDArray[float]: # Calculate centroid in voxel space
63+
vox_centroid = np.mean(mask_vox, axis=1, dtype=float)
64+
65+
# Convert to homogeneous coordinates
66+
vox_centroid = np.append(vox_centroid, 1)
67+
68+
# Transform to RAS coordinates and return without homogeneous coordinate
69+
return (vox2ras @ vox_centroid)[:3]
70+
9371
centroids = {}
94-
ids_not_found = []
9572
for label in labels:
9673
# Get voxel indices for this label
9774
vox_coords = np.array(np.where(seg_data == label))
98-
if vox_coords.size == 0:
99-
ids_not_found.append(label)
100-
continue
101-
# Calculate centroid in voxel space
102-
vox_centroid = np.mean(vox_coords, axis=1)
103-
104-
# Convert to homogeneous coordinates
105-
vox_centroid = np.append(vox_centroid, 1)
106-
107-
# Transform to RAS coordinates
108-
ras_centroid = vox2ras @ vox_centroid
75+
centroids[int(label)] = None if vox_coords.size == 0 else _get_ras_centroid(vox_coords)
10976

110-
# Store without homogeneous coordinate
111-
centroids[int(label)] = ras_centroid[:3]
112-
113-
if label_ids is not None:
114-
return centroids, ids_not_found
115-
else:
116-
return centroids
117-
118-
119-
120-
def save_nifti_background(
121-
io_processes: list,
122-
data: np.ndarray,
123-
affine: np.ndarray,
124-
header: nib.Nifti1Header,
125-
filepath: str | Path
126-
) -> None:
127-
"""Save a NIfTI image in a background process.
128-
129-
Creates a MGHImage from the provided data and metadata, then saves it to disk
130-
using a background process to avoid blocking the main execution.
131-
132-
Parameters
133-
----------
134-
io_processes : list
135-
List to store background process handles.
136-
data : np.ndarray
137-
Image data array.
138-
affine : np.ndarray
139-
4x4 affine transformation matrix.
140-
header : nib.Nifti1Header
141-
NIfTI header object containing metadata.
142-
filepath : str or Path
143-
Path where the image should be saved.
144-
"""
145-
logger.info(f"Saving NIfTI image to {filepath}")
146-
io_processes.append(run_in_background(nib.save, False,
147-
nib.MGHImage(data, affine, header), filepath))
77+
return centroids
14878

14979

15080
def convert_numpy_to_json_serializable(obj: object) -> object:
@@ -173,7 +103,7 @@ def convert_numpy_to_json_serializable(obj: object) -> object:
173103
return obj
174104

175105

176-
def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, np.ndarray]:
106+
def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, npt.NDArray[float]]:
177107
"""Load fsaverage centroids from static JSON file.
178108
179109
Parameters
@@ -198,7 +128,7 @@ def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, np.ndarray
198128
return {int(label): np.array(centroid) for label, centroid in centroids_data.items()}
199129

200130

201-
def load_fsaverage_affine(affine_path: str | Path) -> np.ndarray:
131+
def load_fsaverage_affine(affine_path: str | Path) -> npt.NDArray[float]:
202132
"""Load fsaverage affine matrix from static text file.
203133
204134
Parameters
@@ -216,15 +146,15 @@ def load_fsaverage_affine(affine_path: str | Path) -> np.ndarray:
216146
if not affine_path.exists():
217147
raise FileNotFoundError(f"Fsaverage affine file not found: {affine_path}")
218148

219-
affine_matrix = np.loadtxt(affine_path)
149+
affine_matrix = np.loadtxt(affine_path).astype(float)
220150

221151
if affine_matrix.shape != (4, 4):
222152
raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}")
223153

224154
return affine_matrix
225155

226156

227-
def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.ndarray]:
157+
def load_fsaverage_data(data_path: str | Path) -> tuple[npt.NDArray[float], FSAverageHeader, npt.NDArray[float]]:
228158
"""Load fsaverage affine matrix and header fields from static JSON file.
229159
230160
Parameters
@@ -257,9 +187,7 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.nda
257187
If the file is not valid JSON.
258188
ValueError
259189
If required fields are missing.
260-
261190
"""
262-
263191
data_path = Path(data_path)
264192
if not data_path.exists():
265193
raise FileNotFoundError(f"Fsaverage data file not found: {data_path}")
@@ -281,9 +209,12 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.nda
281209
# Convert lists back to numpy arrays
282210
affine_matrix = np.array(data["affine"])
283211
vox2ras_tkr = np.array(data["vox2ras_tkr"])
284-
header_data = data["header"].copy()
285-
header_data["Mdc"] = np.array(header_data["Mdc"])
286-
header_data["Pxyz_c"] = np.array(header_data["Pxyz_c"])
212+
header_data = FSAverageHeader(
213+
dims=data["header"]["dims"],
214+
delta=data["header"]["delta"],
215+
Mdc=np.array(data["header"]["Mdc"]),
216+
Pxyz_c=np.array(data["header"]["Pxyz_c"]),
217+
)
287218

288219
# Validate affine matrix shape
289220
if affine_matrix.shape != (4, 4):

0 commit comments

Comments
 (0)