|
16 | 16 | Utility / helper functions |
17 | 17 | """ |
18 | 18 |
|
| 19 | +import logging |
19 | 20 | import random |
20 | 21 | import re |
21 | 22 | import warnings |
|
84 | 85 | "get_layer_param", |
85 | 86 | "set_deterministic_seeds", |
86 | 87 | "torch_distributed_zero_first", |
| 88 | + "thin_model_from_checkpoint", |
87 | 89 | ] |
88 | 90 |
|
89 | 91 |
|
| 92 | +_LOGGER = logging.getLogger(__name__) |
| 93 | + |
| 94 | + |
90 | 95 | ############################## |
91 | 96 | # |
92 | 97 | # pytorch device helpers |
@@ -957,3 +962,99 @@ def torch_distributed_zero_first(local_rank: int): |
957 | 962 | yield |
958 | 963 | if local_rank == 0: |
959 | 964 | torch.distributed.barrier() |
| 965 | + |
| 966 | + |
| 967 | +def thin_model_from_checkpoint(model: Module, state_dict: Dict[str, Any]): |
| 968 | + """ |
| 969 | + Updates any Linear/Conv/BN layers in the given model to match their |
| 970 | + respective shapes in the given state dict. Purpose of compatibility |
| 971 | + when loading weight for a model from a checkpoint of the same architecture |
| 972 | + but with potentially structured thinning applied. Note that this function |
| 973 | + has no guarantees on accuracy, will only resize model parameters for |
| 974 | + loading compatibility. All adjustments done in place |
| 975 | +
|
| 976 | + :param model: model to potentially adjust parameter shapes of |
| 977 | + :param state_dict: state dict to infer parameter shapes from |
| 978 | + """ |
| 979 | + first_thinned = True |
| 980 | + for param_name, checkpoint_tens in state_dict.items(): |
| 981 | + if not param_name.endswith(".weight"): |
| 982 | + continue # only deal with weight params of modules |
| 983 | + layer_name = param_name[:-7] |
| 984 | + layer = get_layer(layer_name, model) |
| 985 | + |
| 986 | + if not hasattr(layer, "weight") or ( |
| 987 | + layer.weight.shape == checkpoint_tens.shape |
| 988 | + ): |
| 989 | + continue # skip if there is no update to shape |
| 990 | + |
| 991 | + # quick check that target layer is some flavor of FC/Conv/BN |
| 992 | + layer_type = layer.__class__.__name__ |
| 993 | + if not ( |
| 994 | + "Linear" not in layer_type |
| 995 | + or "Conv" not in layer_type |
| 996 | + or ("BatchNorm" not in layer_type) |
| 997 | + ): |
| 998 | + continue |
| 999 | + |
| 1000 | + orig_shape = layer.weight.shape |
| 1001 | + target_shape = checkpoint_tens.shape |
| 1002 | + |
| 1003 | + # update weight param + grad |
| 1004 | + if len(target_shape) > 1: |
| 1005 | + layer.weight.data = layer.weight.data[ |
| 1006 | + : target_shape[0], : target_shape[1], ... |
| 1007 | + ] |
| 1008 | + if layer.weight.grad is not None: |
| 1009 | + layer.weight.grad = layer.weight.grad[ |
| 1010 | + : target_shape[0], : target_shape[1], ... |
| 1011 | + ] |
| 1012 | + else: |
| 1013 | + layer.weight.data = layer.weight.data[: target_shape[0]] |
| 1014 | + if layer.weight.grad is not None: |
| 1015 | + layer.weight.grad = layer.weight.grad[: target_shape[0]] |
| 1016 | + |
| 1017 | + # update bias param + grad |
| 1018 | + if hasattr(layer, "bias") and layer.bias is not None: |
| 1019 | + # target output channels should be the first dim of target shape |
| 1020 | + layer.bias.data = layer.bias.data[: target_shape[0]] |
| 1021 | + if layer.bias.grad is not None: |
| 1022 | + layer.bias.grad = layer.bias.grad[: target_shape[0]] |
| 1023 | + |
| 1024 | + # update layer attributes |
| 1025 | + if "BatchNorm" in layer_type: |
| 1026 | + if hasattr(layer, "num_features"): |
| 1027 | + layer.num_features = layer.weight.size(0) |
| 1028 | + # BN running mean and var are not stored as Parameters |
| 1029 | + if hasattr(layer, "running_mean"): |
| 1030 | + layer.running_mean = torch.zeros_like(layer.running_mean)[ |
| 1031 | + : target_shape[0] |
| 1032 | + ] |
| 1033 | + if hasattr(layer, "running_var"): |
| 1034 | + layer.running_var = torch.zeros_like(layer.running_var)[ |
| 1035 | + : target_shape[0] |
| 1036 | + ] |
| 1037 | + |
| 1038 | + if "Linear" in layer_type: |
| 1039 | + if hasattr(layer, "out_features"): |
| 1040 | + layer.out_features = layer.weight.shape[0] |
| 1041 | + if hasattr(layer, "in_features"): |
| 1042 | + layer.in_features = layer.weight.shape[1] |
| 1043 | + |
| 1044 | + if "Conv" in layer_type: |
| 1045 | + if hasattr(layer, "out_channels"): |
| 1046 | + layer.out_channels = layer.weight.shape[0] |
| 1047 | + if hasattr(layer, "in_channels"): |
| 1048 | + layer.in_channels = layer.weight.shape[1] |
| 1049 | + if hasattr(layer, "groups") and layer.groups > 1: |
| 1050 | + layer.groups = layer.weight.shape[0] // layer.weight.shape[1] |
| 1051 | + |
| 1052 | + if first_thinned: |
| 1053 | + _LOGGER.info( |
| 1054 | + "Thinning module layers for compatibility with given state dict:" |
| 1055 | + ) |
| 1056 | + first_thinned = False |
| 1057 | + _LOGGER.info( |
| 1058 | + f"Thinned layer {layer_name} from shape {orig_shape} to " |
| 1059 | + f"{layer.weight.shape}" |
| 1060 | + ) |
0 commit comments