Skip to content

Commit d3ba607

Browse files
author
dmoi
committed
add commitment and lr scheduling
1 parent 04c7042 commit d3ba607

File tree

6 files changed

+685
-270
lines changed

6 files changed

+685
-270
lines changed

foldtree2/learn_monodecoder.py

Lines changed: 196 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@
2323
from datetime import datetime
2424
warnings.filterwarnings("ignore", category=UserWarning)
2525

26+
# Try to import transformers schedulers, fall back to PyTorch schedulers
27+
try:
28+
from transformers import (
29+
get_linear_schedule_with_warmup,
30+
get_cosine_schedule_with_warmup,
31+
get_cosine_with_hard_restarts_schedule_with_warmup,
32+
get_polynomial_decay_schedule_with_warmup
33+
)
34+
TRANSFORMERS_AVAILABLE = True
35+
except ImportError:
36+
print("Warning: transformers library not available. Using PyTorch schedulers only.")
37+
TRANSFORMERS_AVAILABLE = False
38+
2639
# Add argparse for CLI configuration
2740
parser = argparse.ArgumentParser(description='Train model with MultiMonoDecoders for sequence and geometry prediction')
2841
parser.add_argument('--config', '-c', type=str, default=None,
@@ -74,10 +87,30 @@
7487
help='Save current configuration to file (YAML format)')
7588
parser.add_argument('--lr-warmup-steps', type=int, default=0,
7689
help='Number of steps for learning rate warmup (default: 0, no warmup)')
77-
parser.add_argument('--lr-schedule', type=str, default='plateau', choices=['plateau', 'cosine', 'linear', 'none'],
78-
help='Learning rate schedule: plateau (ReduceLROnPlateau), cosine, linear decay, or none (default: plateau)')
90+
parser.add_argument('--lr-warmup-ratio', type=float, default=0.0,
91+
help='Warmup ratio (fraction of total steps). Overrides --lr-warmup-steps if > 0 (default: 0.0)')
92+
parser.add_argument('--lr-schedule', type=str, default='plateau',
93+
choices=['plateau', 'cosine', 'linear', 'cosine_restarts', 'polynomial', 'none'],
94+
help='Learning rate schedule (default: plateau)')
7995
parser.add_argument('--lr-min', type=float, default=1e-6,
8096
help='Minimum learning rate for cosine/linear schedules (default: 1e-6)')
97+
parser.add_argument('--gradient-accumulation-steps', '--grad-accum', type=int, default=1,
98+
help='Number of gradient accumulation steps (default: 1, no accumulation)')
99+
parser.add_argument('--num-cycles', type=int, default=3,
100+
help='Number of cycles for cosine_restarts scheduler (default: 3)')
101+
102+
# Commitment cost scheduling
103+
parser.add_argument('--commitment-cost', type=float, default=0.9,
104+
help='Final commitment cost for VQ-VAE (default: 0.9)')
105+
parser.add_argument('--use-commitment-scheduling', action='store_true',
106+
help='Enable commitment cost scheduling (warmup from low to final value)')
107+
parser.add_argument('--commitment-schedule', type=str, default='cosine',
108+
choices=['cosine', 'linear', 'none'],
109+
help='Commitment cost schedule type (default: cosine)')
110+
parser.add_argument('--commitment-warmup-steps', type=int, default=5000,
111+
help='Number of steps to warmup commitment cost (default: 5000)')
112+
parser.add_argument('--commitment-start', type=float, default=0.1,
113+
help='Starting commitment cost when using scheduling (default: 0.1)')
81114

