Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit d83f526

Browse files
authored
Clear cache before and after the OBS pruning step (#960)
Pytorch caching doesn't seem to work consistently across different versions, so let's make sure we clear it explicitly before and after the pruning step.
1 parent d7c5dc4 commit d83f526

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_obs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def check_mask_update(
246246
return # not a one-shot run
247247

248248
_LOGGER.info("Running OBS Pruning")
249+
torch.cuda.empty_cache()
249250
if self._scorer._is_main_proc:
250251
# collect grads for empirical inverse Fisher estimation
251252
self._scorer._enabled_grad_buffering = True
@@ -254,6 +255,7 @@ def check_mask_update(
254255
self._scorer._enabled_grad_buffering = False
255256

256257
super().check_mask_update(module, epoch, steps_per_epoch, **kwargs)
258+
torch.cuda.empty_cache()
257259

258260
def _get_mask_creator(
259261
self, param_names: List[str], params: List[Parameter]

0 commit comments

Comments
 (0)