1515import json
1616import multiprocessing
1717from pathlib import Path
18- from typing import overload
18+ from typing import TypedDict
1919
2020import nibabel as nib
21+ from numpy import typing as npt
2122import numpy as np
2223
2324import FastSurferCNN .utils .logging as logging
2425
25- logger = logging .get_logger (__name__ )
26-
27-
28- def run_in_background (function : callable , debug : bool = False , * args , ** kwargs ) -> multiprocessing .Process | None :
29- """Run a function in the background using multiprocessing.
30-
31- Parameters
32- ----------
33- function : callable
34- The function to execute.
35- debug : bool, optional
36- If True, run synchronously in current process, by default False.
37- *args
38- Positional arguments to pass to the function.
39- **kwargs
40- Keyword arguments to pass to the function.
41-
42- Returns
43- -------
44- multiprocessing.Process or None
45- Process object if running in background, None if in debug mode.
46- """
47- if debug :
48- function (* args , ** kwargs )
49- process = None
50- else :
51- process = multiprocessing .Process (target = function , args = args , kwargs = kwargs )
52- process .start ()
53- return process
54-
26+ class FSAverageHeader (TypedDict ):
27+ dims : npt .NDArray [int ]
28+ delta : npt .NDArray [float ]
29+ Mdc : npt .NDArray [float ]
30+ Pxyz_c : npt .NDArray [float ]
5531
56- @overload
57- def get_centroids_from_nib (seg_img : nib .Nifti1Image , label_ids : None = None ) -> dict [int , np .ndarray ]:
58- ...
32+ logger = logging .get_logger (__name__ )
5933
60- @overload
61- def get_centroids_from_nib (seg_img : nib .Nifti1Image , label_ids : list [int ]) -> tuple [dict [int , np .ndarray ], list [int ]]:
62- ...
6334
64- def get_centroids_from_nib (seg_img : nib .Nifti1Image , label_ids : list [int ] | None = None ):
35+ def get_centroids_from_nib (seg_img : nib .analyze .SpatialImage , label_ids : list [int ] | None = None ) \
36+ -> dict [int , np .ndarray | None ]:
6537 """Get centroids of segmentation labels in RAS coordinates.
6638
6739 Parameters
6840 ----------
69- seg_img : nibabel.Nifti1Image
41+ seg_img : nibabel.analyze.SpatialImage
7042 Input segmentation image.
7143 label_ids : list[int], optional
7244 List of label IDs to extract centroids for. If None, extracts all non-zero labels.
7345
7446 Returns
7547 -------
76- dict[int, np.ndarray]
77- If label_ids is None, returns a dict mapping label IDs to their centroids (x,y,z) in RAS coordinates.
78- If label_ids is provided, returns a tuple containing:
79- - dict[int, np.ndarray]: Mapping of found label IDs to their centroids.
80- - list[int]: List of label IDs that were not found in the image.
48+ dict[int, np.ndarray | None]
49+ A dict mapping label IDs to their centroids (x,y,z) in RAS coordinates, None if label did not exist.
8150 """
8251 # Get segmentation data and affine
83- seg_data = seg_img . get_fdata ( )
84- vox2ras = seg_img .affine
52+ seg_data : npt . NDArray [ np . integer ] = np . asarray ( seg_img . dataobj )
53+ vox2ras : npt . NDArray [ float ] = seg_img .affine
8554
8655 # Get unique labels
8756 if label_ids is None :
@@ -90,61 +59,22 @@ def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None
9059 else :
9160 labels = label_ids
9261
62+ def _get_ras_centroid (mask_vox : npt .NDArray [np .integer ]) -> npt .NDArray [float ]: # Calculate centroid in voxel space
63+ vox_centroid = np .mean (mask_vox , axis = 1 , dtype = float )
64+
65+ # Convert to homogeneous coordinates
66+ vox_centroid = np .append (vox_centroid , 1 )
67+
68+ # Transform to RAS coordinates and return without homogeneous coordinate
69+ return (vox2ras @ vox_centroid )[:3 ]
70+
9371 centroids = {}
94- ids_not_found = []
9572 for label in labels :
9673 # Get voxel indices for this label
9774 vox_coords = np .array (np .where (seg_data == label ))
98- if vox_coords .size == 0 :
99- ids_not_found .append (label )
100- continue
101- # Calculate centroid in voxel space
102- vox_centroid = np .mean (vox_coords , axis = 1 )
103-
104- # Convert to homogeneous coordinates
105- vox_centroid = np .append (vox_centroid , 1 )
106-
107- # Transform to RAS coordinates
108- ras_centroid = vox2ras @ vox_centroid
75+ centroids [int (label )] = None if vox_coords .size == 0 else _get_ras_centroid (vox_coords )
10976
110- # Store without homogeneous coordinate
111- centroids [int (label )] = ras_centroid [:3 ]
112-
113- if label_ids is not None :
114- return centroids , ids_not_found
115- else :
116- return centroids
117-
118-
119-
120- def save_nifti_background (
121- io_processes : list ,
122- data : np .ndarray ,
123- affine : np .ndarray ,
124- header : nib .Nifti1Header ,
125- filepath : str | Path
126- ) -> None :
127- """Save a NIfTI image in a background process.
128-
129- Creates a MGHImage from the provided data and metadata, then saves it to disk
130- using a background process to avoid blocking the main execution.
131-
132- Parameters
133- ----------
134- io_processes : list
135- List to store background process handles.
136- data : np.ndarray
137- Image data array.
138- affine : np.ndarray
139- 4x4 affine transformation matrix.
140- header : nib.Nifti1Header
141- NIfTI header object containing metadata.
142- filepath : str or Path
143- Path where the image should be saved.
144- """
145- logger .info (f"Saving NIfTI image to { filepath } " )
146- io_processes .append (run_in_background (nib .save , False ,
147- nib .MGHImage (data , affine , header ), filepath ))
77+ return centroids
14878
14979
15080def convert_numpy_to_json_serializable (obj : object ) -> object :
@@ -173,7 +103,7 @@ def convert_numpy_to_json_serializable(obj: object) -> object:
173103 return obj
174104
175105
176- def load_fsaverage_centroids (centroids_path : str | Path ) -> dict [int , np . ndarray ]:
106+ def load_fsaverage_centroids (centroids_path : str | Path ) -> dict [int , npt . NDArray [ float ] ]:
177107 """Load fsaverage centroids from static JSON file.
178108
179109 Parameters
@@ -198,7 +128,7 @@ def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, np.ndarray
198128 return {int (label ): np .array (centroid ) for label , centroid in centroids_data .items ()}
199129
200130
201- def load_fsaverage_affine (affine_path : str | Path ) -> np . ndarray :
131+ def load_fsaverage_affine (affine_path : str | Path ) -> npt . NDArray [ float ] :
202132 """Load fsaverage affine matrix from static text file.
203133
204134 Parameters
@@ -216,15 +146,15 @@ def load_fsaverage_affine(affine_path: str | Path) -> np.ndarray:
216146 if not affine_path .exists ():
217147 raise FileNotFoundError (f"Fsaverage affine file not found: { affine_path } " )
218148
219- affine_matrix = np .loadtxt (affine_path )
149+ affine_matrix = np .loadtxt (affine_path ). astype ( float )
220150
221151 if affine_matrix .shape != (4 , 4 ):
222152 raise ValueError (f"Expected 4x4 affine matrix, got shape { affine_matrix .shape } " )
223153
224154 return affine_matrix
225155
226156
227- def load_fsaverage_data (data_path : str | Path ) -> tuple [np . ndarray , dict , np . ndarray ]:
157+ def load_fsaverage_data (data_path : str | Path ) -> tuple [npt . NDArray [ float ], FSAverageHeader , npt . NDArray [ float ] ]:
228158 """Load fsaverage affine matrix and header fields from static JSON file.
229159
230160 Parameters
@@ -257,9 +187,7 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.nda
257187 If the file is not valid JSON.
258188 ValueError
259189 If required fields are missing.
260-
261190 """
262-
263191 data_path = Path (data_path )
264192 if not data_path .exists ():
265193 raise FileNotFoundError (f"Fsaverage data file not found: { data_path } " )
@@ -281,9 +209,12 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.nda
281209 # Convert lists back to numpy arrays
282210 affine_matrix = np .array (data ["affine" ])
283211 vox2ras_tkr = np .array (data ["vox2ras_tkr" ])
284- header_data = data ["header" ].copy ()
285- header_data ["Mdc" ] = np .array (header_data ["Mdc" ])
286- header_data ["Pxyz_c" ] = np .array (header_data ["Pxyz_c" ])
212+ header_data = FSAverageHeader (
213+ dims = data ["header" ]["dims" ],
214+ delta = data ["header" ]["delta" ],
215+ Mdc = np .array (data ["header" ]["Mdc" ]),
216+ Pxyz_c = np .array (data ["header" ]["Pxyz_c" ]),
217+ )
287218
288219 # Validate affine matrix shape
289220 if affine_matrix .shape != (4 , 4 ):
0 commit comments