|
50 | 50 | get_offloaded_device, |
51 | 51 | get_safetensors_folder, |
52 | 52 | has_offloaded_params, |
53 | | - merge_names, |
54 | 53 | register_offload_parameter, |
55 | 54 | update_parameter_data, |
56 | 55 | ) |
@@ -321,112 +320,6 @@ def __init__( |
321 | 320 | format, config=quantization_config |
322 | 321 | ) |
323 | 322 |
|
324 | | - # ----- used by hf quantizer ----- # |
325 | | - |
326 | | - def get_missing_module_keys(self, model: Module) -> List[str]: |
327 | | - """ |
328 | | - Identifies the expected missing weight keys in the compressed state_dict. |
329 | | -
|
330 | | - When a model undergoes sparsity or quantization compression, certain |
331 | | - weight tensors may be absent from the checkpoint by virtue of compression. |
332 | | - This function determines which weight keys are missing based on the |
333 | | - applied compression techniques. |
334 | | -
|
335 | | - :param model: The PyTorch model to check for missing keys. |
336 | | - :return: A list of missing keys expected in the compressed state_dict. |
337 | | - """ |
338 | | - missing_keys = set() |
339 | | - |
340 | | - # Determine missing keys due to sparsity compression |
341 | | - if ( |
342 | | - self.sparsity_compressor |
343 | | - and self.sparsity_config.format != CompressionFormat.dense.value |
344 | | - ): |
345 | | - sparse_targets = match_named_modules( |
346 | | - model=model, |
347 | | - targets=self.sparsity_config.targets, |
348 | | - ignore=self.sparsity_config.ignore, |
349 | | - ) |
350 | | - |
351 | | - missing_keys.update( |
352 | | - merge_names(target_name, "weight") |
353 | | - for target_name, _module in sparse_targets |
354 | | - ) |
355 | | - |
356 | | - # Determine missing keys due to pack quantization |
357 | | - if ( |
358 | | - self.quantization_compressor |
359 | | - and self.quantization_config.format |
360 | | - == CompressionFormat.pack_quantized.value |
361 | | - ): |
362 | | - for scheme in self.quantization_config.config_groups.values(): |
363 | | - quant_targets = match_named_modules( |
364 | | - model=model, |
365 | | - targets=scheme.targets, |
366 | | - ignore=self.quantization_config.ignore, |
367 | | - ) |
368 | | - missing_keys.update( |
369 | | - merge_names(target_name, "weight") |
370 | | - for target_name, _module in quant_targets |
371 | | - ) |
372 | | - |
373 | | - return list(missing_keys) |
374 | | - |
375 | | - def get_unexpected_file_keys(self, model: Module) -> List[str]: |
376 | | - """ |
377 | | - Identifies extra keys introduced by the compression process in the |
378 | | - compressed state_dict that are not expected by the model graph. |
379 | | -
|
380 | | - During sparsity or quantization compression, additional metadata or |
381 | | - auxiliary parameters may be stored in the checkpoint, which do not |
382 | | - correspond to any parameter in the original model. These keys are |
383 | | - typically introduced to support the reconstruction of compressed weights. |
384 | | -
|
385 | | - For example, Sparse24Bitmask compression may introduce keys such as |
386 | | - 'compressed', 'bitmask', and 'shape' in the checkpoint, which are |
387 | | - not part of the original model parameters. |
388 | | -
|
389 | | - :param model: The PyTorch model to check for unexpected keys. |
390 | | - :return: A list of extra keys introduced by the compression process |
391 | | - that are not expected by the model. |
392 | | - """ |
393 | | - |
394 | | - unexpected_keys = set() |
395 | | - |
396 | | - # Identify unexpected keys from sparsity compression |
397 | | - if ( |
398 | | - self.sparsity_compressor |
399 | | - and self.sparsity_config.format != CompressionFormat.dense.value |
400 | | - ): |
401 | | - sparse_targets = match_named_modules( |
402 | | - model=model, |
403 | | - targets=self.sparsity_config.targets, |
404 | | - ignore=self.sparsity_config.ignore, |
405 | | - ) |
406 | | - unexpected_keys.update( |
407 | | - merge_names(target_name, param) |
408 | | - for target_name, _module in sparse_targets |
409 | | - for param in self.sparsity_compressor.compression_param_names |
410 | | - ) |
411 | | - |
412 | | - # Identify unexpected keys from quantization compression |
413 | | - if self.quantization_compressor: |
414 | | - for scheme in self.quantization_config.config_groups.values(): |
415 | | - quant_targets = match_named_modules( |
416 | | - model=model, |
417 | | - targets=scheme.targets, |
418 | | - ignore=self.quantization_config.ignore, |
419 | | - ) |
420 | | - for quant_compressor in self.quantization_compressor.values(): |
421 | | - unexpected_keys.update( |
422 | | - merge_names(target_name, param) |
423 | | - for target_name, _module in quant_targets |
424 | | - for param in quant_compressor.compression_param_names |
425 | | - if param != "weight" |
426 | | - ) |
427 | | - |
428 | | - return list(unexpected_keys) |
429 | | - |
430 | 323 | # ----- model memory compression/decompression pathways ----- # |
431 | 324 |
|
432 | 325 | def compress_model(self, model: Module): |
|
0 commit comments