22import torch
33from compressed_tensors .quantization import QuantizationArgs , QuantizationScheme
44from pydantic import ValidationError
5-
5+ from torch . nn import Linear
66from llmcompressor .modifiers .awq import AWQMapping , AWQModifier
7- from llmcompressor .modifiers .awq .base import get_lowest_common_parent
7+ from llmcompressor .modifiers .awq .base import get_lowest_common_module
88from llmcompressor .modifiers .factory import ModifierFactory
99
1010
@@ -40,16 +40,16 @@ def test_set_resolved_mappings():
4040 )
4141 self_attn = torch .nn .ModuleDict (
4242 {
43- "q_proj" : torch . nn . Linear (4 , 4 ),
44- "k_proj" : torch . nn . Linear (4 , 4 ),
45- "v_proj" : torch . nn . Linear (4 , 4 ),
46- "o_proj" : torch . nn . Linear (4 , 4 ),
43+ "q_proj" : Linear (4 , 4 ),
44+ "k_proj" : Linear (4 , 4 ),
45+ "v_proj" : Linear (4 , 4 ),
46+ "o_proj" : Linear (4 , 4 ),
4747 }
4848 )
4949 mlp = torch .nn .ModuleDict (
5050 {
51- "up_proj" : torch . nn . Linear (4 , 10 ),
52- "down_proj" : torch . nn . Linear (10 , 4 ),
51+ "up_proj" : Linear (4 , 10 ),
52+ "down_proj" : Linear (10 , 4 ),
5353 }
5454 )
5555 model = torch .nn .ModuleDict (
@@ -100,11 +100,11 @@ def test_set_resolved_mappings():
100100 {
101101 "self_attn" : torch .nn .ModuleDict (
102102 {
103- "q_proj" : torch . nn . Linear (4 , 2 ),
104- "k_proj" : torch . nn . Linear (4 , 2 ),
105- "v_proj" : torch . nn . Linear (4 , 2 ),
106- "z_proj" : torch . nn . Linear (2 , 4 ),
107- "o_proj" : torch . nn . Linear (4 , 4 ),
103+ "q_proj" : Linear (4 , 2 ),
104+ "k_proj" : Linear (4 , 2 ),
105+ "v_proj" : Linear (4 , 2 ),
106+ "z_proj" : Linear (2 , 4 ),
107+ "o_proj" : Linear (4 , 4 ),
108108 }
109109 )
110110 }
@@ -192,15 +192,15 @@ def test_validate():
192192
193193
194194@pytest .mark .unit
195- def test_get_lowest_common_parent ():
195+ def test_get_lowest_common_module ():
196196 mlp = torch .nn .ModuleDict (
197197 {
198198 "experts" : torch .nn .ModuleList (
199199 [
200200 torch .nn .ModuleDict (
201201 {
202- "gate_proj" : torch . nn . Linear (4 , 2 ),
203- "down_proj" : torch . nn . Linear (4 , 2 ),
202+ "gate_proj" : Linear (4 , 2 ),
203+ "down_proj" : Linear (4 , 2 ),
204204 }
205205 )
206206 for _ in range (10 )
@@ -210,15 +210,15 @@ def test_get_lowest_common_parent():
210210 )
211211 self_attn = torch .nn .ModuleDict (
212212 {
213- "q_proj" : torch . nn . Linear (4 , 2 ),
214- "k_proj" : torch . nn . Linear (4 , 2 ),
215- "v_proj" : torch . nn . Linear (4 , 2 ),
216- "o_proj" : torch . nn . Linear (4 , 4 ),
213+ "q_proj" : Linear (4 , 2 ),
214+ "k_proj" : Linear (4 , 2 ),
215+ "v_proj" : Linear (4 , 2 ),
216+ "o_proj" : Linear (4 , 4 ),
217217 }
218218 )
219219 model = torch .nn .ModuleDict (
220220 {
221- "embed_tokens" : torch . nn . Linear (4 , 2 ),
221+ "embed_tokens" : Linear (4 , 2 ),
222222 "decoder" : torch .nn .ModuleDict (
223223 {
224224 "self_attn" : self_attn ,
@@ -228,22 +228,37 @@ def test_get_lowest_common_parent():
228228 }
229229 )
230230
231- parent_name , parent = get_lowest_common_parent (
231+ parent_name , parent = get_lowest_common_module (
232232 ["decoder.mlp.experts.1.gate_proj" , "decoder.mlp.experts.4.down_proj" ], model
233233 )
234234 assert parent_name == "decoder.mlp" and parent == mlp
235235
236- parent_name , parent = get_lowest_common_parent (
236+ parent_name , parent = get_lowest_common_module (
237237 ["decoder.self_attn.q_proj" , "decoder.self_attn.v_proj" ], model
238238 )
239239 assert parent_name == "decoder.self_attn" and parent == self_attn
240240
241- parent_name , parent = get_lowest_common_parent (
241+ parent_name , parent = get_lowest_common_module (
242242 ["decoder.mlp.experts.1.gate_proj" , "decoder.self_attn.v_proj" ], model
243243 )
244244 assert parent_name == "decoder" and parent == model ["decoder" ]
245245
246- parent_name , parent = get_lowest_common_parent (
246+ parent_name , parent = get_lowest_common_module (
247247 ["embed_tokens" , "decoder.self_attn.v_proj" ], model
248248 )
249249 assert parent_name == "" and parent == model
250+
251+ m = torch .nn .ModuleDict (
252+ {
253+ "abc" : Linear (3 ,3 ),
254+ "ab" : torch .nn .ModuleDict ({"a" : Linear (3 ,3 )}),
255+ "z" : Linear (3 ,3 )
256+ }
257+ )
258+ parent_name , parent = get_lowest_common_module (["abc" , "ab" ], m )
259+ assert parent_name == ""
260+ parent_name , parent = get_lowest_common_module (["ab" , "ab.a" ], m )
261+ assert parent_name == "ab"
262+ parent_name , parent = get_lowest_common_module (["z" ], m )
263+ assert parent_name == "z"
264+
0 commit comments