82115
# Print an overview of the arguments and example command if no arguments provided
83116
if len(sys.argv) == 1:
@@ -172,8 +205,18 @@
172205
print(f" Run Name: {args.run_name if args.run_name else 'auto-generated'}")
173206
print(f" LR Schedule: {args.lr_schedule}")
174207
print(f" LR Warmup Steps: {args.lr_warmup_steps}")
175-
if args.lr_schedule in ['cosine', 'linear']:
208+
print(f" LR Warmup Ratio: {args.lr_warmup_ratio}")
209+
print(f" Gradient Accumulation Steps: {args.gradient_accumulation_steps}")
210+
if args.lr_schedule in ['cosine', 'linear', 'cosine_restarts', 'polynomial']:
176211
print(f" LR Min: {args.lr_min}")
212+
if args.lr_schedule == 'cosine_restarts':
213+
print(f" Num Cycles: {args.num_cycles}")
214+
print(f" Commitment Cost: {args.commitment_cost}")
215+
print(f" Use Commitment Scheduling: {args.use_commitment_scheduling}")
216+
if args.use_commitment_scheduling:
217+
print(f" Commitment Schedule: {args.commitment_schedule}")
218+
print(f" Commitment Warmup Steps: {args.commitment_warmup_steps}")
219+
print(f" Commitment Start: {args.commitment_start}")
177220

178221
# Save configuration if requested
179222
if args.save_config:
@@ -264,15 +307,19 @@
264307
out_channels=args.embedding_dim,
265308
metadata={'edge_types': [('res','contactPoints','res'), ('res','hbond','res')]},
266309
num_embeddings=args.num_embeddings,
267-
commitment_cost=0.8,
310+
commitment_cost=args.commitment_cost,
268311
edge_dim=1,
269312
encoder_hidden=hidden_size,
270-
EMA= args.EMA,
313+
EMA=args.EMA,
271314
nheads=5,
272315
dropout_p=0.005,
273316
reset_codes=False,
274317
flavor='transformer',
275-
fftin=True
318+
fftin=True,
319+
use_commitment_scheduling=args.use_commitment_scheduling,
320+
commitment_warmup_steps=args.commitment_warmup_steps,
321+
commitment_schedule=args.commitment_schedule,
322+
commitment_start=args.commitment_start
276323
)
277324
else:
278325
encoder = ft2.mk1_Encoder(
@@ -281,15 +328,19 @@
281328
out_channels=args.embedding_dim,
282329
metadata={'edge_types': [('res','contactPoints','res')]},
283330
num_embeddings=args.num_embeddings,
284-
commitment_cost=0.9,
331+
commitment_cost=args.commitment_cost,
285332
edge_dim=1,
286333
encoder_hidden=hidden_size,
287334
EMA=args.EMA,
288335
nheads=8,
289336
dropout_p=0.01,
290337
reset_codes=False,
291338
flavor='transformer',
292-
fftin=True
339+
fftin=True,
340+
use_commitment_scheduling=args.use_commitment_scheduling,
341+
commitment_warmup_steps=args.commitment_warmup_steps,
342+
commitment_schedule=args.commitment_schedule,
343+
commitment_start=args.commitment_start
293344
)
294345

295346
if args.hetero_gae:
@@ -433,37 +484,82 @@
433484
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=args.learning_rate , weight_decay=0.000001)
434485

435486
# Learning rate scheduler setup
436-
total_steps = len(train_loader) * args.epochs
437-
warmup_steps = args.lr_warmup_steps
487+
total_steps = len(train_loader) * args.epochs // args.gradient_accumulation_steps # Adjust for gradient accumulation
438488

439-
if args.lr_schedule == 'plateau':
440-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)
441-
elif args.lr_schedule == 'cosine':
442-
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=args.lr_min)
443-
elif args.lr_schedule == 'linear':
444-
# Linear decay from learning_rate to lr_min
445-
lambda_lr = lambda step: max(args.lr_min / args.learning_rate, 1.0 - (step - warmup_steps) / (total_steps - warmup_steps))
446-
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_lr)
447-
else: # 'none'
448-
scheduler = None
449-
450-
# Learning rate warmup function
451-
def get_warmup_lr(current_step, warmup_steps, base_lr):
452-
"""Linear warmup from 0 to base_lr over warmup_steps"""
453-
if warmup_steps == 0 or current_step >= warmup_steps:
454-
return base_lr
455-
return base_lr * (current_step / warmup_steps)
489+
# Calculate warmup steps
490+
if args.lr_warmup_ratio > 0:
491+
warmup_steps = int(total_steps * args.lr_warmup_ratio)
492+
print(f"Using warmup ratio {args.lr_warmup_ratio:.2%}, calculated warmup_steps: {warmup_steps}")
493+
else:
494+
warmup_steps = args.lr_warmup_steps
456495

