|
19 | 19 | """ |
20 | 20 |
|
21 | 21 | import logging |
22 | | -from typing import Any, Dict, List, Optional, Union |
| 22 | +import math |
| 23 | +from typing import Any, Dict, Iterable, List, Optional, Union |
23 | 24 |
|
24 | 25 | import torch |
25 | 26 | from torch import Tensor |
@@ -373,6 +374,45 @@ def load_state_dict(self, state_dict: Dict[str, Dict], strict: bool = True): |
373 | 374 |
|
374 | 375 | modifiers_index[key].load_state_dict(val) |
375 | 376 |
|
| 377 | + def apply( |
| 378 | + self, |
| 379 | + module: Module, |
| 380 | + epoch: float = math.inf, |
| 381 | + loggers: Optional[LoggerManager] = None, |
| 382 | + finalize: bool = True, |
| 383 | + **kwargs, |
| 384 | + ): |
| 385 | + """ |
| 386 | + Applies the lifecycle of each stage in the manager/recipe |
| 387 | + by calling into initialize and finalize for each modifier for each stage |
| 388 | +
|
| 389 | + :param module: the PyTorch model/module to modify |
| 390 | + :param epoch: the epoch to apply the modifier at, defaults to math.inf (end) |
| 391 | + :param loggers: Optional logger manager to log the modification process to |
| 392 | + :param finalize: True to invoke finalize after initialize, False otherwise. |
| 393 | + If training after one shot, set finalize=False to keep modifiers applied. |
| 394 | + :param kwargs: Optional kwargs to support specific arguments |
| 395 | + for individual modifiers (passed to initialize and finalize). |
| 396 | + """ |
| 397 | + if not self.initialized: |
| 398 | + super().initialize(module, epoch, loggers, **kwargs) |
| 399 | + self._initialize_epoch = epoch |
| 400 | + |
| 401 | + modifier_lists = ( |
| 402 | + self._modifiers |
| 403 | + if isinstance(self._modifiers, List) |
| 404 | + else list(self._modifiers.values()) |
| 405 | + ) |
| 406 | + |
| 407 | + for modifier_list in modifier_lists: |
| 408 | + |
| 409 | + self._initialize_modifiers( |
| 410 | + modifier_list, module, epoch, loggers=loggers, **kwargs |
| 411 | + ) |
| 412 | + |
| 413 | + if finalize: |
| 414 | + self._finalize_modifiers(modifier_list, module, **kwargs) |
| 415 | + |
376 | 416 | def apply_structure( |
377 | 417 | self, |
378 | 418 | module: Module, |
@@ -422,12 +462,9 @@ def initialize( |
422 | 462 | super().initialize(module, epoch, loggers, **kwargs) |
423 | 463 | self._initialize_epoch = epoch |
424 | 464 |
|
425 | | - for mod in self.iter_modifiers(): |
426 | | - if mod.initialized: |
427 | | - # check in case modifier was initialized from apply_structure |
428 | | - continue |
429 | | - |
430 | | - mod.initialize(module, epoch, loggers, **kwargs) |
| 465 | + self._initialize_modifiers( |
| 466 | + self.iter_modifiers(), module, epoch, loggers, **kwargs |
| 467 | + ) |
431 | 468 |
|
432 | 469 | def initialize_loggers(self, loggers: Union[None, LoggerManager, List[BaseLogger]]): |
433 | 470 | """ |
@@ -521,8 +558,7 @@ def finalize( |
521 | 558 | """ |
522 | 559 | super().finalize(module, reset_loggers, **kwargs) |
523 | 560 |
|
524 | | - for mod in self.iter_modifiers(): |
525 | | - mod.finalize(module, reset_loggers, **kwargs) |
| 561 | + self._finalize_modifiers(self.iter_modifiers(), module, reset_loggers, **kwargs) |
526 | 562 |
|
527 | 563 | def update( |
528 | 564 | self, |
@@ -635,3 +671,34 @@ def optimizer_post_step( |
635 | 671 | continue |
636 | 672 |
|
637 | 673 | mod.optimizer_post_step(module, optimizer, epoch, steps_per_epoch) |
| 674 | + |
| 675 | + def _initialize_modifiers( |
| 676 | + self, |
| 677 | + modifiers: Iterable[Modifier], |
| 678 | + module: Module, |
| 679 | + epoch: float = 0, |
| 680 | + loggers: Union[None, LoggerManager, List[BaseLogger]] = None, |
| 681 | + **kwargs, |
| 682 | + ): |
| 683 | + if isinstance(modifiers, Modifier): |
| 684 | + modifiers = [modifiers] |
| 685 | + |
| 686 | + for mod in modifiers: |
| 687 | + if mod.initialized: |
| 688 | + # check in case modifier was initialized from apply_structure |
| 689 | + continue |
| 690 | + |
| 691 | + mod.initialize(module, epoch, loggers, **kwargs) |
| 692 | + |
| 693 | + def _finalize_modifiers( |
| 694 | + self, |
| 695 | + modifiers: Iterable[Modifier], |
| 696 | + module: Optional[Module] = None, |
| 697 | + reset_loggers: bool = True, |
| 698 | + **kwargs, |
| 699 | + ): |
| 700 | + if isinstance(modifiers, Modifier): |
| 701 | + modifiers = [modifiers] |
| 702 | + |
| 703 | + for mod in modifiers: |
| 704 | + mod.finalize(module, reset_loggers, **kwargs) |
0 commit comments