Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 9778882

Browse files
authored
add pathway for no applied recipe during transformers eval (#530) (#531)
1 parent 30ca1bd commit 9778882

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/sparseml/transformers/utils/trainer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,14 @@ def __init__(
109109
recipe_args = {}
110110

111111
# initialize manager and override num epochs if available
112-
self.manager = ScheduledModifierManager.from_yaml(recipe, **recipe_args)
112+
self.manager = (
113+
ScheduledModifierManager.from_yaml(recipe, **recipe_args)
114+
if recipe
115+
else None
116+
)
113117
if (
114-
self.manager.max_epochs
118+
self.manager
119+
and self.manager.max_epochs
115120
and "args" in kwargs
116121
and (hasattr(kwargs["args"], "num_train_epochs"))
117122
):
@@ -178,7 +183,7 @@ def create_optimizer(self):
178183
Create optimizer customized using SparseML
179184
"""
180185
super().create_optimizer()
181-
if not self.recipe:
186+
if not self.recipe or not self.manager:
182187
return
183188
total_batch_size = (
184189
self.args.per_device_train_batch_size
@@ -234,7 +239,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
234239
"""
235240
Computing loss using teacher/student distillation
236241
"""
237-
if not self.recipe or self.teacher is None:
242+
if not self.recipe or self.manager is None or self.teacher is None:
238243
return super().compute_loss(model, inputs, return_outputs=return_outputs)
239244

240245
student_outputs = model(**inputs)
@@ -273,6 +278,9 @@ def _save_arch_modifiers(self, output_dir: Optional[str] = None):
273278
Save modifiers that change the model's architecture, which is to be applied
274279
later on whenever the model is loaded
275280
"""
281+
if not self.manager:
282+
return
283+
276284
output_dir = output_dir if output_dir is not None else self.args.output_dir
277285
output_recipe_file = os.path.join(output_dir, RECIPE_NAME)
278286
saved_mods = [

0 commit comments

Comments
 (0)