Skip to content

Commit 056daee

Browse files
committed
tree flatten -> tree_leaves
Summary Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent 1f10f7d commit 056daee

File tree

2 files changed

+5
-5
lines changed
  • src/llmcompressor/modifiers

2 files changed

+5
-5
lines changed

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from loguru import logger
77
from pydantic import ConfigDict, Field
88
from torch.nn import Module
9-
from torch.utils._pytree import tree_flatten
9+
from torch.utils._pytree import tree_leaves
1010

1111
from llmcompressor.core import Event, EventType, State
1212
from llmcompressor.modifiers import Modifier
@@ -202,7 +202,7 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
202202
module_to_name = get_module_to_name_dict(model)
203203
for mapping in self.mappings:
204204
for *nested_balance_layers, smooth_layers in match_modules_set(
205-
model, tree_flatten(mapping)[0], self.ignore
205+
model, tree_leaves(mapping), self.ignore
206206
):
207207
assert len(smooth_layers) == 1, (
208208
"SmoothQuant mappings must match a single smooth layer for each "
@@ -211,7 +211,7 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
211211
)
212212
smooth_layer = smooth_layers[0]
213213
smooth_name = module_to_name.get(smooth_layers[0])
214-
balance_layers = tree_flatten(nested_balance_layers)[0]
214+
balance_layers = tree_leaves(nested_balance_layers)
215215
resolved_mappings.append(
216216
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
217217
)

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from compressed_tensors.utils import TorchDtype, get_head_dim
1313
from pydantic import Field, ValidationInfo, field_validator
14-
from torch.utils._pytree import tree_flatten
14+
from torch.utils._pytree import tree_leaves
1515
from transformers import PreTrainedModel
1616

1717
from llmcompressor.core import Event, EventType, State
@@ -208,7 +208,7 @@ def _fuse_norms(self, model: PreTrainedModel):
208208
):
209209
# match_modules_set returns a list of lists
210210
assert len(norm) == 1
211-
fuse_norm_linears(norm[0], tree_flatten(linears)[0])
211+
fuse_norm_linears(norm[0], tree_leaves(linears))
212212

213213
def _create_r1_scheme(self) -> TransformScheme:
214214
return TransformScheme(

0 commit comments

Comments
 (0)