@@ -65,7 +65,7 @@ class AutoConfigPlanner:
6565
6666 def __init__ (
6767 self ,
68- architecture : str = ' mednext' ,
68+ architecture : str = " mednext" ,
6969 target_spacing : Optional [List [float ]] = None ,
7070 median_shape : Optional [List [int ]] = None ,
7171 manual_overrides : Optional [Dict [str , Any ]] = None ,
@@ -93,33 +93,33 @@ def __init__(
9393 def _get_architecture_defaults (self ) -> Dict [str , Any ]:
9494 """Get architecture-specific default parameters."""
9595 defaults = {
96- ' mednext' : {
97- ' base_features' : 32 ,
98- ' max_features' : 320 ,
99- 'lr' : 1e-3 , # MedNeXt paper recommends 1e-3
100- ' use_scheduler' : False , # MedNeXt uses constant LR
96+ " mednext" : {
97+ " base_features" : 32 ,
98+ " max_features" : 320 ,
99+ "lr" : 1e-3 , # MedNeXt paper recommends 1e-3
100+ " use_scheduler" : False , # MedNeXt uses constant LR
101101 },
102- ' mednext_custom' : {
103- ' base_features' : 32 ,
104- ' max_features' : 320 ,
105- 'lr' : 1e-3 ,
106- ' use_scheduler' : False ,
102+ " mednext_custom" : {
103+ " base_features" : 32 ,
104+ " max_features" : 320 ,
105+ "lr" : 1e-3 ,
106+ " use_scheduler" : False ,
107107 },
108- ' monai_basic_unet3d' : {
109- ' base_features' : 32 ,
110- ' max_features' : 512 ,
111- 'lr' : 1e-4 ,
112- ' use_scheduler' : True ,
108+ " monai_basic_unet3d" : {
109+ " base_features" : 32 ,
110+ " max_features" : 512 ,
111+ "lr" : 1e-4 ,
112+ " use_scheduler" : True ,
113113 },
114- ' monai_unet' : {
115- ' base_features' : 32 ,
116- ' max_features' : 512 ,
117- 'lr' : 1e-4 ,
118- ' use_scheduler' : True ,
114+ " monai_unet" : {
115+ " base_features" : 32 ,
116+ " max_features" : 512 ,
117+ "lr" : 1e-4 ,
118+ " use_scheduler" : True ,
119119 },
120120 }
121121
122- return defaults .get (self .architecture , defaults [' monai_basic_unet3d' ])
122+ return defaults .get (self .architecture , defaults [" monai_basic_unet3d" ])
123123
124124 def plan (
125125 self ,
@@ -149,20 +149,20 @@ def plan(
149149 result .planning_notes .append (f"Patch size: { patch_size } " )
150150
151151 # Step 2: Get model parameters
152- result .base_features = self .arch_defaults [' base_features' ]
153- result .max_features = self .arch_defaults [' max_features' ]
152+ result .base_features = self .arch_defaults [" base_features" ]
153+ result .max_features = self .arch_defaults [" max_features" ]
154154
155155 # Step 3: Determine precision
156156 result .precision = "16-mixed" if use_mixed_precision else "32"
157157
158158 # Step 4: Estimate memory and determine batch size
159- if not self .gpu_info [' cuda_available' ]:
159+ if not self .gpu_info [" cuda_available" ]:
160160 result .batch_size = 1
161161 result .precision = "32" # CPU doesn't support mixed precision well
162162 result .warnings .append ("CUDA not available, using CPU with batch_size=1" )
163163 result .planning_notes .append ("Training on CPU (slow!)" )
164164 else :
165- gpu_memory_gb = self .gpu_info [' available_memory_gb' ][0 ] # Use first GPU
165+ gpu_memory_gb = self .gpu_info [" available_memory_gb" ][0 ] # Use first GPU
166166 result .available_gpu_memory_gb = gpu_memory_gb
167167
168168 # Calculate number of pooling stages (log2 of patch size / 4)
@@ -211,12 +211,12 @@ def plan(
211211 )
212212
213213 # Step 5: Determine num_workers
214- num_gpus = self .gpu_info [' num_gpus' ] if self .gpu_info [' cuda_available' ] else 0
214+ num_gpus = self .gpu_info [" num_gpus" ] if self .gpu_info [" cuda_available" ] else 0
215215 result .num_workers = get_optimal_num_workers (num_gpus )
216216 result .planning_notes .append (f"Num workers: { result .num_workers } " )
217217
218218 # Step 6: Learning rate
219- result .lr = self .arch_defaults ['lr' ]
219+ result .lr = self .arch_defaults ["lr" ]
220220 result .planning_notes .append (f"Learning rate: { result .lr } " )
221221
222222 # Step 7: Apply manual overrides
@@ -266,8 +266,8 @@ def _plan_patch_size(self) -> List[int]:
266266
267267 # If GPU memory is limited, may need to reduce patch size
268268 # (This is a simplified heuristic)
269- if self .gpu_info [' cuda_available' ]:
270- gpu_memory_gb = self .gpu_info [' available_memory_gb' ][0 ]
269+ if self .gpu_info [" cuda_available" ]:
270+ gpu_memory_gb = self .gpu_info [" available_memory_gb" ][0 ]
271271 if gpu_memory_gb < 8 :
272272 # Very limited GPU, use smaller patches
273273 patch_size = np .minimum (patch_size , [64 , 64 , 64 ])
@@ -289,8 +289,10 @@ def print_plan(self, result: AutoPlanResult):
289289 print (f" Batch Size: { result .batch_size } " )
290290 if result .accumulate_grad_batches > 1 :
291291 effective_bs = result .batch_size * result .accumulate_grad_batches
292- print (f" Gradient Accumulation: { result .accumulate_grad_batches } "
293- f"(effective batch_size={ effective_bs } )" )
292+ print (
293+ f" Gradient Accumulation: { result .accumulate_grad_batches } "
294+ f"(effective batch_size={ effective_bs } )"
295+ )
294296 print (f" Num Workers: { result .num_workers } " )
295297 print ()
296298
@@ -307,8 +309,10 @@ def print_plan(self, result: AutoPlanResult):
307309 if result .available_gpu_memory_gb > 0 :
308310 print ("💾 GPU Memory:" )
309311 print (f" Available: { result .available_gpu_memory_gb :.2f} GB" )
310- print (f" Estimated Usage: { result .estimated_gpu_memory_gb :.2f} GB "
311- f"({ result .estimated_gpu_memory_gb / result .available_gpu_memory_gb * 100 :.1f} %)" )
312+ print (
313+ f" Estimated Usage: { result .estimated_gpu_memory_gb :.2f} GB "
314+ f"({ result .estimated_gpu_memory_gb / result .available_gpu_memory_gb * 100 :.1f} %)"
315+ )
312316 print (f" Per Sample: { result .gpu_memory_per_sample_gb :.2f} GB" )
313317 print ()
314318
@@ -349,45 +353,50 @@ def auto_plan_config(
349353 Updated config with auto-planned parameters
350354 """
351355 # Check if auto-planning is disabled
352- if hasattr (config , ' system' ) and hasattr (config .system , ' auto_plan' ):
356+ if hasattr (config , " system" ) and hasattr (config .system , " auto_plan" ):
353357 if not config .system .auto_plan :
354358 print ("ℹ️ Auto-planning disabled in config" )
355359 return config
356360
357361 # Extract relevant config values
358- architecture = config .model .architecture if hasattr (config .model , 'architecture' ) else 'mednext'
359- in_channels = config .model .in_channels if hasattr (config .model , 'in_channels' ) else 1
360- out_channels = config .model .out_channels if hasattr (config .model , 'out_channels' ) else 2
361- deep_supervision = config .model .deep_supervision if hasattr (config .model , 'deep_supervision' ) else False
362+ architecture = config .model .architecture if hasattr (config .model , "architecture" ) else "mednext"
363+ in_channels = config .model .in_channels if hasattr (config .model , "in_channels" ) else 1
364+ out_channels = config .model .out_channels if hasattr (config .model , "out_channels" ) else 2
365+ deep_supervision = (
366+ config .model .deep_supervision if hasattr (config .model , "deep_supervision" ) else False
367+ )
362368
363369 # Get target spacing and median shape if provided
364370 target_spacing = None
365- if hasattr (config , ' data' ) and hasattr (config .data , ' target_spacing' ):
371+ if hasattr (config , " data" ) and hasattr (config .data , " target_spacing" ):
366372 target_spacing = config .data .target_spacing
367373
368374 median_shape = None
369- if hasattr (config , ' data' ) and hasattr (config .data , ' median_shape' ):
375+ if hasattr (config , " data" ) and hasattr (config .data , " median_shape" ):
370376 median_shape = config .data .median_shape
371377
372378 # Collect manual overrides (values explicitly set in config)
373379 manual_overrides = {}
374- if hasattr (config , 'data' ):
375- if hasattr (config .data , 'batch_size' ) and config .data .batch_size is not None :
376- manual_overrides ['batch_size' ] = config .data .batch_size
377- if hasattr (config .data , 'num_workers' ) and config .data .num_workers is not None :
378- manual_overrides ['num_workers' ] = config .data .num_workers
379- if hasattr (config .data , 'patch_size' ) and config .data .patch_size is not None :
380- manual_overrides ['patch_size' ] = config .data .patch_size
381-
382- if hasattr (config , 'training' ):
383- if hasattr (config .training , 'precision' ) and config .training .precision is not None :
384- manual_overrides ['precision' ] = config .training .precision
385- if hasattr (config .training , 'accumulate_grad_batches' ) and config .training .accumulate_grad_batches is not None :
386- manual_overrides ['accumulate_grad_batches' ] = config .training .accumulate_grad_batches
387-
388- if hasattr (config , 'optimizer' ):
389- if hasattr (config .optimizer , 'lr' ) and config .optimizer .lr is not None :
390- manual_overrides ['lr' ] = config .optimizer .lr
380+ if hasattr (config , "data" ):
381+ if hasattr (config .data , "batch_size" ) and config .data .batch_size is not None :
382+ manual_overrides ["batch_size" ] = config .data .batch_size
383+ if hasattr (config .data , "num_workers" ) and config .data .num_workers is not None :
384+ manual_overrides ["num_workers" ] = config .data .num_workers
385+ if hasattr (config .data , "patch_size" ) and config .data .patch_size is not None :
386+ manual_overrides ["patch_size" ] = config .data .patch_size
387+
388+ if hasattr (config , "training" ):
389+ if hasattr (config .training , "precision" ) and config .training .precision is not None :
390+ manual_overrides ["precision" ] = config .training .precision
391+ if (
392+ hasattr (config .training , "accumulate_grad_batches" )
393+ and config .training .accumulate_grad_batches is not None
394+ ):
395+ manual_overrides ["accumulate_grad_batches" ] = config .training .accumulate_grad_batches
396+
397+ if hasattr (config , "optimizer" ):
398+ if hasattr (config .optimizer , "lr" ) and config .optimizer .lr is not None :
399+ manual_overrides ["lr" ] = config .optimizer .lr
391400
392401 # Create planner
393402 planner = AutoConfigPlanner (
@@ -398,9 +407,11 @@ def auto_plan_config(
398407 )
399408
400409 # Plan
401- use_mixed_precision = not (hasattr (config , 'training' ) and
402- hasattr (config .training , 'precision' ) and
403- config .training .precision == "32" )
410+ use_mixed_precision = not (
411+ hasattr (config , "training" )
412+ and hasattr (config .training , "precision" )
413+ and config .training .precision == "32"
414+ )
404415
405416 result = planner .plan (
406417 in_channels = in_channels ,
@@ -412,19 +423,19 @@ def auto_plan_config(
412423 # Update config with planned values (if not manually overridden)
413424 OmegaConf .set_struct (config , False ) # Allow adding new fields
414425
415- if ' batch_size' not in manual_overrides :
426+ if " batch_size" not in manual_overrides :
416427 config .data .batch_size = result .batch_size
417- if ' num_workers' not in manual_overrides :
428+ if " num_workers" not in manual_overrides :
418429 config .data .num_workers = result .num_workers
419- if ' patch_size' not in manual_overrides :
430+ if " patch_size" not in manual_overrides :
420431 config .data .patch_size = result .patch_size
421432
422- if ' precision' not in manual_overrides :
433+ if " precision" not in manual_overrides :
423434 config .training .precision = result .precision
424- if ' accumulate_grad_batches' not in manual_overrides :
435+ if " accumulate_grad_batches" not in manual_overrides :
425436 config .training .accumulate_grad_batches = result .accumulate_grad_batches
426437
427- if 'lr' not in manual_overrides :
438+ if "lr" not in manual_overrides :
428439 config .optimizer .lr = result .lr
429440
430441 OmegaConf .set_struct (config , True ) # Re-enable struct mode
@@ -436,14 +447,14 @@ def auto_plan_config(
436447 return config
437448
438449
439- if __name__ == ' __main__' :
450+ if __name__ == " __main__" :
440451 # Test auto planning
441452 from connectomics .config import Config
442453 from omegaconf import OmegaConf
443454
444455 # Create test config
445456 cfg = OmegaConf .structured (Config ())
446- cfg .model .architecture = ' mednext'
457+ cfg .model .architecture = " mednext"
447458 cfg .model .deep_supervision = True
448459
449460 # Auto plan
0 commit comments