Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepem/train/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def initialize(self):
self.parser.add_argument('--chkpt_num', type=int, default=0)
self.parser.add_argument('--no_eval', action='store_true')
self.parser.add_argument('--pretrain', default=None)
self.parser.add_argument('--grad_accum_steps', type=int, default=1)

# WandB logging
self.parser.add_argument('--wandb_pad_output', action='store_true')
Expand Down
49 changes: 40 additions & 9 deletions deepem/train/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import time

from collections import defaultdict
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
Expand All @@ -25,6 +26,8 @@ def cleanup_distributed():


def train(opt):
assert not (opt.size_average and opt.grad_accum_steps > 1), \
"size_average and grad_accum_steps > 1 are not supported"
# Model
if opt.parallel == "DDP":
# Make sure samewise finished syncing files
Expand Down Expand Up @@ -66,8 +69,11 @@ def train(opt):
# Timer
t0 = time.time()

for i in range(opt.chkpt_num, opt.max_iter):
grad_accum_steps = opt.grad_accum_steps
accum_losses = defaultdict(float)
accum_nmasks = defaultdict(float)

for i in range(opt.chkpt_num, opt.max_iter):
# Load training samples.
sample = train_loader()

Expand All @@ -81,23 +87,48 @@ def train(opt):
losses, nmasks, preds = forward(model, sample, opt)
total_loss = sum([w*losses[k] for k, w in opt.loss_weight.items()])
# Backward passes under autocast are not recommended.
scaler.scale(total_loss).backward()
scaler.scale(total_loss / grad_accum_steps).backward()
else:
losses, nmasks, preds = forward(model, sample, opt)
total_loss = sum([w * losses[k] for k, w in opt.loss_weight.items()])
(total_loss / grad_accum_steps).backward()

# Accumulate metrics for logging
with torch.no_grad():
for k, v in losses.items():
accum_losses[k] += v
for k, v in nmasks.items():
accum_nmasks[k] += v

if ((i - opt.chkpt_num) + 1) % grad_accum_steps != 0:
continue

# --- From here on, code only runs on optimizer step ---
if opt.mixed_precision:
scaler.step(optimizer)
scaler.update()
losses = {k: v.float() for k, v in losses.items()}
nmasks = {k: v.float() for k, v in nmasks.items()}
preds = {k: v.float() for k, v in preds.items()}
else:
losses, nmasks, preds = forward(model, sample, opt)
total_loss = sum([w*losses[k] for k, w in opt.loss_weight.items()])
total_loss.backward()
optimizer.step()

# Average accumulated losses
avg_losses = {k: v / grad_accum_steps for k, v in accum_losses.items()}
avg_nmasks = {k: v / grad_accum_steps for k, v in accum_nmasks.items()}

if opt.mixed_precision:
avg_losses = {k: v.float() for k, v in avg_losses.items()}
avg_nmasks = {k: v.float() for k, v in avg_nmasks.items()}
preds = {k: v.float() for k, v in preds.items()}


# Elapsed time
elapsed = time.time() - t0

# Record keeping
logger.record('train', losses, nmasks, elapsed=elapsed)
logger.record("train", avg_losses, avg_nmasks, elapsed=elapsed)

# Reset accumulators
accum_losses = defaultdict(float)
accum_nmasks = defaultdict(float)

# Log & display averaged stats.
if (i+1) % opt.avgs_intv == 0 or i < opt.warm_up:
Expand Down