Skip to content

Commit 728b8c0

Browse files
committed
format
Summary Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent dea5eab commit 728b8c0

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
320320
repeat for model.layer.1 and so on
321321
"""
322322
resolved_mappings: list[ResolvedMapping] = []
323-
323+
324324
module_to_name = {}
325325
for name, module in model.named_modules():
326326
if module in module_to_name:
@@ -331,14 +331,11 @@ def _set_resolved_mappings(self, model: Module) -> None:
331331
)
332332
module_to_name[module] = name
333333

334-
335-
336334
for mapping in self.mappings:
337-
338335
target_patterns = (mapping.smooth_layer, *mapping.balance_layers)
339336

340-
for smooth_layer, *balance_layers in (
341-
match_modules_set(model, target_patterns, self.ignore)
337+
for smooth_layer, *balance_layers in match_modules_set(
338+
model, target_patterns, self.ignore
342339
):
343340
smooth_name = module_to_name.get(smooth_layer)
344341
balance_names = [
@@ -353,10 +350,11 @@ def _set_resolved_mappings(self, model: Module) -> None:
353350
# skip mapping if any of the balance layers are incompatible
354351
if not all_compatible or len(balance_layers) == 0:
355352
logger.info(
356-
f"skipping AWQ for {smooth_name} for mapping {mapping}" + (
357-
" because found incompatible balance layers"
358-
if not all_compatible else
359-
f" because no balance layers were found"
353+
f"skipping AWQ for {smooth_name} for mapping {mapping}"
354+
+ (
355+
" because found incompatible balance layers"
356+
if not all_compatible
357+
else " because no balance layers were found"
360358
)
361359
)
362360

@@ -812,7 +810,7 @@ def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Mod
812810
Implementation is a small alteration of os.path.commonprefix
813811
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
814812
"""
815-
# adding "." before and after allows for handling a lot of corner
813+
# adding "." before and after allows for handling a lot of corner
816814
# cases which were previously mishandled ([case]->prefix->result)
817815
# case 0: single module: [.abc.] -> .abc. -> abc
818816
# case 1: substring modules: [.abc., .ab.] -> .ab -> ""
@@ -829,9 +827,9 @@ def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Mod
829827

830828
# 2) throw away module name fragment and leading dot
831829
# ".keep.thro" -> "keep"
832-
parent_name = parent_name[1:parent_name.rfind(".")]
830+
parent_name = parent_name[1 : parent_name.rfind(".")]
833831

834-
# 3) return first parent that is not a module list
832+
# 3) return first common module that is not a module list
835833
while True:
836834
if parent_name == "":
837835
return "", module

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
44
from pydantic import ValidationError
55
from torch.nn import Linear
6+
67
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier
78
from llmcompressor.modifiers.awq.base import get_lowest_common_module
89
from llmcompressor.modifiers.factory import ModifierFactory
@@ -250,9 +251,9 @@ def test_get_lowest_common_module():
250251

251252
m = torch.nn.ModuleDict(
252253
{
253-
"abc": Linear(3,3),
254-
"ab": torch.nn.ModuleDict({"a": Linear(3,3)}),
255-
"z": Linear(3,3)
254+
"abc": Linear(3, 3),
255+
"ab": torch.nn.ModuleDict({"a": Linear(3, 3)}),
256+
"z": Linear(3, 3),
256257
}
257258
)
258259
parent_name, parent = get_lowest_common_module(["abc", "ab"], m)
@@ -261,4 +262,3 @@ def test_get_lowest_common_module():
261262
assert parent_name == "ab"
262263
parent_name, parent = get_lowest_common_module(["z"], m)
263264
assert parent_name == "z"
264-

0 commit comments

Comments
 (0)