2323from datetime import datetime
2424warnings .filterwarnings ("ignore" , category = UserWarning )
2525
26+ # Try to import transformers schedulers, fall back to PyTorch schedulers
27+ try :
28+ from transformers import (
29+ get_linear_schedule_with_warmup ,
30+ get_cosine_schedule_with_warmup ,
31+ get_cosine_with_hard_restarts_schedule_with_warmup ,
32+ get_polynomial_decay_schedule_with_warmup
33+ )
34+ TRANSFORMERS_AVAILABLE = True
35+ except ImportError :
36+ print ("Warning: transformers library not available. Using PyTorch schedulers only." )
37+ TRANSFORMERS_AVAILABLE = False
38+
2639# Add argparse for CLI configuration
2740parser = argparse .ArgumentParser (description = 'Train model with MultiMonoDecoders for sequence and geometry prediction' )
2841parser .add_argument ('--config' , '-c' , type = str , default = None ,
7487 help = 'Save current configuration to file (YAML format)' )
7588parser .add_argument ('--lr-warmup-steps' , type = int , default = 0 ,
7689 help = 'Number of steps for learning rate warmup (default: 0, no warmup)' )
77- parser .add_argument ('--lr-schedule' , type = str , default = 'plateau' , choices = ['plateau' , 'cosine' , 'linear' , 'none' ],
78- help = 'Learning rate schedule: plateau (ReduceLROnPlateau), cosine, linear decay, or none (default: plateau)' )
90+ parser .add_argument ('--lr-warmup-ratio' , type = float , default = 0.0 ,
91+ help = 'Warmup ratio (fraction of total steps). Overrides --lr-warmup-steps if > 0 (default: 0.0)' )
92+ parser .add_argument ('--lr-schedule' , type = str , default = 'plateau' ,
93+ choices = ['plateau' , 'cosine' , 'linear' , 'cosine_restarts' , 'polynomial' , 'none' ],
94+ help = 'Learning rate schedule (default: plateau)' )
7995parser .add_argument ('--lr-min' , type = float , default = 1e-6 ,
8096 help = 'Minimum learning rate for cosine/linear schedules (default: 1e-6)' )
97+ parser .add_argument ('--gradient-accumulation-steps' , '--grad-accum' , type = int , default = 1 ,
98+ help = 'Number of gradient accumulation steps (default: 1, no accumulation)' )
99+ parser .add_argument ('--num-cycles' , type = int , default = 3 ,
100+ help = 'Number of cycles for cosine_restarts scheduler (default: 3)' )
101+
102+ # Commitment cost scheduling
103+ parser .add_argument ('--commitment-cost' , type = float , default = 0.9 ,
104+ help = 'Final commitment cost for VQ-VAE (default: 0.9)' )
105+ parser .add_argument ('--use-commitment-scheduling' , action = 'store_true' ,
106+ help = 'Enable commitment cost scheduling (warmup from low to final value)' )
107+ parser .add_argument ('--commitment-schedule' , type = str , default = 'cosine' ,
108+ choices = ['cosine' , 'linear' , 'none' ],
109+ help = 'Commitment cost schedule type (default: cosine)' )
110+ parser .add_argument ('--commitment-warmup-steps' , type = int , default = 5000 ,
111+ help = 'Number of steps to warmup commitment cost (default: 5000)' )
112+ parser .add_argument ('--commitment-start' , type = float , default = 0.1 ,
113+ help = 'Starting commitment cost when using scheduling (default: 0.1)' )
81114
82115# Print an overview of the arguments and example command if no arguments provided
83116if len (sys .argv ) == 1 :
172205print (f" Run Name: { args .run_name if args .run_name else 'auto-generated' } " )
173206print (f" LR Schedule: { args .lr_schedule } " )
174207print (f" LR Warmup Steps: { args .lr_warmup_steps } " )
175- if args .lr_schedule in ['cosine' , 'linear' ]:
208+ print (f" LR Warmup Ratio: { args .lr_warmup_ratio } " )
209+ print (f" Gradient Accumulation Steps: { args .gradient_accumulation_steps } " )
210+ if args .lr_schedule in ['cosine' , 'linear' , 'cosine_restarts' , 'polynomial' ]:
176211 print (f" LR Min: { args .lr_min } " )
212+ if args .lr_schedule == 'cosine_restarts' :
213+ print (f" Num Cycles: { args .num_cycles } " )
214+ print (f" Commitment Cost: { args .commitment_cost } " )
215+ print (f" Use Commitment Scheduling: { args .use_commitment_scheduling } " )
216+ if args .use_commitment_scheduling :
217+ print (f" Commitment Schedule: { args .commitment_schedule } " )
218+ print (f" Commitment Warmup Steps: { args .commitment_warmup_steps } " )
219+ print (f" Commitment Start: { args .commitment_start } " )
177220
178221# Save configuration if requested
179222if args .save_config :
264307 out_channels = args .embedding_dim ,
265308 metadata = {'edge_types' : [('res' ,'contactPoints' ,'res' ), ('res' ,'hbond' ,'res' )]},
266309 num_embeddings = args .num_embeddings ,
267- commitment_cost = 0.8 ,
310+ commitment_cost = args . commitment_cost ,
268311 edge_dim = 1 ,
269312 encoder_hidden = hidden_size ,
270- EMA = args .EMA ,
313+ EMA = args .EMA ,
271314 nheads = 5 ,
272315 dropout_p = 0.005 ,
273316 reset_codes = False ,
274317 flavor = 'transformer' ,
275- fftin = True
318+ fftin = True ,
319+ use_commitment_scheduling = args .use_commitment_scheduling ,
320+ commitment_warmup_steps = args .commitment_warmup_steps ,
321+ commitment_schedule = args .commitment_schedule ,
322+ commitment_start = args .commitment_start
276323 )
277324 else :
278325 encoder = ft2 .mk1_Encoder (
281328 out_channels = args .embedding_dim ,
282329 metadata = {'edge_types' : [('res' ,'contactPoints' ,'res' )]},
283330 num_embeddings = args .num_embeddings ,
284- commitment_cost = 0.9 ,
331+ commitment_cost = args . commitment_cost ,
285332 edge_dim = 1 ,
286333 encoder_hidden = hidden_size ,
287334 EMA = args .EMA ,
288335 nheads = 8 ,
289336 dropout_p = 0.01 ,
290337 reset_codes = False ,
291338 flavor = 'transformer' ,
292- fftin = True
339+ fftin = True ,
340+ use_commitment_scheduling = args .use_commitment_scheduling ,
341+ commitment_warmup_steps = args .commitment_warmup_steps ,
342+ commitment_schedule = args .commitment_schedule ,
343+ commitment_start = args .commitment_start
293344 )
294345
295346 if args .hetero_gae :
433484optimizer = torch .optim .AdamW (list (encoder .parameters ()) + list (decoder .parameters ()), lr = args .learning_rate , weight_decay = 0.000001 )
434485
435486# Learning rate scheduler setup
436- total_steps = len (train_loader ) * args .epochs
437- warmup_steps = args .lr_warmup_steps
487+ total_steps = len (train_loader ) * args .epochs // args .gradient_accumulation_steps # Adjust for gradient accumulation
438488
439- if args .lr_schedule == 'plateau' :
440- scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (optimizer , factor = 0.5 , patience = 2 )
441- elif args .lr_schedule == 'cosine' :
442- scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (optimizer , T_max = total_steps - warmup_steps , eta_min = args .lr_min )
443- elif args .lr_schedule == 'linear' :
444- # Linear decay from learning_rate to lr_min
445- lambda_lr = lambda step : max (args .lr_min / args .learning_rate , 1.0 - (step - warmup_steps ) / (total_steps - warmup_steps ))
446- scheduler = torch .optim .lr_scheduler .LambdaLR (optimizer , lr_lambda = lambda_lr )
447- else : # 'none'
448- scheduler = None
449-
450- # Learning rate warmup function
451- def get_warmup_lr (current_step , warmup_steps , base_lr ):
452- """Linear warmup from 0 to base_lr over warmup_steps"""
453- if warmup_steps == 0 or current_step >= warmup_steps :
454- return base_lr
455- return base_lr * (current_step / warmup_steps )
489+ # Calculate warmup steps
490+ if args .lr_warmup_ratio > 0 :
491+ warmup_steps = int (total_steps * args .lr_warmup_ratio )
492+ print (f"Using warmup ratio { args .lr_warmup_ratio :.2%} , calculated warmup_steps: { warmup_steps } " )
493+ else :
494+ warmup_steps = args .lr_warmup_steps
456495
457- def set_lr (optimizer , lr ):
458- """Set learning rate for all parameter groups"""
459- for param_group in optimizer .param_groups :
460- param_group ['lr' ] = lr
496+ # Initialize scheduler
497+ scheduler = None
498+ scheduler_step_mode = 'epoch' # 'step' or 'epoch'
461499
462- print (f"Learning rate schedule: { args .lr_schedule } " )
463- print (f"Warmup steps: { warmup_steps } " )
464- print (f"Total training steps: { total_steps } " )
465- if args .lr_schedule in ['cosine' , 'linear' ]:
466- print (f"Min learning rate: { args .lr_min } " )
500+ if args .lr_schedule == 'plateau' :
501+ scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (optimizer , mode = 'min' , factor = 0.5 , patience = 2 , verbose = True )
502+ scheduler_step_mode = 'epoch'
503+
504+ elif args .lr_schedule == 'cosine' and TRANSFORMERS_AVAILABLE :
505+ scheduler = get_cosine_schedule_with_warmup (
506+ optimizer ,
507+ num_warmup_steps = warmup_steps ,
508+ num_training_steps = total_steps
509+ )
510+ scheduler_step_mode = 'step'
511+
512+ elif args .lr_schedule == 'linear' and TRANSFORMERS_AVAILABLE :
513+ scheduler = get_linear_schedule_with_warmup (
514+ optimizer ,
515+ num_warmup_steps = warmup_steps ,
516+ num_training_steps = total_steps
517+ )
518+ scheduler_step_mode = 'step'
519+
520+ elif args .lr_schedule == 'cosine_restarts' and TRANSFORMERS_AVAILABLE :
521+ scheduler = get_cosine_with_hard_restarts_schedule_with_warmup (
522+ optimizer ,
523+ num_warmup_steps = warmup_steps ,
524+ num_training_steps = total_steps ,
525+ num_cycles = args .num_cycles
526+ )
527+ scheduler_step_mode = 'step'
528+
529+ elif args .lr_schedule == 'polynomial' and TRANSFORMERS_AVAILABLE :
530+ scheduler = get_polynomial_decay_schedule_with_warmup (
531+ optimizer ,
532+ num_warmup_steps = warmup_steps ,
533+ num_training_steps = total_steps ,
534+ power = 2.0
535+ )
536+ scheduler_step_mode = 'step'
537+
538+ elif args .lr_schedule in ['cosine' , 'linear' , 'cosine_restarts' , 'polynomial' ] and not TRANSFORMERS_AVAILABLE :
539+ # Fallback to PyTorch schedulers
540+ print (f"Warning: transformers not available, falling back to PyTorch CosineAnnealingLR for { args .lr_schedule } " )
541+ scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (
542+ optimizer ,
543+ T_max = total_steps - warmup_steps ,
544+ eta_min = args .lr_min
545+ )
546+ scheduler_step_mode = 'step'
547+
548+ elif args .lr_schedule == 'none' :
549+ scheduler = None
550+ print ("No learning rate scheduling (constant LR)" )
551+ else :
552+ print (f"Unknown scheduler: { args .lr_schedule } , using no scheduling" )
553+
554+ print (f"\n Scheduler Configuration:" )
555+ print (f" Schedule type: { args .lr_schedule } " )
556+ print (f" Scheduler step mode: { scheduler_step_mode } " )
557+ print (f" Warmup steps: { warmup_steps } " )
558+ print (f" Total training steps: { total_steps } " )
559+ print (f" Gradient accumulation steps: { args .gradient_accumulation_steps } " )
560+ print (f" Effective batch size: { args .batch_size * args .gradient_accumulation_steps } " )
561+ if args .lr_schedule in ['cosine' , 'linear' , 'cosine_restarts' , 'polynomial' ]:
562+ print (f" Min learning rate: { args .lr_min } " )
467563
468564# Function to analyze gradient norms
469565def analyze_gradient_norms (model , top_k = 3 ):
@@ -496,6 +592,16 @@ def analyze_gradient_norms(model, top_k=3):
496592 f .write (f'Embedding dimension: { args .embedding_dim } \n ' )
497593 f .write (f'Number of embeddings: { args .num_embeddings } \n ' )
498594 f .write (f'Loss weights - Edge: { edgeweight } , X: { xweight } , FFT2: { fft2weight } , VQ: { vqweight } \n ' )
595+ f .write (f'LR Schedule: { args .lr_schedule } \n ' )
596+ f .write (f'LR Warmup Steps: { warmup_steps } \n ' )
597+ f .write (f'Gradient Accumulation Steps: { args .gradient_accumulation_steps } \n ' )
598+ f .write (f'Effective Batch Size: { args .batch_size * args .gradient_accumulation_steps } \n ' )
599+ f .write (f'Commitment Cost: { args .commitment_cost } \n ' )
600+ f .write (f'Use Commitment Scheduling: { args .use_commitment_scheduling } \n ' )
601+ if args .use_commitment_scheduling :
602+ f .write (f'Commitment Schedule: { args .commitment_schedule } \n ' )
603+ f .write (f'Commitment Warmup Steps: { args .commitment_warmup_steps } \n ' )
604+ f .write (f'Commitment Start: { args .commitment_start } \n ' )
499605
500606# Save configuration to TensorBoard
501607config_text = "\n " .join ([f"{ k } : { v } " for k , v in vars (args ).items ()])
@@ -522,9 +628,18 @@ def analyze_gradient_norms(model, top_k=3):
522628best_loss = float ('inf' )
523629done_burn = False
524630after_burn_in = args .epochs - burn_in if burn_in else args .epochs
525- global_step = 0 # Track global training steps for warmup
631+ global_step = 0 # Track global training steps for warmup and scheduling
632+ accumulation_step = 0 # Track steps within gradient accumulation
633+
634+ print (f"\n Training Configuration:" )
635+ print (f" Total epochs: { args .epochs } " )
636+ print (f" Burn-in epochs: { burn_in } " )
637+ print (f" After burn-in epochs: { after_burn_in } " )
638+ print (f" Gradient accumulation steps: { args .gradient_accumulation_steps } " )
639+ print (f" Steps per epoch: { len (train_loader )} " )
640+ print (f" Effective steps per epoch: { len (train_loader ) // args .gradient_accumulation_steps } " )
641+ print ()
526642
527- print (f"Total epochs: { args .epochs } , Burn-in epochs: { burn_in } , After burn-in epochs: { after_burn_in } " )
528643for epoch in range (args .epochs ):
529644 if burn_in and epoch < burn_in :
530645 print (f"Burn-in epoch { epoch + 1 } /{ args .epochs } : Adjusting loss weights" )
@@ -557,16 +672,9 @@ def analyze_gradient_norms(model, top_k=3):
557672 total_loss_fft2 = 0
558673 total_logit_loss = 0
559674
560- for data in tqdm .tqdm (train_loader , desc = f"Epoch { epoch + 1 } /{ args .epochs } " ):
675+ for batch_idx , data in enumerate ( tqdm .tqdm (train_loader , desc = f"Epoch { epoch + 1 } /{ args .epochs } " ) ):
561676 data = data .to (device )
562677
563- # Learning rate warmup
564- if warmup_steps > 0 and global_step < warmup_steps :
565- warmup_lr = get_warmup_lr (global_step , warmup_steps , args .learning_rate )
566- set_lr (optimizer , warmup_lr )
567-
568- optimizer .zero_grad ()
569-
570678 # Forward through encoder
571679 z , vqloss = encoder (data )
572680 data ['res' ].x = z
@@ -612,29 +720,36 @@ def analyze_gradient_norms(model, top_k=3):
612720 fft2loss * fft2weight + angles_loss * angles_weight +
613721 logitloss * logitweight )
614722
615- # Backward and optimize
723+ # Scale loss for gradient accumulation
724+ loss = loss / args .gradient_accumulation_steps
725+
726+ # Backward
616727 loss .backward ()
617728
618- if clip_grad :
619- torch .nn .utils .clip_grad_norm_ (encoder .parameters (), max_norm = 1.0 )
620- torch .nn .utils .clip_grad_norm_ (decoder .parameters (), max_norm = 1.0 )
621-
622- optimizer .step ()
729+ accumulation_step += 1
623730
624- # Step the scheduler (for step-based schedulers)
625- if scheduler is not None and args .lr_schedule in ['cosine' , 'linear' ]:
626- if global_step >= warmup_steps :
731+ # Optimize after accumulating gradients
732+ if accumulation_step % args .gradient_accumulation_steps == 0 :
733+ if clip_grad :
734+ torch .nn .utils .clip_grad_norm_ (encoder .parameters (), max_norm = 1.0 )
735+ torch .nn .utils .clip_grad_norm_ (decoder .parameters (), max_norm = 1.0 )
736+
737+ optimizer .step ()
738+ optimizer .zero_grad ()
739+
740+ # Step the scheduler (for step-based schedulers)
741+ if scheduler is not None and scheduler_step_mode == 'step' :
627742 scheduler .step ()
743+
744+ global_step += 1
628745
629- global_step += 1
630-
631- # Accumulate metrics
632- total_loss_x += xloss .item ()
633- total_logit_loss += logitloss .item ()
634- total_loss_edge += edgeloss .item ()
635- total_loss_fft2 += fft2loss .item ()
636- total_angles_loss += angles_loss .item ()
637- total_vq += vqloss .item () if isinstance (vqloss , torch .Tensor ) else float (vqloss )
746+ # Accumulate metrics (scale back up since loss was scaled down)
747+ total_loss_x += xloss .item () * args .gradient_accumulation_steps
748+ total_logit_loss += logitloss .item () * args .gradient_accumulation_steps
749+ total_loss_edge += edgeloss .item () * args .gradient_accumulation_steps
750+ total_loss_fft2 += fft2loss .item () * args .gradient_accumulation_steps
751+ total_angles_loss += angles_loss .item () * args .gradient_accumulation_steps
752+ total_vq += (vqloss .item () if isinstance (vqloss , torch .Tensor ) else float (vqloss )) * args .gradient_accumulation_steps
638753
639754 # Calculate average losses
640755 avg_loss_x = total_loss_x / len (train_loader )
@@ -647,16 +762,24 @@ def analyze_gradient_norms(model, top_k=3):
647762 avg_loss_fft2 + avg_angles_loss + avg_logit_loss )
648763
649764 # Update learning rate scheduler (for epoch-based schedulers)
650- if scheduler is not None and args .lr_schedule == 'plateau' :
651- scheduler .step (avg_loss_x )
765+ if scheduler is not None and scheduler_step_mode == 'epoch' :
766+ if args .lr_schedule == 'plateau' :
767+ scheduler .step (avg_loss_x )
768+ else :
769+ scheduler .step ()
652770
653771 # Print metrics
654772 print (f"Epoch { epoch + 1 } : AA Loss: { avg_loss_x :.4f} , "
655773 f"Edge Loss: { avg_loss_edge :.4f} , VQ Loss: { avg_loss_vq :.4f} , "
656774 f"FFT2 Loss: { avg_loss_fft2 :.4f} , Angles Loss: { avg_angles_loss :.4f} , "
657775 f"Logit Loss: { avg_logit_loss :.4f} " )
658- print (f"Total Loss: { avg_total_loss :.4f} , "
659- f"LR: { optimizer .param_groups [0 ]['lr' ]:.6f} " )
776+ current_lr = optimizer .param_groups [0 ]['lr' ]
777+ print (f"Total Loss: { avg_total_loss :.4f} , LR: { current_lr :.6f} " )
778+
779+ # Print commitment cost if using scheduling
780+ if args .use_commitment_scheduling and hasattr (encoder , 'vector_quantizer' ):
781+ current_commitment = encoder .vector_quantizer .get_commitment_cost ()
782+ print (f"Commitment Cost: { current_commitment :.4f} " )
660783
661784 #if avg_loss_edge > avg_loss_x:
662785 # edgeweight *= 1.5
@@ -679,6 +802,11 @@ def analyze_gradient_norms(model, top_k=3):
679802 writer .add_scalar ('Loss/Total' , avg_total_loss , epoch )
680803 writer .add_scalar ('Learning_Rate' , optimizer .param_groups [0 ]['lr' ], epoch )
681804
805+ # Log commitment cost if using scheduling
806+ if args .use_commitment_scheduling and hasattr (encoder , 'vector_quantizer' ):
807+ current_commitment = encoder .vector_quantizer .get_commitment_cost ()
808+ writer .add_scalar ('Training/Commitment_Cost' , current_commitment , epoch )
809+
682810 # Log loss weights
683811 writer .add_scalar ('Weights/Edge' , edgeweight , epoch )
684812 writer .add_scalar ('Weights/X' , xweight , epoch )
0 commit comments