@@ -166,24 +166,44 @@ def fit(self, training_data, validation_data, options, sparsify=False, device=No
166166 trim_level = options .trim_level
167167
168168 ticks = training_data .num_rows / batch_size # iterations per epoch
169- total_iterations = ticks * num_epochs
170- scheduler = self .configure_lr (options , optimizer , ticks , total_iterations )
171-
172- # optimizer = optim.Adam(model.parameters(), lr=0.0001)
173- log = []
169+
170+ # Calculation of total iterations in non-rolling vs rolling training
171+ # ticks = num_rows/batch_size (total number of iterations per epoch)
172+ # Non-Rolling Training:
173+ # Total Iteration = num_epochs * ticks
174+ # Rolling Training:
175+ # irl = Initial_rolling_length (We are using 2)
176+ # If num_epochs <= max_rolling_length:
177+ # Total Iterations = sum(range(irl, irl + num_epochs))
178+ # If num_epochs > max_rolling_length:
179+ # Total Iterations = sum(range(irl, irl + max_rolling_length)) + (num_epochs - max_rolling_length)*ticks
174180 if options .rolling :
175181 rolling_length = 2
176182 max_rolling_length = int (ticks )
177- if max_rolling_length > options .max_rolling_length :
178- max_rolling_length = options .max_rolling_length
183+ if max_rolling_length > options .max_rolling_length + rolling_length :
184+ max_rolling_length = options .max_rolling_length + rolling_length
179185 bag_count = 100
180186 hidden_bag_size = batch_size * bag_count
187+ if num_epochs + rolling_length < max_rolling_length :
188+ max_rolling_length = num_epochs + rolling_length
189+ total_iterations = sum (range (rolling_length , max_rolling_length ))
190+ if num_epochs + rolling_length > max_rolling_length :
191+ epochs_remaining = num_epochs + rolling_length - max_rolling_length
192+ total_iterations += epochs_remaining * training_data .num_rows / batch_size
193+ ticks = total_iterations / num_epochs
194+ else :
195+ total_iterations = ticks * num_epochs
196+
197+ scheduler = self .configure_lr (options , optimizer , ticks , total_iterations )
198+
199+ # optimizer = optim.Adam(model.parameters(), lr=0.0001)
200+ log = []
181201
182202 for epoch in range (num_epochs ):
183203 self .train ()
184204 if options .rolling :
185205 rolling_length += 1
186- if rolling_length < max_rolling_length :
206+ if rolling_length <= max_rolling_length :
187207 self .init_hidden_bag (hidden_bag_size , device )
188208 for i_batch , (audio , labels ) in enumerate (training_data .get_data_loader (batch_size )):
189209 if not self .batch_first :
@@ -197,7 +217,7 @@ def fit(self, training_data, validation_data, options, sparsify=False, device=No
197217 # Also, we need to clear out the hidden state,
198218 # detaching it from its history on the last instance.
199219 if options .rolling :
200- if rolling_length < max_rolling_length :
220+ if rolling_length <= max_rolling_length :
201221 if (i_batch + 1 ) % rolling_length == 0 :
202222 self .init_hidden ()
203223 break
0 commit comments