457-
def set_lr(optimizer, lr):
458-
"""Set learning rate for all parameter groups"""
459-
for param_group in optimizer.param_groups:
460-
param_group['lr'] = lr
496+
# Initialize scheduler
497+
scheduler = None
498+
scheduler_step_mode = 'epoch' # 'step' or 'epoch'
461499

462-
print(f"Learning rate schedule: {args.lr_schedule}")
463-
print(f"Warmup steps: {warmup_steps}")
464-
print(f"Total training steps: {total_steps}")
465-
if args.lr_schedule in ['cosine', 'linear']:
466-
print(f"Min learning rate: {args.lr_min}")
500+
if args.lr_schedule == 'plateau':
501+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
502+
scheduler_step_mode = 'epoch'
503+
504+
elif args.lr_schedule == 'cosine' and TRANSFORMERS_AVAILABLE:
505+
scheduler = get_cosine_schedule_with_warmup(
506+
optimizer,
507+
num_warmup_steps=warmup_steps,
508+
num_training_steps=total_steps
509+
)
510+
scheduler_step_mode = 'step'
511+
512+
elif args.lr_schedule == 'linear' and TRANSFORMERS_AVAILABLE:
513+
scheduler = get_linear_schedule_with_warmup(
514+
optimizer,
515+
num_warmup_steps=warmup_steps,
516+
num_training_steps=total_steps
517+
)
518+
scheduler_step_mode = 'step'
519+
520+
elif args.lr_schedule == 'cosine_restarts' and TRANSFORMERS_AVAILABLE:
521+
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
522+
optimizer,
523+
num_warmup_steps=warmup_steps,
524+
num_training_steps=total_steps,
525+
num_cycles=args.num_cycles
526+
)
527+
scheduler_step_mode = 'step'
528+
529+
elif args.lr_schedule == 'polynomial' and TRANSFORMERS_AVAILABLE:
530+
scheduler = get_polynomial_decay_schedule_with_warmup(
531+
optimizer,
532+
num_warmup_steps=warmup_steps,
533+
num_training_steps=total_steps,
534+
power=2.0
535+
)
536+
scheduler_step_mode = 'step'
537+
538+
elif args.lr_schedule in ['cosine', 'linear', 'cosine_restarts', 'polynomial'] and not TRANSFORMERS_AVAILABLE:
539+
# Fallback to PyTorch schedulers
540+
print(f"Warning: transformers not available, falling back to PyTorch CosineAnnealingLR for {args.lr_schedule}")
541+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
542+
optimizer,
543+
T_max=total_steps - warmup_steps,
544+
eta_min=args.lr_min
545+
)
546+
scheduler_step_mode = 'step'
547+
548+
elif args.lr_schedule == 'none':
549+
scheduler = None
550+
print("No learning rate scheduling (constant LR)")
551+
else:
552+
print(f"Unknown scheduler: {args.lr_schedule}, using no scheduling")
553+
554+
print(f"\nScheduler Configuration:")
555+
print(f" Schedule type: {args.lr_schedule}")
556+
print(f" Scheduler step mode: {scheduler_step_mode}")
557+
print(f" Warmup steps: {warmup_steps}")
558+
print(f" Total training steps: {total_steps}")
559+
print(f" Gradient accumulation steps: {args.gradient_accumulation_steps}")
560+
print(f" Effective batch size: {args.batch_size * args.gradient_accumulation_steps}")
561+
if args.lr_schedule in ['cosine', 'linear', 'cosine_restarts', 'polynomial']:
562+
print(f" Min learning rate: {args.lr_min}")
467563

