66from loguru import logger
77from pydantic import ConfigDict , Field
88from torch .nn import Module
9- from torch .utils ._pytree import tree_flatten
9+ from torch .utils ._pytree import tree_leaves
1010
1111from llmcompressor .core import Event , EventType , State
1212from llmcompressor .modifiers import Modifier
@@ -202,7 +202,7 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
202202 module_to_name = get_module_to_name_dict (model )
203203 for mapping in self .mappings :
204204 for * nested_balance_layers , smooth_layers in match_modules_set (
205- model , tree_flatten (mapping )[ 0 ] , self .ignore
205+ model , tree_leaves (mapping ), self .ignore
206206 ):
207207 assert len (smooth_layers ) == 1 , (
208208 "SmoothQuant mappings must match a single smooth layer for each "
@@ -211,7 +211,7 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
211211 )
212212 smooth_layer = smooth_layers [0 ]
213213 smooth_name = module_to_name .get (smooth_layers [0 ])
214- balance_layers = tree_flatten (nested_balance_layers )[ 0 ]
214+ balance_layers = tree_leaves (nested_balance_layers )
215215 resolved_mappings .append (
216216 SmoothQuantMapping (smooth_name , smooth_layer , balance_layers )
217217 )
0 commit comments