@@ -165,24 +165,44 @@ def fit(self, training_data, validation_data, options, sparsify=False, device=No
165165 trim_level = options .trim_level
166166
167167 ticks = training_data .num_rows / batch_size # iterations per epoch
168- total_iterations = ticks * num_epochs
169- scheduler = self .configure_lr (options , optimizer , ticks , total_iterations )
170-
171- # optimizer = optim.Adam(model.parameters(), lr=0.0001)
172- log = []
168+
169+ # Calculation of total iterations in non-rolling vs rolling training
170+ # ticks = num_rows/batch_size (total number of iterations per epoch)
171+ # Non-Rolling Training:
172+ # Total Iteration = num_epochs * ticks
173+ # Rolling Training:
174+ # irl = Initial_rolling_length (We are using 2)
175+ # If num_epochs <= max_rolling_length:
176+ # Total Iterations = sum(range(irl, irl + num_epochs))
177+ # If num_epochs > max_rolling_length:
178+ # Total Iterations = sum(range(irl, irl + max_rolling_length)) + (num_epochs - max_rolling_length)*ticks
173179 if options .rolling :
174180 rolling_length = 2
175181 max_rolling_length = int (ticks )
176- if max_rolling_length > options .max_rolling_length :
177- max_rolling_length = options .max_rolling_length
182+ if max_rolling_length > options .max_rolling_length + rolling_length :
183+ max_rolling_length = options .max_rolling_length + rolling_length
178184 bag_count = 100
179185 hidden_bag_size = batch_size * bag_count
186+ if num_epochs + rolling_length < max_rolling_length :
187+ max_rolling_length = num_epochs + rolling_length
188+ total_iterations = sum (range (rolling_length , max_rolling_length ))
189+ if num_epochs + rolling_length > max_rolling_length :
190+ epochs_remaining = num_epochs + rolling_length - max_rolling_length
191+ total_iterations += epochs_remaining * training_data .num_rows / batch_size
192+ ticks = total_iterations / num_epochs
193+ else :
194+ total_iterations = ticks * num_epochs
195+
196+ scheduler = self .configure_lr (options , optimizer , ticks , total_iterations )
197+
198+ # optimizer = optim.Adam(model.parameters(), lr=0.0001)
199+ log = []
180200
181201 for epoch in range (num_epochs ):
182202 self .train ()
183203 if options .rolling :
184204 rolling_length += 1
185- if rolling_length < max_rolling_length :
205+ if rolling_length <= max_rolling_length :
186206 self .init_hidden_bag (hidden_bag_size , device )
187207 for i_batch , (audio , labels ) in enumerate (training_data .get_data_loader (batch_size )):
188208 if not self .batch_first :
@@ -196,7 +216,7 @@ def fit(self, training_data, validation_data, options, sparsify=False, device=No
196216 # Also, we need to clear out the hidden state,
197217 # detaching it from its history on the last instance.
198218 if options .rolling :
199- if rolling_length < max_rolling_length :
219+ if rolling_length <= max_rolling_length :
200220 if (i_batch + 1 ) % rolling_length == 0 :
201221 self .init_hidden ()
202222 break
0 commit comments