Skip to content

Commit d9f751f

Browse files
committed
Fix CellMap training issues: resolve NaN losses and infinite epochs
- Add data type conversion (_prepare_images/_prepare_labels) to prevent uint8/float16 mismatch - Implement CellMapBalancedLoss with comprehensive NaN/inf handling for sparse segmentation - Add batch skipping for empty batches to prevent NaN contamination in progress bars - Replace problematic pos_weight computation with uniform weights to avoid numerical instability - Add limit_train_batches/limit_val_batches to prevent infinite epochs from large datasets - Enhance loss function with gradient-connected safety values and extensive numerical stability checks
1 parent a18b392 commit d9f751f

File tree

5 files changed

+596
-93
lines changed

5 files changed

+596
-93
lines changed

scripts/cellmap/configs/mednext_cos7.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
'scale': (8, 8, 8), # 8nm isotropic resolution
2525
}
2626
target_array_info = input_array_info
27+
force_all_classes = 'both' # keep every organelle present in both splits
2728

2829
# Output paths
2930
output_dir = 'outputs/cellmap_cos7'
@@ -52,6 +53,9 @@
5253
epochs = 500 # Maximum epochs
5354
num_gpus = 1 # Number of GPUs
5455
precision = '16-mixed' # Mixed precision training
56+
iterations_per_epoch = None # Keep dataloader on the cheap shuffle path
57+
train_batches_per_epoch = 2000 # Lightning caps epoch length at 2k steps
58+
val_batches_per_epoch = 200 # Limit validation passes per epoch
5559

5660
# Learning rate scheduler (constant for MedNeXt)
5761
scheduler_config = {

scripts/cellmap/configs/mednext_mito.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
'scale': (4, 4, 4), # 4nm isotropic (higher resolution)
2323
}
2424
target_array_info = input_array_info
25+
force_all_classes = 'both' # ensure mito voxels present in both train/val splits
2526

2627
# Output paths
2728
output_dir = 'outputs/cellmap_mito'
@@ -50,6 +51,9 @@
5051
epochs = 1000 # More epochs for single class
5152
num_gpus = 1 # Number of GPUs
5253
precision = '16-mixed' # Mixed precision training
54+
iterations_per_epoch = None # Leave None so dataloader avoids huge subset shuffles
55+
train_batches_per_epoch = 2000 # Cap Lightning's epoch length instead
56+
val_batches_per_epoch = 200 # Limit validation passes per epoch
5357

5458
# Learning rate scheduler
5559
scheduler_config = {

scripts/cellmap/configs/monai_unet_quick.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# Classes to segment (just 2 classes for quick test)
1515
classes = ['nuc', 'mito']
1616

17+
# Data root path
18+
data_root = '/projects/weilab/dataset/cellmap'
19+
1720
# Data configuration
1821
input_array_info = {
1922
'shape': (64, 64, 64), # Small patches for speed

scripts/cellmap/predict_cellmap.py

Lines changed: 88 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import numpy as np
3636
from tqdm import tqdm
3737
from monai.inferers import SlidingWindowInferer
38+
import torch.nn.functional as F
3839

3940
# CellMap utilities
4041
from cellmap_segmentation_challenge.utils import TEST_CROPS, load_safe_config
@@ -45,34 +46,51 @@
4546
from omegaconf import OmegaConf
4647

4748

48-
def find_scale_level(zarr_path, target_resolution):
49-
"""Find the scale level that best matches target resolution."""
49+
def select_scale_level(zarr_path, target_resolution):
50+
"""Return the scale path plus voxel size/translation metadata closest to target resolution."""
5051
store = zarr.open(zarr_path, mode='r')
5152

52-
# Read OME-NGFF multiscale metadata
5353
multiscale_meta = store.attrs.get('multiscales', [{}])[0]
5454
datasets_meta = multiscale_meta.get('datasets', [])
5555

56+
# Default fallback if metadata is missing
5657
if not datasets_meta:
57-
# Fallback: use s2 (typically 8nm)
58-
return 2
58+
return {
59+
"path": "s2",
60+
"voxel_size": np.array(target_resolution, dtype=float),
61+
"translation": np.zeros(3, dtype=float),
62+
}
5963

60-
# Find closest scale to target resolution
61-
best_scale = 0
64+
best = datasets_meta[0]
6265
min_diff = float('inf')
6366

64-
for i, ds_meta in enumerate(datasets_meta):
65-
transforms = ds_meta.get('coordinateTransformations', [{}])
66-
scale = transforms[0].get('scale', [1, 1, 1]) if transforms else [1, 1, 1]
67-
# scale is [z, y, x] in nm
67+
for ds_meta in datasets_meta:
68+
transforms = ds_meta.get('coordinateTransformations', [])
69+
scale = next(
70+
(np.array(t.get('scale', [1, 1, 1]), dtype=float) for t in transforms if t.get('type') == 'scale'),
71+
np.ones(3, dtype=float),
72+
)
6873
avg_resolution = np.mean(scale)
6974
diff = abs(avg_resolution - np.mean(target_resolution))
70-
7175
if diff < min_diff:
7276
min_diff = diff
73-
best_scale = i
77+
best = ds_meta
78+
79+
transforms = best.get('coordinateTransformations', [])
80+
voxel_size = next(
81+
(np.array(t.get('scale', [1, 1, 1]), dtype=float) for t in transforms if t.get('type') == 'scale'),
82+
np.array(target_resolution, dtype=float),
83+
)
84+
translation = next(
85+
(np.array(t.get('translation', [0, 0, 0]), dtype=float) for t in transforms if t.get('type') == 'translation'),
86+
np.zeros(3, dtype=float),
87+
)
7488

75-
return best_scale
89+
return {
90+
"path": best.get('path', 's0'),
91+
"voxel_size": voxel_size,
92+
"translation": translation,
93+
}
7694

7795

7896
def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):
@@ -132,14 +150,19 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):
132150
model = model.to(device)
133151
print(f"Using device: {device}")
134152

