@@ -126,8 +126,9 @@ def on_epoch_begin(self, epoch, logs={}):
126126 dynamic_ncols = True ,
127127 unit = self .mode )
128128
129- self .seen = 0
129+ self .num_samples_seen = 0
130130 self .steps_to_update = 0
131+ self .steps_so_far = 0
131132 self .logs = defaultdict (float )
132133
133134 def on_epoch_end (self , epoch , logs = {}):
@@ -154,10 +155,11 @@ def on_batch_end(self, batch, logs={}):
154155 else :
155156 batch_size = 1
156157
157- self .seen += batch_size
158- self .steps_to_update += batch_size
158+ self .num_samples_seen += batch_size
159+ self .steps_to_update += 1
160+ self .steps_so_far += 1
159161
160- if self .seen < self .total_steps :
162+ if self .steps_so_far < self .total_steps :
161163
162164 for metric , value in logs .items ():
163165 self .logs [metric ] += value * batch_size
@@ -167,7 +169,7 @@ def on_batch_end(self, batch, logs={}):
167169 if self .show_epoch_progress and time_diff >= self .update_interval :
168170
169171 # update the epoch progress bar
170- metrics = self .format_metrics (self .logs , self .seen )
172+ metrics = self .format_metrics (self .logs , self .num_samples_seen )
171173 self .epoch_progress_tqdm .desc = metrics
172174 self .epoch_progress_tqdm .update (self .steps_to_update )
173175
0 commit comments