|
1 | 1 | import os |
2 | 2 | import weakref |
3 | | -from collections.abc import Generator |
4 | 3 | from functools import wraps |
5 | 4 |
|
6 | 5 | import torch |
7 | 6 | from accelerate.accelerator import get_state_dict_offloaded_model |
8 | 7 | from compressed_tensors import ( |
9 | 8 | ModelCompressor, |
10 | 9 | SparsityCompressionConfig, |
11 | | - delete_offload_parameter, |
12 | | - has_offloaded_params, |
13 | | - register_offload_parameter, |
14 | 10 | ) |
15 | 11 | from compressed_tensors.config import CompressionFormat |
16 | 12 | from loguru import logger |
|
24 | 20 | from llmcompressor.transformers.utils import RECIPE_FILE_NAME |
25 | 21 | from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path |
26 | 22 |
|
27 | | -__all__ = ["modify_save_pretrained", "untie_word_embeddings"] |
| 23 | +__all__ = ["modify_save_pretrained"] |
28 | 24 |
|
29 | 25 |
|
30 | 26 | def modify_save_pretrained(model: PreTrainedModel): |
@@ -117,119 +113,6 @@ def save_pretrained_wrapper( |
117 | 113 | model.save_pretrained = save_pretrained_compressed(model.save_pretrained) |
118 | 114 |
|
119 | 115 |
|
120 | | -def untie_word_embeddings(model: PreTrainedModel): |
121 | | - """ |
122 | | - Patches bug where HF transformers will fail to untie weights under specific |
123 | | - circumstances (https://github.com/huggingface/transformers/issues/33689). |
124 | | -
|
125 | | - This function detects those cases and unties the tensors if applicable |
126 | | -
|
127 | | - :param model: model to fix |
128 | | - """ |
129 | | - try: |
130 | | - input_embed = model.get_input_embeddings() |
131 | | - output_embed = model.get_output_embeddings() |
132 | | - except NotImplementedError as e: |
133 | | - logger.warning( |
134 | | - f"cannot untie model of type {model.__class__} which doesn't have " |
135 | | - f"get_input_embeddings and get_output_embeddings implmented\n{e}" |
136 | | - ) |
137 | | - return |
138 | | - |
139 | | - for module in (input_embed, output_embed): |
140 | | - if module is None or not hasattr(module, "weight"): |
141 | | - logger.warning(f"Cannot untie {module} which does not have weight param") |
142 | | - continue |
143 | | - |
144 | | - # this could be replaced by a `get_offloaded_parameter` util |
145 | | - if not has_offloaded_params(module): |
146 | | - untied_data = module.weight.data.clone() |
147 | | - else: |
148 | | - untied_data = module._hf_hook.weights_map["weight"].clone() |
149 | | - |
150 | | - requires_grad = module.weight.requires_grad |
151 | | - new_parameter = torch.nn.Parameter(untied_data, requires_grad=requires_grad) |
152 | | - delete_offload_parameter(module, "weight") |
153 | | - register_offload_parameter(module, "weight", new_parameter) |
154 | | - |
155 | | - if hasattr(model.config, "tie_word_embeddings"): |
156 | | - model.config.tie_word_embeddings = False |
157 | | - |
158 | | - |
159 | | -def _get_embeddings_or_warn( |
160 | | - model: torch.nn.Module, |
161 | | -) -> tuple[torch.nn.Module | None, torch.nn.Module | None]: |
162 | | - if not ( |
163 | | - hasattr(model, "get_input_embeddings") |
164 | | - and hasattr(model, "get_output_embeddings") |
165 | | - ): |
166 | | - logger.warning( |
167 | | - f"{model.__class__} doesn't have attribute get_input_embeddings and" |
168 | | - " get_output_embeddings implemented." |
169 | | - "\nThis can cause" |
170 | | - " problems when quantizing layers with shared weights" |
171 | | - ) |
172 | | - return None, None |
173 | | - |
174 | | - try: |
175 | | - input_embeddings, output_embeddings = ( |
176 | | - model.get_input_embeddings(), |
177 | | - model.get_output_embeddings(), |
178 | | - ) |
179 | | - except NotImplementedError as e: |
180 | | - logger.warning( |
181 | | - f"{model.__class__} doesn't have get_input_embeddings and " |
182 | | - "get_output_embeddings implemented." |
183 | | - "\nThis can cause" |
184 | | - " problems when quantizing layers with shared weights" |
185 | | - f"\n{e}" |
186 | | - ) |
187 | | - return None, None |
188 | | - |
189 | | - if not ( |
190 | | - isinstance(input_embeddings, torch.nn.Module) |
191 | | - and isinstance(output_embeddings, torch.nn.Module) |
192 | | - ): |
193 | | - logger.warning( |
194 | | - f"expected modules from {model.__class__} get_input_embeddings and" |
195 | | - f" get_output_embeddings but got {type(input_embeddings)}" |
196 | | - f" and {type(output_embeddings)}." |
197 | | - "\nThis can cause" |
198 | | - " problems when quantizing layers with shared weights" |
199 | | - ) |
200 | | - return None, None |
201 | | - return input_embeddings, output_embeddings |
202 | | - |
203 | | - |
204 | | -def untie_if_target_shared_embedding( |
205 | | - model: torch.nn.Module, matched_module_generator: Generator[torch.nn.Module] |
206 | | -): |
207 | | - """ |
208 | | - Helper method that checks for shared input/output embedding and unties them |
209 | | - if either shows up in the matched_module_generator |
210 | | -
|
211 | | - :param model: model to untie if embeddings are shared and targeted by |
212 | | - matched_module_generator |
213 | | - :param matched_module_generator: Generator of all modules (not names) which |
214 | | - will be modified by quantization or transformation |
215 | | - """ |
216 | | - input_embeddings, output_embeddings = _get_embeddings_or_warn(model) |
217 | | - |
218 | | - if None in (input_embeddings, output_embeddings): # if couldn't find embeddings |
219 | | - return |
220 | | - |
221 | | - if ( |
222 | | - input_embeddings.weight is not output_embeddings.weight |
223 | | - ): # if not shared, can ignore |
224 | | - return |
225 | | - |
226 | | - # if shared, check if either is targeted |
227 | | - for module in matched_module_generator: |
228 | | - if module in (input_embeddings, output_embeddings): |
229 | | - untie_word_embeddings(model) |
230 | | - return |
231 | | - |
232 | | - |
233 | 116 | def get_model_compressor( |
234 | 117 | model: torch.nn.Module, |
235 | 118 | sparsity_config: SparsityCompressionConfig | None = None, |
|
0 commit comments