Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit cae2a1e

Browse files
authored
[cherry-pick] fix propagation of mask creator objects in PT pruning modifier (#220) (#221)
* fix propagation of mask creator objects in PT pruning modifier (#220) * fix propagation of mask creator objects in PT pruning modifier * quality - need type() to handle inheretence * update version to 0.3.1 for hotfix
1 parent ba9fed0 commit cae2a1e

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/sparseml/pytorch/optim/modifier_pruning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ def initialize(self, module: Module, optimizer: Optimizer):
608608
layers,
609609
param_names,
610610
layer_names=layer_names,
611+
mask_creator=self._mask_creator,
611612
global_sparsity=self._global_sparsity,
612613
)
613614

tests/sparseml/pytorch/optim/test_modifier_pruning.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
GlobalMagnitudePruningModifier,
2424
GMPruningModifier,
2525
MagnitudePruningModifier,
26+
load_mask_creator,
2627
)
2728
from tests.sparseml.pytorch.helpers import LinearNet
2829
from tests.sparseml.pytorch.optim.test_modifier import (
@@ -262,6 +263,10 @@ def test_lifecycle(
262263
optimizer = optim_lambda(model)
263264
self.initialize_helper(modifier, model, optimizer)
264265
assert modifier.applied_sparsity is None
266+
assert type(load_mask_creator(modifier._mask_type)) == type( # noqa: E721
267+
modifier._mask_creator
268+
)
269+
assert modifier._mask_creator == modifier._module_masks._mask_creator
265270

266271
# check sparsity is not set before
267272
for epoch in range(int(modifier.start_epoch)):

0 commit comments

Comments
 (0)