3535import numpy as np
3636from tqdm import tqdm
3737from monai .inferers import SlidingWindowInferer
38+ import torch .nn .functional as F
3839
3940# CellMap utilities
4041from cellmap_segmentation_challenge .utils import TEST_CROPS , load_safe_config
4546from omegaconf import OmegaConf
4647
4748
48- def find_scale_level (zarr_path , target_resolution ):
49- """Find the scale level that best matches target resolution."""
49+ def select_scale_level (zarr_path , target_resolution ):
50+ """Return the scale path plus voxel size/translation metadata closest to target resolution."""
5051 store = zarr .open (zarr_path , mode = 'r' )
5152
52- # Read OME-NGFF multiscale metadata
5353 multiscale_meta = store .attrs .get ('multiscales' , [{}])[0 ]
5454 datasets_meta = multiscale_meta .get ('datasets' , [])
5555
56+ # Default fallback if metadata is missing
5657 if not datasets_meta :
57- # Fallback: use s2 (typically 8nm)
58- return 2
58+ return {
59+ "path" : "s2" ,
60+ "voxel_size" : np .array (target_resolution , dtype = float ),
61+ "translation" : np .zeros (3 , dtype = float ),
62+ }
5963
60- # Find closest scale to target resolution
61- best_scale = 0
64+ best = datasets_meta [0 ]
6265 min_diff = float ('inf' )
6366
64- for i , ds_meta in enumerate (datasets_meta ):
65- transforms = ds_meta .get ('coordinateTransformations' , [{}])
66- scale = transforms [0 ].get ('scale' , [1 , 1 , 1 ]) if transforms else [1 , 1 , 1 ]
67- # scale is [z, y, x] in nm
67+ for ds_meta in datasets_meta :
68+ transforms = ds_meta .get ('coordinateTransformations' , [])
69+ scale = next (
70+ (np .array (t .get ('scale' , [1 , 1 , 1 ]), dtype = float ) for t in transforms if t .get ('type' ) == 'scale' ),
71+ np .ones (3 , dtype = float ),
72+ )
6873 avg_resolution = np .mean (scale )
6974 diff = abs (avg_resolution - np .mean (target_resolution ))
70-
7175 if diff < min_diff :
7276 min_diff = diff
73- best_scale = i
77+ best = ds_meta
78+
79+ transforms = best .get ('coordinateTransformations' , [])
80+ voxel_size = next (
81+ (np .array (t .get ('scale' , [1 , 1 , 1 ]), dtype = float ) for t in transforms if t .get ('type' ) == 'scale' ),
82+ np .array (target_resolution , dtype = float ),
83+ )
84+ translation = next (
85+ (np .array (t .get ('translation' , [0 , 0 , 0 ]), dtype = float ) for t in transforms if t .get ('type' ) == 'translation' ),
86+ np .zeros (3 , dtype = float ),
87+ )
7488
75- return best_scale
89+ return {
90+ "path" : best .get ('path' , 's0' ),
91+ "voxel_size" : voxel_size ,
92+ "translation" : translation ,
93+ }
7694
7795
7896def predict_cellmap (checkpoint_path , config_path , output_dir , crop_filter = None ):
@@ -132,14 +150,19 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):
132150 model = model .to (device )
133151 print (f"Using device: { device } " )
134152
135- # Setup sliding window inferer (MONAI)
136- inferer = SlidingWindowInferer (
137- roi_size = (128 , 128 , 128 ),
138- sw_batch_size = 4 ,
139- overlap = 0.5 ,
140- mode = 'gaussian' ,
141- device = torch .device (device ),
142- )
153+ base_roi = (128 , 128 , 128 )
154+ inferer_cache : dict [tuple [int , int , int ], SlidingWindowInferer ] = {}
155+
156+ def get_inferer (roi_size : tuple [int , int , int ]) -> SlidingWindowInferer :
157+ if roi_size not in inferer_cache :
158+ inferer_cache [roi_size ] = SlidingWindowInferer (
159+ roi_size = roi_size ,
160+ sw_batch_size = 4 ,
161+ overlap = 0.5 ,
162+ mode = 'gaussian' ,
163+ device = torch .device (device ),
164+ )
165+ return inferer_cache [roi_size ]
143166
144167 # Filter test crops if specified
145168 if crop_filter :
@@ -167,16 +190,20 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):
167190
168191 # Find appropriate scale level for target resolution
169192 em_path = f"{ zarr_path } /recon-1/em/fibsem-uint8"
170- scale_level = find_scale_level (em_path , target_resolution )
171- print (f" Using scale level: s{ scale_level } (target resolution: { target_resolution } nm)" )
193+ scale_info = select_scale_level (em_path , target_resolution )
194+ scale_level = scale_info ['path' ]
195+ scale_voxel_size = scale_info ['voxel_size' ]
196+ scale_translation = scale_info ['translation' ]
197+ print (f" Using scale level: { scale_level } (voxel size: { scale_voxel_size } nm)" )
172198
173199 # Load EM data once for all crops in this dataset
174200 try :
175- raw_array = zarr .open (f"{ em_path } /s { scale_level } " , mode = 'r' )
201+ raw_array = zarr .open (f"{ em_path } /{ scale_level } " , mode = 'r' )
176202 except Exception as e :
177203 print (f" Error loading EM data: { e } " )
178204 print (f" Skipping dataset { dataset } " )
179205 continue
206+ raw_shape = np .array (raw_array .shape , dtype = int )
180207
181208 for crop in tqdm (dataset_crops , desc = f" Crops in { dataset } " ):
182209 crop_id = crop .id
@@ -186,32 +213,57 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):
186213 if class_label not in classes :
187214 continue
188215
189- # Extract crop region from full volume
190- # Note: This is simplified - in production, use crop.translation and crop.shape
191- # to extract exact region
216+ # Extract crop region using precise metadata
192217 crop_output_dir = f"{ output_dir } /{ dataset } /crop{ crop_id } "
193218 os .makedirs (crop_output_dir , exist_ok = True )
194219
195- # Load a reasonable-sized region (simplified)
196- # In production, use crop metadata to extract exact region
197220 try :
198- # Get raw data shape
199- raw_shape = raw_array .shape
221+ target_shape = np .array (crop .shape , dtype = int )
222+ target_voxel = np .array (crop .voxel_size , dtype = float )
223+ translation_nm = np .array (crop .translation , dtype = float )
224+
225+ physical_extent = target_shape * target_voxel
226+ start_idx = np .floor ((translation_nm - scale_translation ) / scale_voxel_size ).astype (int )
227+ end_idx = np .ceil ((translation_nm + physical_extent - scale_translation ) / scale_voxel_size ).astype (int )
228+
229+ end_idx = np .maximum (end_idx , start_idx + 1 )
230+ start_idx = np .clip (start_idx , 0 , np .maximum (raw_shape - 1 , 0 ))
231+ end_idx = np .clip (end_idx , start_idx + 1 , raw_shape )
200232
201- # Simple extraction (center crop for demo)
202- # TODO: Use actual crop.translation and crop.shape for exact extraction
203- d , h , w = min (256 , raw_shape [0 ]), min (256 , raw_shape [1 ]), min (256 , raw_shape [2 ])
204- raw_volume = raw_array [:d , :h , :w ]
233+ slices = tuple (slice (int (s ), int (e )) for s , e in zip (start_idx , end_idx ))
234+ raw_volume = raw_array [slices ]
205235
206236 # Normalize and convert to tensor
207237 raw_volume = np .array (raw_volume ).astype (np .float32 ) / 255.0
208238 raw_tensor = torch .from_numpy (raw_volume [None , None , ...]).to (device ) # (1, 1, D, H, W)
209239
240+ roi_size = tuple (
241+ int (max (1 , min (base_dim , vol_dim )))
242+ for base_dim , vol_dim in zip (base_roi , raw_volume .shape )
243+ )
244+ inferer = get_inferer (roi_size )
245+
210246 # Run inference
211247 with torch .no_grad ():
212248 predictions = inferer (raw_tensor , model )
213249 predictions = torch .sigmoid (predictions ).cpu ().numpy ()[0 ] # (C, D, H, W)
214250
251+ # Resize predictions back to the official crop shape if needed
252+ target_shape_tuple = tuple (int (x ) for x in target_shape )
253+ if predictions .shape [1 :] != target_shape_tuple :
254+ pred_tensor = torch .from_numpy (predictions ).unsqueeze (0 )
255+ predictions = (
256+ F .interpolate (
257+ pred_tensor ,
258+ size = target_shape_tuple ,
259+ mode = "trilinear" ,
260+ align_corners = False ,
261+ )
262+ .squeeze (0 )
263+ .cpu ()
264+ .numpy ()
265+ )
266+
215267 # Save predictions for each class
216268 for i , cls in enumerate (classes ):
217269 pred_array = (predictions [i ] > 0.5 ).astype (np .uint8 )
0 commit comments