Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 30b7ec0

Browse files
kylesayrsbfineran
authored andcommitted
Fixed bug with extra batch being run, unwrapped combined operators for clarity (#890)
1 parent b93040c commit 30b7ec0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/sparseml/pytorch/utils/module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,9 @@ def run(
693693
epoch_timer = time.time()
694694

695695
for batch, data in data_iter:
696+
if 0 < max_steps and batch >= max_steps:
697+
break
698+
696699
step_timer = time.time()
697700
batch_size = self._run_funcs.batch_size(data) # type: int
698701

@@ -751,9 +754,6 @@ def run(
751754
if results is not None:
752755
results.append(batch_results, batch_size)
753756

754-
if 0 < max_steps <= batch:
755-
break
756-
757757
should_log = self._loggers and self._log_summary and results
758758
log_step = counter # log under the counter step for the summaries
759759

0 commit comments

Comments
 (0)