468564
# Function to analyze gradient norms
469565
def analyze_gradient_norms(model, top_k=3):
@@ -496,6 +592,16 @@ def analyze_gradient_norms(model, top_k=3):
496592
f.write(f'Embedding dimension: {args.embedding_dim}\n')
497593
f.write(f'Number of embeddings: {args.num_embeddings}\n')
498594
f.write(f'Loss weights - Edge: {edgeweight}, X: {xweight}, FFT2: {fft2weight}, VQ: {vqweight}\n')
595+
f.write(f'LR Schedule: {args.lr_schedule}\n')
596+
f.write(f'LR Warmup Steps: {warmup_steps}\n')
597+
f.write(f'Gradient Accumulation Steps: {args.gradient_accumulation_steps}\n')
598+
f.write(f'Effective Batch Size: {args.batch_size * args.gradient_accumulation_steps}\n')
599+
f.write(f'Commitment Cost: {args.commitment_cost}\n')
600+
f.write(f'Use Commitment Scheduling: {args.use_commitment_scheduling}\n')
601+
if args.use_commitment_scheduling:
602+
f.write(f'Commitment Schedule: {args.commitment_schedule}\n')
603+
f.write(f'Commitment Warmup Steps: {args.commitment_warmup_steps}\n')
604+
f.write(f'Commitment Start: {args.commitment_start}\n')
499605

500606
# Save configuration to TensorBoard
501607
config_text = "\n".join([f"{k}: {v}" for k, v in vars(args).items()])
@@ -522,9 +628,18 @@ def analyze_gradient_norms(model, top_k=3):
522628
best_loss = float('inf')
523629
done_burn = False
524630
after_burn_in = args.epochs - burn_in if burn_in else args.epochs
525-
global_step = 0 # Track global training steps for warmup
631+
global_step = 0 # Track global training steps for warmup and scheduling
632+
accumulation_step = 0 # Track steps within gradient accumulation
633+
634+
print(f"\nTraining Configuration:")
635+
print(f" Total epochs: {args.epochs}")
636+
print(f" Burn-in epochs: {burn_in}")
637+
print(f" After burn-in epochs: {after_burn_in}")
638+
print(f" Gradient accumulation steps: {args.gradient_accumulation_steps}")
639+
print(f" Steps per epoch: {len(train_loader)}")
640+
print(f" Effective steps per epoch: {len(train_loader) // args.gradient_accumulation_steps}")
641+
print()
526642

527-
print(f"Total epochs: {args.epochs}, Burn-in epochs: {burn_in}, After burn-in epochs: {after_burn_in}")
528643
for epoch in range(args.epochs):
529644
if burn_in and epoch < burn_in:
530645
print(f"Burn-in epoch {epoch+1}/{args.epochs}: Adjusting loss weights")
@@ -557,16 +672,9 @@ def analyze_gradient_norms(model, top_k=3):
557672
total_loss_fft2 = 0
558673
total_logit_loss = 0
559674

560-
for data in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}"):
675+
for batch_idx, data in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")):
561676
data = data.to(device)
562677

563-
# Learning rate warmup
564-
if warmup_steps > 0 and global_step < warmup_steps:
565-
warmup_lr = get_warmup_lr(global_step, warmup_steps, args.learning_rate)
566-
set_lr(optimizer, warmup_lr)
567-
568-
optimizer.zero_grad()
569-
570678
# Forward through encoder
571679
z, vqloss = encoder(data)
572680
data['res'].x = z
@@ -612,29 +720,36 @@ def analyze_gradient_norms(model, top_k=3):
612720
fft2loss * fft2weight + angles_loss * angles_weight +
613721
logitloss * logitweight)
614722

615-
# Backward and optimize
723+
# Scale loss for gradient accumulation
724+
loss = loss / args.gradient_accumulation_steps
725+
726+
# Backward
616727
loss.backward()
617728

618-
if clip_grad:
619-
torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
620-
torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
621-
622-
optimizer.step()
729+
accumulation_step += 1
623730

