Skip to content

Commit 3dd560c

Browse files
shun-linseanpmorgan
authored andcommitted
fixes bugs when using fit generator (#829)
1 parent ac91cf2 commit 3dd560c

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tensorflow_addons/callbacks/tqdm_progress_bar.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ def on_epoch_begin(self, epoch, logs={}):
126126
dynamic_ncols=True,
127127
unit=self.mode)
128128

129-
self.seen = 0
129+
self.num_samples_seen = 0
130130
self.steps_to_update = 0
131+
self.steps_so_far = 0
131132
self.logs = defaultdict(float)
132133

133134
def on_epoch_end(self, epoch, logs={}):
@@ -154,10 +155,11 @@ def on_batch_end(self, batch, logs={}):
154155
else:
155156
batch_size = 1
156157

157-
self.seen += batch_size
158-
self.steps_to_update += batch_size
158+
self.num_samples_seen += batch_size
159+
self.steps_to_update += 1
160+
self.steps_so_far += 1
159161

160-
if self.seen < self.total_steps:
162+
if self.steps_so_far < self.total_steps:
161163

162164
for metric, value in logs.items():
163165
self.logs[metric] += value * batch_size
@@ -167,7 +169,7 @@ def on_batch_end(self, batch, logs={}):
167169
if self.show_epoch_progress and time_diff >= self.update_interval:
168170

169171
# update the epoch progress bar
170-
metrics = self.format_metrics(self.logs, self.seen)
172+
metrics = self.format_metrics(self.logs, self.num_samples_seen)
171173
self.epoch_progress_tqdm.desc = metrics
172174
self.epoch_progress_tqdm.update(self.steps_to_update)
173175

0 commit comments

Comments
 (0)