Skip to content

Commit 514d3c1

Browse files
committed
fixed total iterations for rolling training
1 parent eb1b95a commit 514d3c1

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

examples/pytorch/FastCells/train_classifier.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,24 +165,44 @@ def fit(self, training_data, validation_data, options, sparsify=False, device=No
165165
trim_level = options.trim_level
166166

167167
ticks = training_data.num_rows / batch_size # iterations per epoch
168-
total_iterations = ticks * num_epochs
169-
scheduler = self.configure_lr(options, optimizer, ticks, total_iterations)
170-
171-
# optimizer = optim.Adam(model.parameters(), lr=0.0001)
172-
log = []
168+
169+
# Calculation of total iterations in non-rolling vs rolling training
170+
# ticks = num_rows/batch_size (total number of iterations per epoch)
171+
# Non-Rolling Training:
172+
# Total Iteration = num_epochs * ticks
173+
# Rolling Training:
174+
# irl = Initial_rolling_length (We are using 2)
175+
# If num_epochs <= max_rolling_length:
176+
# Total Iterations = sum(range(irl, irl + num_epochs))
177+
# If num_epochs > max_rolling_length:
178+
# Total Iterations = sum(range(irl, irl + max_rolling_length)) + (num_epochs - max_rolling_length)*ticks
173179
if options.rolling:
174180
rolling_length = 2
175181
max_rolling_length = int(ticks)
176-
if max_rolling_length > options.max_rolling_length:
177-
max_rolling_length = options.max_rolling_length
182+
if max_rolling_length > options.max_rolling_length + rolling_length:
183+
max_rolling_length = options.max_rolling_length + rolling_length
178184
bag_count = 100
179185
hidden_bag_size = batch_size * bag_count
186+
if num_epochs + rolling_length < max_rolling_length:
187+
max_rolling_length = num_epochs + rolling_length
188+
total_iterations = sum(range(rolling_length, max_rolling_length))
189+
if num_epochs + rolling_length > max_rolling_length:
190+
epochs_remaining = num_epochs + rolling_length - max_rolling_length
191+
total_iterations += epochs_remaining * training_data.num_rows / batch_size
192+
ticks = total_iterations / num_epochs
193+
else:
194+
total_iterations = ticks * num_epochs
195+
196+
scheduler = self.configure_lr(options, optimizer, ticks, total_iterations)
197+
198+
# optimizer = optim.Adam(model.parameters(), lr=0.0001)
199+
log = []
180200

181201
for epoch in range(num_epochs):
182202
self.train()
183203
if options.rolling:
184204
rolling_length += 1
185-
if rolling_length < max_rolling_length:
205+
if rolling_length <= max_rolling_length:
186206
self.init_hidden_bag(hidden_bag_size, device)
187207
for i_batch, (audio, labels) in enumerate(training_data.get_data_loader(batch_size)):
188208
if not self.batch_first:
@@ -196,7 +216,7 @@ def fit(self, training_data, validation_data, options, sparsify=False, device=No
196216
# Also, we need to clear out the hidden state,
197217
# detaching it from its history on the last instance.
198218
if options.rolling:
199-
if rolling_length < max_rolling_length:
219+
if rolling_length <= max_rolling_length:
200220
if (i_batch + 1) % rolling_length == 0:
201221
self.init_hidden()
202222
break

0 commit comments

Comments
 (0)