Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 589f5f9

Browse files
authored
add pass to thin from checkpoint while loading torch models (#690)
* add pass to thin from checkpoint while loading torch models * Update src/sparseml/pytorch/utils/helpers.py
1 parent 054bb61 commit 589f5f9

File tree

3 files changed

+149
-1
lines changed

3 files changed

+149
-1
lines changed

src/sparseml/pytorch/utils/helpers.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Utility / helper functions
1717
"""
1818

19+
import logging
1920
import random
2021
import re
2122
import warnings
@@ -84,9 +85,13 @@
8485
"get_layer_param",
8586
"set_deterministic_seeds",
8687
"torch_distributed_zero_first",
88+
"thin_model_from_checkpoint",
8789
]
8890

8991

92+
_LOGGER = logging.getLogger(__name__)
93+
94+
9095
##############################
9196
#
9297
# pytorch device helpers
@@ -957,3 +962,99 @@ def torch_distributed_zero_first(local_rank: int):
957962
yield
958963
if local_rank == 0:
959964
torch.distributed.barrier()
965+
966+
967+
def thin_model_from_checkpoint(model: Module, state_dict: Dict[str, Any]):
968+
"""
969+
Updates any Linear/Conv/BN layers in the given model to match their
970+
respective shapes in the given state dict. Purpose of compatibility
971+
when loading weight for a model from a checkpoint of the same architecture
972+
but with potentially structured thinning applied. Note that this function
973+
has no guarantees on accuracy, will only resize model parameters for
974+
loading compatibility. All adjustments done in place
975+
976+
:param model: model to potentially adjust parameter shapes of
977+
:param state_dict: state dict to infer parameter shapes from
978+
"""
979+
first_thinned = True
980+
for param_name, checkpoint_tens in state_dict.items():
981+
if not param_name.endswith(".weight"):
982+
continue # only deal with weight params of modules
983+
layer_name = param_name[:-7]
984+
layer = get_layer(layer_name, model)
985+
986+
if not hasattr(layer, "weight") or (
987+
layer.weight.shape == checkpoint_tens.shape
988+
):
989+
continue # skip if there is no update to shape
990+
991+
# quick check that target layer is some flavor of FC/Conv/BN
992+
layer_type = layer.__class__.__name__
993+
if not (
994+
"Linear" not in layer_type
995+
or "Conv" not in layer_type
996+
or ("BatchNorm" not in layer_type)
997+
):
998+
continue
999+
1000+
orig_shape = layer.weight.shape
1001+
target_shape = checkpoint_tens.shape
1002+
1003+
# update weight param + grad
1004+
if len(target_shape) > 1:
1005+
layer.weight.data = layer.weight.data[
1006+
: target_shape[0], : target_shape[1], ...
1007+
]
1008+
if layer.weight.grad is not None:
1009+
layer.weight.grad = layer.weight.grad[
1010+
: target_shape[0], : target_shape[1], ...
1011+
]
1012+
else:
1013+
layer.weight.data = layer.weight.data[: target_shape[0]]
1014+
if layer.weight.grad is not None:
1015+
layer.weight.grad = layer.weight.grad[: target_shape[0]]
1016+
1017+
# update bias param + grad
1018+
if hasattr(layer, "bias") and layer.bias is not None:
1019+
# target output channels should be the first dim of target shape
1020+
layer.bias.data = layer.bias.data[: target_shape[0]]
1021+
if layer.bias.grad is not None:
1022+
layer.bias.grad = layer.bias.grad[: target_shape[0]]
1023+
1024+
# update layer attributes
1025+
if "BatchNorm" in layer_type:
1026+
if hasattr(layer, "num_features"):
1027+
layer.num_features = layer.weight.size(0)
1028+
# BN running mean and var are not stored as Parameters
1029+
if hasattr(layer, "running_mean"):
1030+
layer.running_mean = torch.zeros_like(layer.running_mean)[
1031+
: target_shape[0]
1032+
]
1033+
if hasattr(layer, "running_var"):
1034+
layer.running_var = torch.zeros_like(layer.running_var)[
1035+
: target_shape[0]
1036+
]
1037+
1038+
if "Linear" in layer_type:
1039+
if hasattr(layer, "out_features"):
1040+
layer.out_features = layer.weight.shape[0]
1041+
if hasattr(layer, "in_features"):
1042+
layer.in_features = layer.weight.shape[1]
1043+
1044+
if "Conv" in layer_type:
1045+
if hasattr(layer, "out_channels"):
1046+
layer.out_channels = layer.weight.shape[0]
1047+
if hasattr(layer, "in_channels"):
1048+
layer.in_channels = layer.weight.shape[1]
1049+
if hasattr(layer, "groups") and layer.groups > 1:
1050+
layer.groups = layer.weight.shape[0] // layer.weight.shape[1]
1051+
1052+
if first_thinned:
1053+
_LOGGER.info(
1054+
"Thinning module layers for compatibility with given state dict:"
1055+
)
1056+
first_thinned = False
1057+
_LOGGER.info(
1058+
f"Thinned layer {layer_name} from shape {orig_shape} to "
1059+
f"{layer.weight.shape}"
1060+
)

src/sparseml/pytorch/utils/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch.nn import DataParallel, Module
2424
from torch.optim.optimizer import Optimizer
2525

26+
from sparseml.pytorch.utils.helpers import thin_model_from_checkpoint
2627
from sparseml.utils.helpers import create_parent_dirs
2728
from sparsezoo import Zoo
2829

@@ -117,6 +118,9 @@ def load_model(
117118
elif ignore in model_dict and ignore not in current_dict:
118119
del model_dict[ignore]
119120

121+
# safety pass for updating layer param shapes when loading a thinned model
122+
thin_model_from_checkpoint(model, model_dict)
123+
120124
model.load_state_dict(model_dict, strict)
121125

122126

tests/sparseml/pytorch/utils/test_helpers.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytest
2222
import torch
2323
from torch import Tensor
24-
from torch.nn import Linear, Module, ReLU, Sequential
24+
from torch.nn import BatchNorm2d, Conv2d, Linear, Module, ReLU, Sequential
2525
from torch.optim import SGD
2626
from torch.utils.data import DataLoader
2727

@@ -43,6 +43,7 @@
4343
tensors_module_forward,
4444
tensors_to_device,
4545
tensors_to_precision,
46+
thin_model_from_checkpoint,
4647
)
4748
from tests.sparseml.pytorch.helpers import LinearNet
4849

@@ -837,3 +838,45 @@ def test_tensor_sample_cuda(tensor, size, dim, expected_shape):
837838
def test_mask_difference(old_mask, new_mask, expected_diff):
838839
diff = mask_difference(old_mask, new_mask)
839840
assert torch.sum((diff - expected_diff).abs()) < sys.float_info.epsilon
841+
842+
843+
@pytest.mark.skipif(
844+
os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False),
845+
reason="Skipping pytorch tests",
846+
)
847+
@pytest.mark.parametrize(
848+
"model,state_dict,test_input",
849+
[
850+
(
851+
Sequential(Conv2d(3, 16, (1, 1)), BatchNorm2d(16), Conv2d(16, 16, (1, 1))),
852+
{
853+
"0.weight": torch.randn(8, 3, 1, 1),
854+
"0.bias": torch.randn(8),
855+
"1.weight": torch.randn(8),
856+
"1.bias": torch.randn(8),
857+
"1.running_mean": torch.randn(8),
858+
"1.running_var": torch.randn(8),
859+
"2.weight": torch.randn(12, 8, 1, 1),
860+
"2.bias": torch.randn(12),
861+
},
862+
torch.randn(2, 3, 16, 16),
863+
),
864+
(
865+
Sequential(Linear(8, 12), Linear(12, 16)),
866+
{
867+
"0.weight": torch.randn(7, 8),
868+
"0.bias": torch.randn(7),
869+
"1.weight": torch.randn(9, 7),
870+
"1.bias": torch.randn(9),
871+
},
872+
torch.randn(5, 8),
873+
),
874+
],
875+
)
876+
def test_thin_model_from_checkpoint(model, state_dict, test_input):
877+
with pytest.raises(RuntimeError):
878+
model.load_state_dict(state_dict)
879+
880+
thin_model_from_checkpoint(model, state_dict)
881+
model.load_state_dict(state_dict, strict=True)
882+
assert isinstance(model(test_input), Tensor)

0 commit comments

Comments
 (0)