|
21 | 21 | from compressed_tensors.utils import ( |
22 | 22 | InternalModule, |
23 | 23 | is_match, |
| 24 | + is_narrow_match, |
24 | 25 | match_modules_set, |
25 | 26 | match_named_modules, |
26 | 27 | match_named_parameters, |
@@ -500,6 +501,86 @@ class InternalLinear(InternalModule, nn.Linear): |
500 | 501 | assert len(matches) == 0 |
501 | 502 |
|
502 | 503 |
|
| 504 | +class TestIsNarrowMatch: |
| 505 | + def test_narrow_match_true_child_only(self): |
| 506 | + """ |
| 507 | + Target matches the child module name but NOT its parent name. |
| 508 | + Should return True. |
| 509 | + """ |
| 510 | + model = DummyModel() |
| 511 | + name = "transformer.layers.0.self_attn.q_proj" |
| 512 | + # Matches "...q_proj" but not "...self_attn" |
| 513 | + target = r"re:.*q_proj$" |
| 514 | + |
| 515 | + assert is_narrow_match(model, target, name) |
| 516 | + |
| 517 | + def test_narrow_match_false_when_parent_also_matches(self): |
| 518 | + """ |
| 519 | + Target matches both the child and its parent name. |
| 520 | + Should return False because it's not a 'narrow' match. |
| 521 | + """ |
| 522 | + model = DummyModel() |
| 523 | + name = "transformer.layers.0.self_attn.q_proj" |
| 524 | + # Broad target that also matches the parent "transformer.layers.0.self_attn" |
| 525 | + target = r"re:transformer\.layers\.0\..*" |
| 526 | + |
| 527 | + assert not is_narrow_match(model, target, name) |
| 528 | + |
| 529 | + def test_narrow_match_false_when_neither_matches(self): |
| 530 | + """ |
| 531 | + Target matches neither the child nor the parent. |
| 532 | + Should return False. |
| 533 | + """ |
| 534 | + model = DummyModel() |
| 535 | + name = "transformer.layers.0.self_attn.q_proj" |
| 536 | + target = r"re:this_does_not_exist$" |
| 537 | + |
| 538 | + assert not is_narrow_match(model, target, name) |
| 539 | + |
| 540 | + def test_narrow_match_iterable_targets_any_true(self): |
| 541 | + """ |
| 542 | + With multiple targets: if any target narrowly matches the child, |
| 543 | + the function should return True. |
| 544 | + """ |
| 545 | + model = DummyModel() |
| 546 | + name = "transformer.layers.0.self_attn.q_proj" |
| 547 | + # First target is broad (matches both child & parent -> narrow False), |
| 548 | + # second target is narrow (matches child only -> narrow True). |
| 549 | + targets = [ |
| 550 | + r"re:transformer\.layers\.0\..*", |
| 551 | + r"re:.*q_proj$", |
| 552 | + ] |
| 553 | + |
| 554 | + assert is_narrow_match(model, targets, name) |
| 555 | + |
| 556 | + def test_narrow_match_with_explicit_module_argument(self): |
| 557 | + """ |
| 558 | + Passing the module explicitly should behave the same as when it's |
| 559 | + retrieved from the model by name. |
| 560 | + """ |
| 561 | + model = DummyModel() |
| 562 | + name = "transformer.layers.0.self_attn.q_proj" |
| 563 | + module = model.get_submodule(name) |
| 564 | + target = r"re:.*q_proj$" |
| 565 | + |
| 566 | + # Both ways should be True |
| 567 | + assert is_narrow_match(model, target, name) |
| 568 | + assert is_narrow_match(model, target, name, module=module) |
| 569 | + |
| 570 | + def test_narrow_match_top_level_behavior_documented(self): |
| 571 | + """ |
| 572 | + (Behavior check) For a top-level module name without a dot, the current |
| 573 | + implementation derives parent_name == name, so parent==child. |
| 574 | + Then 'narrow' cannot be True because parent match mirrors child match. |
| 575 | + This test documents current behavior to guard against regressions. |
| 576 | + """ |
| 577 | + model = DummyModel() |
| 578 | + name = "layer1" # top-level module in DummyModel |
| 579 | + target = r"re:^layer1$" |
| 580 | + |
| 581 | + assert not is_narrow_match(model, target, name) |
| 582 | + |
| 583 | + |
503 | 584 | class TestIntegration: |
504 | 585 | """Integration tests combining multiple functions""" |
505 | 586 |
|
|
0 commit comments