624-
# Step the scheduler (for step-based schedulers)
625-
if scheduler is not None and args.lr_schedule in ['cosine', 'linear']:
626-
if global_step >= warmup_steps:
731+
# Optimize after accumulating gradients
732+
if accumulation_step % args.gradient_accumulation_steps == 0:
733+
if clip_grad:
734+
torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
735+
torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
736+
737+
optimizer.step()
738+
optimizer.zero_grad()
739+
740+
# Step the scheduler (for step-based schedulers)
741+
if scheduler is not None and scheduler_step_mode == 'step':
627742
scheduler.step()
743+
744+
global_step += 1
628745

629-
global_step += 1
630-
631-
# Accumulate metrics
632-
total_loss_x += xloss.item()
633-
total_logit_loss += logitloss.item()
634-
total_loss_edge += edgeloss.item()
635-
total_loss_fft2 += fft2loss.item()
636-
total_angles_loss += angles_loss.item()
637-
total_vq += vqloss.item() if isinstance(vqloss, torch.Tensor) else float(vqloss)
746+
# Accumulate metrics (scale back up since loss was scaled down)
747+
total_loss_x += xloss.item() * args.gradient_accumulation_steps
748+
total_logit_loss += logitloss.item() * args.gradient_accumulation_steps
749+
total_loss_edge += edgeloss.item() * args.gradient_accumulation_steps
750+
total_loss_fft2 += fft2loss.item() * args.gradient_accumulation_steps
751+
total_angles_loss += angles_loss.item() * args.gradient_accumulation_steps
752+
total_vq += (vqloss.item() if isinstance(vqloss, torch.Tensor) else float(vqloss)) * args.gradient_accumulation_steps
638753

639754
# Calculate average losses
640755
avg_loss_x = total_loss_x / len(train_loader)
@@ -647,16 +762,24 @@ def analyze_gradient_norms(model, top_k=3):
647762
avg_loss_fft2 + avg_angles_loss + avg_logit_loss)
648763

649764
# Update learning rate scheduler (for epoch-based schedulers)
650-
if scheduler is not None and args.lr_schedule == 'plateau':
651-
scheduler.step(avg_loss_x)
765+
if scheduler is not None and scheduler_step_mode == 'epoch':
766+
if args.lr_schedule == 'plateau':
767+
scheduler.step(avg_loss_x)
768+
else:
769+
scheduler.step()
652770

653771
# Print metrics
654772
print(f"Epoch {epoch+1}: AA Loss: {avg_loss_x:.4f}, "
655773
f"Edge Loss: {avg_loss_edge:.4f}, VQ Loss: {avg_loss_vq:.4f}, "
656774
f"FFT2 Loss: {avg_loss_fft2:.4f}, Angles Loss: {avg_angles_loss:.4f}, "
657775
f"Logit Loss: {avg_logit_loss:.4f}")
658-
print(f"Total Loss: {avg_total_loss:.4f}, "
659-
f"LR: {optimizer.param_groups[0]['lr']:.6f}")
776+
current_lr = optimizer.param_groups[0]['lr']
777+
print(f"Total Loss: {avg_total_loss:.4f}, LR: {current_lr:.6f}")
778+
779+
# Print commitment cost if using scheduling
780+
if args.use_commitment_scheduling and hasattr(encoder, 'vector_quantizer'):
781+
current_commitment = encoder.vector_quantizer.get_commitment_cost()
782+
print(f"Commitment Cost: {current_commitment:.4f}")
660783

661784
#if avg_loss_edge > avg_loss_x:
662785
# edgeweight *= 1.5
@@ -679,6 +802,11 @@ def analyze_gradient_norms(model, top_k=3):
679802
writer.add_scalar('Loss/Total', avg_total_loss, epoch)
680803
writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], epoch)
681804

805+
# Log commitment cost if using scheduling
806+
if args.use_commitment_scheduling and hasattr(encoder, 'vector_quantizer'):
807+
current_commitment = encoder.vector_quantizer.get_commitment_cost()
808+
writer.add_scalar('Training/Commitment_Cost', current_commitment, epoch)
809+
682810
# Log loss weights
683811
writer.add_scalar('Weights/Edge', edgeweight, epoch)
684812
writer.add_scalar('Weights/X', xweight, epoch)

0 commit comments

Comments
 (0)