Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit c838611

Browse files
authored
Utility functions for stage epochs (#1424) (#1433)
* Utility functions to get bound epochs per stage * Handle start_epoch -1 case * Post-process dict to remove -1's as possible * Additional comments * Outsource logic to helper function * Iteratable bug fix * nit - empty space
1 parent bb73dd7 commit c838611

File tree

1 file changed

+66
-35
lines changed

1 file changed

+66
-35
lines changed

src/sparseml/optim/manager.py

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from collections import OrderedDict
2525
from copy import deepcopy
2626
from functools import cmp_to_key
27-
from typing import Any, Dict, Generator, List, Optional, Union
27+
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
2828

2929
from sparseml.optim.modifier import BaseModifier, BaseObject, ModifierProp
3030
from sparseml.sparsification.types import SparsificationTypes
@@ -343,47 +343,15 @@ def min_epochs(self) -> int:
343343
"""
344344
:return: the minimum epochs required by any of the modifiers under the manager
345345
"""
346-
vals = []
347-
vals.extend(
348-
[
349-
math.floor(mod.start_epoch)
350-
for mod in self.iter_modifiers()
351-
if mod.start_epoch > -1
352-
]
353-
)
354-
vals.extend(
355-
[
356-
math.floor(mod.end_epoch)
357-
for mod in self.iter_modifiers()
358-
if mod.end_epoch > -1
359-
]
360-
)
361-
362-
return min(vals) if len(vals) > 0 else -1
346+
return _min_modifier_epoch(self.iter_modifiers())
363347

364348
@ModifierProp(serializable=False)
365349
def max_epochs(self) -> int:
366350
"""
367351
:return: the maximum number of epochs required by any of the modifiers
368352
under the manager
369353
"""
370-
vals = []
371-
vals.extend(
372-
[
373-
math.ceil(mod.start_epoch)
374-
for mod in self.iter_modifiers()
375-
if mod.start_epoch > -1
376-
]
377-
)
378-
vals.extend(
379-
[
380-
math.ceil(mod.end_epoch)
381-
for mod in self.iter_modifiers()
382-
if mod.end_epoch > -1
383-
]
384-
)
385-
386-
return max(vals) if len(vals) > 0 else -1
354+
return _max_modifier_epoch(self.iter_modifiers())
387355

388356
def save(self, file_path: str, include_metadata: bool = True):
389357
"""
@@ -561,6 +529,44 @@ def qat_active(self, epoch: float) -> bool:
561529
else False
562530
)
563531

532+
def get_start_end_epochs(self) -> Dict[str, Tuple[float, float]]:
533+
"""
534+
Return an OrderedDict mapping each stage to its min and max epoch. If not a
535+
staged manager, map 'all' to the the min and max epochs
536+
"""
537+
if isinstance(self.modifiers, List):
538+
return OrderedDict({"all": (self.min_epochs, self.max_epochs)})
539+
else:
540+
stage_max_min = OrderedDict()
541+
for stage, mod_list in self.modifiers.items():
542+
epoch_min = _min_modifier_epoch(mod_list)
543+
epoch_max = _max_modifier_epoch(mod_list)
544+
stage_max_min[stage] = (epoch_min, epoch_max)
545+
546+
# post-process to replace -1's with their real values
547+
epochs_list = list(stage_max_min.values())
548+
for i, (stage, epochs) in enumerate(stage_max_min.items()):
549+
# replace start epochs that are -1 with the last epoch of the previous
550+
# stage, or 0 if it's the first stage
551+
if epochs[0] == -1:
552+
stage_max_min[stage][0] = epochs_list[i - 1][1] if i > 0 else 0
553+
# replace end epochs that are -1 with the next stage's start epoch,
554+
# unless it's the last stage
555+
if epochs[1] == -1 and i < len(epochs_list) - 1:
556+
stage_max_min[stage][1] = epochs_list[i + 1][0]
557+
558+
return stage_max_min
559+
560+
def get_last_start_epoch(self) -> float:
561+
"""
562+
Return the start epoch of the last stage in the recipe. Useful for applying
563+
recipes at the correct epoch in a staged run
564+
"""
565+
stage_max_min = self.get_start_end_epochs()
566+
last_stage_epochs = stage_max_min[next(reversed(stage_max_min))]
567+
last_start_epoch = last_stage_epochs[0]
568+
return last_start_epoch if last_start_epoch > -1 else 0
569+
564570
def _info_log_metadata(self):
565571
metadata_str = json.dumps(self._metadata, indent=1)
566572
_LOGGER.debug(f"Created recipe manager with metadata: {metadata_str}")
@@ -586,3 +592,28 @@ def _nested_dict_to_lines(
586592
# reached maximum nesting level.
587593
yaml_str_lines.append(indentation * nesting_depth + f"{key}: {value}")
588594
return yaml_str_lines
595+
596+
597+
def _min_modifier_epoch(modifiers: Iterable[BaseModifier]) -> float:
598+
"""
599+
:return: the minimum epochs required by any of the modifiers provided
600+
"""
601+
vals = [math.floor(mod.start_epoch) for mod in modifiers if mod.start_epoch > -1]
602+
603+
return min(vals) if len(vals) > 0 else -1
604+
605+
606+
def _max_modifier_epoch(modifiers: Iterable[BaseModifier]) -> float:
607+
"""
608+
:return: the maximum number of epochs required by any of the modifiers provided
609+
"""
610+
# save modifiers as list so it can iterated over multiple times
611+
modifiers = [mod for mod in modifiers]
612+
613+
vals = []
614+
vals.extend(
615+
[math.ceil(mod.start_epoch) for mod in modifiers if mod.start_epoch > -1]
616+
)
617+
vals.extend([math.ceil(mod.end_epoch) for mod in modifiers if mod.end_epoch > -1])
618+
619+
return max(vals) if len(vals) > 0 else -1

0 commit comments

Comments
 (0)