Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

-

### Fixed

- Fix `StochasticWeightAveraging` with infinite epochs ([#21396](https://github.com/Lightning-AI/pytorch-lightning/pull/21396))


## [2.6.0] - 2025-11-28

Expand Down
21 changes: 14 additions & 7 deletions src/lightning/pytorch/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def swa_start(self) -> int:

@property
def swa_end(self) -> int:
if self._max_epochs == -1:
return float("inf") # type: ignore[return-value]
return self._max_epochs - 1 # 0-based

@staticmethod
Expand All @@ -163,12 +165,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -

assert trainer.max_epochs is not None
if isinstance(self._swa_epoch_start, float):
if trainer.max_epochs == -1:
raise MisconfigurationException(
"SWA with `swa_epoch_start` as a float is not supported when `max_epochs=-1`. "
"Please provide `swa_epoch_start` as an integer."
)
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)

self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)

self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
if self._model_contains_batch_norm and trainer.max_epochs != -1:
# virtually increase max_epochs to perform batch norm update on latest epoch.
assert trainer.fit_loop.max_epochs is not None
trainer.fit_loop.max_epochs += 1
Expand Down Expand Up @@ -243,7 +250,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
self._latest_update_epoch = trainer.current_epoch

# Note: No > here in case the callback is saved with the model and training continues
if trainer.current_epoch == self.swa_end + 1:
if self._max_epochs != -1 and trainer.current_epoch == self.swa_end + 1:
# Transfer weights from average model to pl_module
assert self._average_model is not None
self.transfer_weights(self._average_model, pl_module)
Expand All @@ -267,17 +274,17 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
@override
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# the trainer increases the current epoch before this hook is called
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
if self._model_contains_batch_norm and self._max_epochs != -1 and trainer.current_epoch - 1 == self.swa_end + 1:
# BatchNorm epoch update. Reset state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.fit_loop.max_batches -= 1
assert trainer.fit_loop.max_epochs is not None
trainer.fit_loop.max_epochs -= 1
self.reset_momenta()
elif trainer.current_epoch - 1 == self.swa_end:
# Last SWA epoch. Transfer weights from average model to pl_module
assert self._average_model is not None
self.transfer_weights(self._average_model, pl_module)
elif trainer.current_epoch - 1 == self.swa_end or self._max_epochs == -1:
# Last SWA epoch or infinite training. Transfer weights from average model to pl_module
if self._average_model is not None:
self.transfer_weights(self._average_model, pl_module)

@staticmethod
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None:
Expand Down
30 changes: 30 additions & 0 deletions tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,5 +387,35 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str):
trainer.fit(model)


def test_swa_with_infinite_epochs_and_batchnorm(tmp_path):
"""Test that SWA works correctly with max_epochs=-1 (infinite training) and BatchNorm."""
model = SwaTestModel(batchnorm=True)
swa_callback = StochasticWeightAveraging(swa_lrs=0.1, swa_epoch_start=2)

trainer = Trainer(
default_root_dir=tmp_path,
enable_progress_bar=False,
enable_model_summary=False,
max_epochs=-1,
max_steps=30, # Use max_steps as stopping condition
limit_train_batches=5,
limit_val_batches=0,
callbacks=[swa_callback],
logger=False,
)
assert trainer.max_epochs == -1
assert trainer.fit_loop.max_epochs == -1

trainer.fit(model)
assert trainer.current_epoch >= 5
assert trainer.global_step == 30
assert trainer.max_epochs == -1

# Verify SWA was actually applied (update_parameters should have been called)
# SWA starts at epoch 2, so with 6 epochs (0-5), we should have 4 updates (epochs 2, 3, 4, 5)
assert swa_callback.n_averaged is not None
assert swa_callback.n_averaged > 0, "SWA should have updated parameters"


def _backward_patch(trainer: Trainer) -> AbstractContextManager:
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)
Loading