135-
# Setup sliding window inferer (MONAI)
136-
inferer = SlidingWindowInferer(
137-
roi_size=(128, 128, 128),
138-
sw_batch_size=4,
139-
overlap=0.5,
140-
mode='gaussian',
141-
device=torch.device(device),
142-
)
153+
base_roi = (128, 128, 128)
154+
inferer_cache: dict[tuple[int, int, int], SlidingWindowInferer] = {}
155+
156+
def get_inferer(roi_size: tuple[int, int, int]) -> SlidingWindowInferer:
157+
if roi_size not in inferer_cache:
158+
inferer_cache[roi_size] = SlidingWindowInferer(
159+
roi_size=roi_size,
160+
sw_batch_size=4,
161+
overlap=0.5,
162+
mode='gaussian',
163+
device=torch.device(device),
164+
)
165+
return inferer_cache[roi_size]
143166

144167
# Filter test crops if specified
145168
if crop_filter:
@@ -167,16 +190,20 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):
167190

168191
# Find appropriate scale level for target resolution
169192
em_path = f"{zarr_path}/recon-1/em/fibsem-uint8"
170-
scale_level = find_scale_level(em_path, target_resolution)
171-
print(f" Using scale level: s{scale_level} (target resolution: {target_resolution} nm)")
193+
scale_info = select_scale_level(em_path, target_resolution)
194+
scale_level = scale_info['path']
195+
scale_voxel_size = scale_info['voxel_size']
196+
scale_translation = scale_info['translation']
197+
print(f" Using scale level: {scale_level} (voxel size: {scale_voxel_size} nm)")
172198

173199
# Load EM data once for all crops in this dataset
174200
try:
175-
raw_array = zarr.open(f"{em_path}/s{scale_level}", mode='r')
201+
raw_array = zarr.open(f"{em_path}/{scale_level}", mode='r')
176202
except Exception as e:
177203
print(f" Error loading EM data: {e}")
178204
print(f" Skipping dataset {dataset}")
179205
continue
206+
raw_shape = np.array(raw_array.shape, dtype=int)
180207

181208
for crop in tqdm(dataset_crops, desc=f" Crops in {dataset}"):
182209
crop_id = crop.id
@@ -186,32 +213,57 @@ def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None):
186213
if class_label not in classes:
187214
continue
188215

189-
# Extract crop region from full volume
190-
# Note: This is simplified - in production, use crop.translation and crop.shape
191-
# to extract exact region
216+
# Extract crop region using precise metadata
192217
crop_output_dir = f"{output_dir}/{dataset}/crop{crop_id}"
193218
os.makedirs(crop_output_dir, exist_ok=True)
194219

195-
# Load a reasonable-sized region (simplified)
196-
# In production, use crop metadata to extract exact region
197220
try:
198-
# Get raw data shape
199-
raw_shape = raw_array.shape
221+
target_shape = np.array(crop.shape, dtype=int)
222+
target_voxel = np.array(crop.voxel_size, dtype=float)
223+
translation_nm = np.array(crop.translation, dtype=float)
224+
225+
physical_extent = target_shape * target_voxel
226+
start_idx = np.floor((translation_nm - scale_translation) / scale_voxel_size).astype(int)
227+
end_idx = np.ceil((translation_nm + physical_extent - scale_translation) / scale_voxel_size).astype(int)
228+
229+
end_idx = np.maximum(end_idx, start_idx + 1)
230+
start_idx = np.clip(start_idx, 0, np.maximum(raw_shape - 1, 0))
231+
end_idx = np.clip(end_idx, start_idx + 1, raw_shape)
200232

201-
# Simple extraction (center crop for demo)
202-
# TODO: Use actual crop.translation and crop.shape for exact extraction
203-
d, h, w = min(256, raw_shape[0]), min(256, raw_shape[1]), min(256, raw_shape[2])
204-
raw_volume = raw_array[:d, :h, :w]
233+
slices = tuple(slice(int(s), int(e)) for s, e in zip(start_idx, end_idx))
234+
raw_volume = raw_array[slices]
205235

206236
# Normalize and convert to tensor
207237
raw_volume = np.array(raw_volume).astype(np.float32) / 255.0
208238
raw_tensor = torch.from_numpy(raw_volume[None, None, ...]).to(device) # (1, 1, D, H, W)
209239

240+
roi_size = tuple(
241+
int(max(1, min(base_dim, vol_dim)))
242+
for base_dim, vol_dim in zip(base_roi, raw_volume.shape)
243+
)
244+
inferer = get_inferer(roi_size)
245+
210246
# Run inference
211247
with torch.no_grad():
212248
predictions = inferer(raw_tensor, model)
213249
predictions = torch.sigmoid(predictions).cpu().numpy()[0] # (C, D, H, W)
214250

251+
# Resize predictions back to the official crop shape if needed
252+
target_shape_tuple = tuple(int(x) for x in target_shape)
253+
if predictions.shape[1:] != target_shape_tuple:
254+
pred_tensor = torch.from_numpy(predictions).unsqueeze(0)
255+
predictions = (
256+
F.interpolate(
257+
pred_tensor,
258+
size=target_shape_tuple,
259+
mode="trilinear",
260+
align_corners=False,
261+
)
262+
.squeeze(0)
263+
.cpu()
264+
.numpy()
265+
)
266+
215267
# Save predictions for each class
216268
for i, cls in enumerate(classes):
217269
pred_array = (predictions[i] > 0.5).astype(np.uint8)

0 commit comments

Comments
 (0)