Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit e0524d6

Browse files
committed
pytorch manager - group modifiers by stage in apply (#873)
* pytorch manager - group modifiers by stage in apply * update docstring * wrap modifier in iterable if necessary
1 parent 3f6d851 commit e0524d6

File tree

1 file changed

+76
-9
lines changed

1 file changed

+76
-9
lines changed

src/sparseml/pytorch/optim/manager.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
"""
2020

2121
import logging
22-
from typing import Any, Dict, List, Optional, Union
22+
import math
23+
from typing import Any, Dict, Iterable, List, Optional, Union
2324

2425
import torch
2526
from torch import Tensor
@@ -373,6 +374,45 @@ def load_state_dict(self, state_dict: Dict[str, Dict], strict: bool = True):
373374

374375
modifiers_index[key].load_state_dict(val)
375376

377+
def apply(
378+
self,
379+
module: Module,
380+
epoch: float = math.inf,
381+
loggers: Optional[LoggerManager] = None,
382+
finalize: bool = True,
383+
**kwargs,
384+
):
385+
"""
386+
Applies the lifecycle of each stage in the manager/recipe
387+
by calling into initialize and finalize for each modifier for each stage
388+
389+
:param module: the PyTorch model/module to modify
390+
:param epoch: the epoch to apply the modifier at, defaults to math.inf (end)
391+
:param loggers: Optional logger manager to log the modification process to
392+
:param finalize: True to invoke finalize after initialize, False otherwise.
393+
If training after one shot, set finalize=False to keep modifiers applied.
394+
:param kwargs: Optional kwargs to support specific arguments
395+
for individual modifiers (passed to initialize and finalize).
396+
"""
397+
if not self.initialized:
398+
super().initialize(module, epoch, loggers, **kwargs)
399+
self._initialize_epoch = epoch
400+
401+
modifier_lists = (
402+
self._modifiers
403+
if isinstance(self._modifiers, List)
404+
else list(self._modifiers.values())
405+
)
406+
407+
for modifier_list in modifier_lists:
408+
409+
self._initialize_modifiers(
410+
modifier_list, module, epoch, loggers=loggers, **kwargs
411+
)
412+
413+
if finalize:
414+
self._finalize_modifiers(modifier_list, module, **kwargs)
415+
376416
def apply_structure(
377417
self,
378418
module: Module,
@@ -422,12 +462,9 @@ def initialize(
422462
super().initialize(module, epoch, loggers, **kwargs)
423463
self._initialize_epoch = epoch
424464

425-
for mod in self.iter_modifiers():
426-
if mod.initialized:
427-
# check in case modifier was initialized from apply_structure
428-
continue
429-
430-
mod.initialize(module, epoch, loggers, **kwargs)
465+
self._initialize_modifiers(
466+
self.iter_modifiers(), module, epoch, loggers, **kwargs
467+
)
431468

432469
def initialize_loggers(self, loggers: Union[None, LoggerManager, List[BaseLogger]]):
433470
"""
@@ -521,8 +558,7 @@ def finalize(
521558
"""
522559
super().finalize(module, reset_loggers, **kwargs)
523560

524-
for mod in self.iter_modifiers():
525-
mod.finalize(module, reset_loggers, **kwargs)
561+
self._finalize_modifiers(self.iter_modifiers(), module, reset_loggers, **kwargs)
526562

527563
def update(
528564
self,
@@ -635,3 +671,34 @@ def optimizer_post_step(
635671
continue
636672

637673
mod.optimizer_post_step(module, optimizer, epoch, steps_per_epoch)
674+
675+
def _initialize_modifiers(
676+
self,
677+
modifiers: Iterable[Modifier],
678+
module: Module,
679+
epoch: float = 0,
680+
loggers: Union[None, LoggerManager, List[BaseLogger]] = None,
681+
**kwargs,
682+
):
683+
if isinstance(modifiers, Modifier):
684+
modifiers = [modifiers]
685+
686+
for mod in modifiers:
687+
if mod.initialized:
688+
# check in case modifier was initialized from apply_structure
689+
continue
690+
691+
mod.initialize(module, epoch, loggers, **kwargs)
692+
693+
def _finalize_modifiers(
694+
self,
695+
modifiers: Iterable[Modifier],
696+
module: Optional[Module] = None,
697+
reset_loggers: bool = True,
698+
**kwargs,
699+
):
700+
if isinstance(modifiers, Modifier):
701+
modifiers = [modifiers]
702+
703+
for mod in modifiers:
704+
mod.finalize(module, reset_loggers, **kwargs)

0 commit comments

Comments
 (0)