Skip to content

Commit 5229f86

Browse files
committed
is_narrow_match tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent f46dadf commit 5229f86

File tree

2 files changed

+103
-4
lines changed

2 files changed

+103
-4
lines changed

src/compressed_tensors/utils/match.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,31 @@ def is_match(
261261
)
262262

263263

264-
def is_narrow_match(model: torch.nn.Module, targets: Iterable[str], name: str) -> bool:
265-
module = model.get_submodule(name)
264+
def is_narrow_match(
265+
model: torch.nn.Module,
266+
targets: Union[str, Iterable[str]],
267+
name: str,
268+
module: Optional[torch.nn.Module] = None,
269+
) -> bool:
270+
"""
271+
Checks if any of the targets narrowly match the module. A target narrowly matches
272+
a module if the target matches the module, but does not match the module's parent
273+
274+
:param model: model containing both module and its parent
275+
:param targets: target strings, potentially containing "re:" prefixes
276+
:param name: name of module to match
277+
:param module: module to match. If none is provided, then get module from model
278+
:return: True if any of the targets narrow match the module
279+
"""
280+
targets = [targets] if isinstance(targets, str) else targets
281+
module = module if module is not None else model.get_submodule(name)
282+
266283
parent_name = name.rsplit(".", 1)[0]
267284
parent = model.get_submodule(parent_name)
268285

269-
return is_match(name, module, targets) and not is_match(
270-
parent_name, parent, targets
286+
return any(
287+
is_match(name, module, target) and not is_match(parent_name, parent, target)
288+
for target in targets
271289
)
272290

273291

tests/test_utils/test_match.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from compressed_tensors.utils import (
2222
InternalModule,
2323
is_match,
24+
is_narrow_match,
2425
match_modules_set,
2526
match_named_modules,
2627
match_named_parameters,
@@ -500,6 +501,86 @@ class InternalLinear(InternalModule, nn.Linear):
500501
assert len(matches) == 0
501502

502503

504+
class TestIsNarrowMatch:
505+
def test_narrow_match_true_child_only(self):
506+
"""
507+
Target matches the child module name but NOT its parent name.
508+
Should return True.
509+
"""
510+
model = DummyModel()
511+
name = "transformer.layers.0.self_attn.q_proj"
512+
# Matches "...q_proj" but not "...self_attn"
513+
target = r"re:.*q_proj$"
514+
515+
assert is_narrow_match(model, target, name)
516+
517+
def test_narrow_match_false_when_parent_also_matches(self):
518+
"""
519+
Target matches both the child and its parent name.
520+
Should return False because it's not a 'narrow' match.
521+
"""
522+
model = DummyModel()
523+
name = "transformer.layers.0.self_attn.q_proj"
524+
# Broad target that also matches the parent "transformer.layers.0.self_attn"
525+
target = r"re:transformer\.layers\.0\..*"
526+
527+
assert not is_narrow_match(model, target, name)
528+
529+
def test_narrow_match_false_when_neither_matches(self):
530+
"""
531+
Target matches neither the child nor the parent.
532+
Should return False.
533+
"""
534+
model = DummyModel()
535+
name = "transformer.layers.0.self_attn.q_proj"
536+
target = r"re:this_does_not_exist$"
537+
538+
assert not is_narrow_match(model, target, name)
539+
540+
def test_narrow_match_iterable_targets_any_true(self):
541+
"""
542+
With multiple targets: if any target narrowly matches the child,
543+
the function should return True.
544+
"""
545+
model = DummyModel()
546+
name = "transformer.layers.0.self_attn.q_proj"
547+
# First target is broad (matches both child & parent -> narrow False),
548+
# second target is narrow (matches child only -> narrow True).
549+
targets = [
550+
r"re:transformer\.layers\.0\..*",
551+
r"re:.*q_proj$",
552+
]
553+
554+
assert is_narrow_match(model, targets, name)
555+
556+
def test_narrow_match_with_explicit_module_argument(self):
557+
"""
558+
Passing the module explicitly should behave the same as when it's
559+
retrieved from the model by name.
560+
"""
561+
model = DummyModel()
562+
name = "transformer.layers.0.self_attn.q_proj"
563+
module = model.get_submodule(name)
564+
target = r"re:.*q_proj$"
565+
566+
# Both ways should be True
567+
assert is_narrow_match(model, target, name)
568+
assert is_narrow_match(model, target, name, module=module)
569+
570+
def test_narrow_match_top_level_behavior_documented(self):
571+
"""
572+
(Behavior check) For a top-level module name without a dot, the current
573+
implementation derives parent_name == name, so parent==child.
574+
Then 'narrow' cannot be True because parent match mirrors child match.
575+
This test documents current behavior to guard against regressions.
576+
"""
577+
model = DummyModel()
578+
name = "layer1" # top-level module in DummyModel
579+
target = r"re:^layer1$"
580+
581+
assert not is_narrow_match(model, target, name)
582+
583+
503584
class TestIntegration:
504585
"""Integration tests combining multiple functions"""
505586

0 commit comments

Comments
 (0)