|
7 | 7 | from compressed_tensors.utils import ( |
8 | 8 | align_modules, |
9 | 9 | get_execution_device, |
| 10 | + match_modules_set, |
10 | 11 | match_named_modules, |
11 | 12 | update_offload_parameter, |
12 | 13 | ) |
@@ -312,68 +313,78 @@ def _set_resolved_mappings(self, model: Module) -> None: |
312 | 313 | into ResolvedMapping objects, resolving regular expressions. |
313 | 314 | Result is stored in _resolved_mappings. |
314 | 315 |
|
315 | | - For each activation in the mapping list, we find the corresponding weight to |
316 | | - balance by searching for the longest substring. For instance, if our balance |
317 | | - weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we |
318 | | - would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and |
319 | | - repeat for model.layer.1 and so on |
| 316 | + Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers) |
| 317 | + that belong together in the model architecture. |
320 | 318 | """ |
| 319 | + # Build a module-to-name mapping for efficient lookups |
| 320 | + module_to_name = {module: name for name, module in model.named_modules()} |
| 321 | + |
321 | 322 | resolved_mappings: list[ResolvedMapping] = [] |
322 | 323 | for mapping_idx, mapping in enumerate(self.mappings): |
323 | 324 | num_skipped_mappings = 0 |
324 | 325 |
|
325 | | - for smooth_name, smooth_layer in ( |
| 326 | + # Use match_modules_set to find coherent sets of modules |
| 327 | + target_patterns = (mapping.smooth_layer, *mapping.balance_layers) |
| 328 | + |
| 329 | + for modules_set in ( |
326 | 330 | pbar := tqdm( |
327 | | - match_named_modules(model, [mapping.smooth_layer], self.ignore) |
| 331 | + match_modules_set(model, target_patterns, self.ignore) |
328 | 332 | ) |
329 | 333 | ): |
330 | 334 | pbar.set_description( |
331 | 335 | f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}" |
332 | 336 | f" ({num_skipped_mappings} skipped)" |
333 | 337 | ) |
334 | 338 |
|
335 | | - smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) |
336 | | - smooth_parent = get_layer_by_name(smooth_parent_name, model) |
| 339 | + # Unpack the matched set: first is smooth_layer, rest are balance_layers |
| 340 | + smooth_layer = modules_set[0] |
| 341 | + all_balance_layers = list(modules_set[1:]) |
337 | 342 |
|
338 | | - balance_layers, balance_names = [], [] |
339 | | - for balance_regex in mapping.balance_layers: |
340 | | - # find the submodules that match the activation layer |
341 | | - for balance_suffix, balance_layer in match_named_modules( |
342 | | - smooth_parent, [balance_regex], self.ignore |
343 | | - ): |
344 | | - balance_name = f"{smooth_parent_name}.{balance_suffix}" |
345 | | - |
346 | | - # exclude v_proj->o_proj mappings whose shapes are incompatible |
347 | | - # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 |
348 | | - if ( |
349 | | - isinstance(smooth_layer, torch.nn.Linear) |
350 | | - and isinstance(balance_layer, torch.nn.Linear) |
351 | | - and balance_name.endswith(".o_proj") |
352 | | - and ( |
353 | | - ( |
354 | | - smooth_name.endswith(".v_proj") |
355 | | - and smooth_layer.out_features |
356 | | - != balance_layer.in_features |
357 | | - ) |
358 | | - or ( |
359 | | - smooth_name.endswith(".qkv_proj") |
360 | | - and smooth_layer.out_features |
361 | | - != 3 * balance_layer.in_features |
362 | | - ) |
| 343 | + # Get names using the pre-built mapping |
| 344 | + smooth_name = module_to_name.get(smooth_layer) |
| 345 | + if smooth_name is None: |
| 346 | + continue |
| 347 | + |
| 348 | + # Filter balance layers, skipping incompatible ones |
| 349 | + balance_layers = [] |
| 350 | + balance_names = [] |
| 351 | + |
| 352 | + for balance_layer in all_balance_layers: |
| 353 | + balance_name = module_to_name.get(balance_layer) |
| 354 | + if balance_name is None: |
| 355 | + continue |
| 356 | + |
| 357 | + # exclude v_proj->o_proj mappings whose shapes are incompatible |
| 358 | + # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 |
| 359 | + if ( |
| 360 | + isinstance(smooth_layer, torch.nn.Linear) |
| 361 | + and isinstance(balance_layer, torch.nn.Linear) |
| 362 | + and balance_name.endswith(".o_proj") |
| 363 | + and ( |
| 364 | + ( |
| 365 | + smooth_name.endswith(".v_proj") |
| 366 | + and smooth_layer.out_features |
| 367 | + != balance_layer.in_features |
| 368 | + ) |
| 369 | + or ( |
| 370 | + smooth_name.endswith(".qkv_proj") |
| 371 | + and smooth_layer.out_features |
| 372 | + != 3 * balance_layer.in_features |
363 | 373 | ) |
364 | | - ): |
365 | | - num_skipped_mappings += 1 |
366 | | - continue |
| 374 | + ) |
| 375 | + ): |
| 376 | + num_skipped_mappings += 1 |
| 377 | + continue |
367 | 378 |
|
368 | | - balance_layers.append(balance_layer) |
369 | | - balance_names.append(balance_name) |
| 379 | + balance_layers.append(balance_layer) |
| 380 | + balance_names.append(balance_name) |
370 | 381 |
|
371 | 382 | if len(balance_layers) == 0: |
372 | 383 | continue |
373 | 384 |
|
374 | | - elif len(balance_layers) == 1: |
| 385 | + if len(balance_layers) == 1: |
375 | 386 | # for single balance layer, parent is the balance layer |
376 | | - parent_name, parent = balance_name, balance_layer |
| 387 | + parent_name, parent = balance_names[0], balance_layers[0] |
377 | 388 | else: |
378 | 389 | # for multiple balance layers, find lowest common parent |
379 | 390 | parent_name, parent = get_lowest_common_parent(balance_names, model) |
|
0 commit comments