Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 01e1fb1

Browse files
rahul-tulimarkurtz
andauthored
[BugFix]: Exporting when --do_train and --do_eval are not set (#843) (#853)
* [BugFixes]: To supported exporting when `--do_train` and `--do_eval` are not set * Use `get_eval_dataloader()` instead of `get_val_dataloader` For `question_answering.py` * Always create a eval_dataset if num_export_samples>0 * Address suggestions from initial code review TODO: * Propagate same fixes to all transformer tasks * Propagate changes to MLM * Propagate changes to text classification * Propagate changes to token classification Co-authored-by: Mark Kurtz <mark.kurtz@neuralmagic.com>
1 parent c8aa89a commit 01e1fb1

File tree

5 files changed

+14
-10
lines changed

5 files changed

+14
-10
lines changed

src/sparseml/transformers/masked_language_modeling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,8 @@ def group_texts(examples):
635635
train_dataset = train_dataset.select(range(data_args.max_train_samples))
636636

637637
compute_metrics = None
638-
if training_args.do_eval:
638+
make_eval_dataset = training_args.do_eval or data_args.num_export_samples > 0
639+
if make_eval_dataset:
639640
if "validation" not in tokenized_datasets:
640641
raise ValueError("--do_eval requires a validation dataset")
641642
eval_dataset = tokenized_datasets["validation"]
@@ -687,7 +688,7 @@ def compute_metrics(eval_preds):
687688
args=training_args,
688689
data_args=data_args,
689690
train_dataset=train_dataset if training_args.do_train else None,
690-
eval_dataset=eval_dataset if training_args.do_eval else None,
691+
eval_dataset=eval_dataset if make_eval_dataset else None,
691692
tokenizer=tokenizer,
692693
data_collator=data_collator,
693694
compute_metrics=compute_metrics if training_args.do_eval else None,

src/sparseml/transformers/question_answering.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,10 @@ def main():
475475

476476
# Preprocessing the datasets.
477477
# Preprocessing is slighlty different for training and evaluation.
478+
make_eval_dataset = training_args.do_eval or data_args.num_export_samples > 0
478479
if training_args.do_train:
479480
column_names = raw_datasets["train"].column_names
480-
elif training_args.do_eval:
481+
elif make_eval_dataset:
481482
column_names = raw_datasets["validation"].column_names
482483
else:
483484
column_names = raw_datasets["test"].column_names
@@ -666,7 +667,7 @@ def prepare_validation_features(examples):
666667

667668
return tokenized_examples
668669

669-
if training_args.do_eval:
670+
if make_eval_dataset:
670671
if "validation" not in raw_datasets:
671672
raise ValueError("--do_eval requires a validation dataset")
672673
eval_examples = raw_datasets["validation"]
@@ -777,7 +778,7 @@ def compute_metrics(p: EvalPrediction):
777778
args=training_args,
778779
data_args=data_args,
779780
train_dataset=train_dataset if training_args.do_train else None,
780-
eval_dataset=eval_dataset if training_args.do_eval else None,
781+
eval_dataset=eval_dataset if make_eval_dataset else None,
781782
eval_examples=eval_examples if training_args.do_eval else None,
782783
tokenizer=tokenizer,
783784
data_collator=data_collator,

src/sparseml/transformers/sparsification/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def save_sample_inputs_outputs(
502502
device = self.model.device
503503

504504
try:
505-
dataloader = self.get_val_dataloader()
505+
dataloader = self.get_eval_dataloader()
506506
except Exception:
507507
dataloader = self.get_train_dataloader()
508508

src/sparseml/transformers/text_classification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,8 @@ def preprocess_function(examples):
649649
if data_args.max_train_samples is not None:
650650
train_dataset = train_dataset.select(range(data_args.max_train_samples))
651651

652-
if training_args.do_eval:
652+
make_eval_dataset = training_args.do_eval or data_args.num_export_samples > 0
653+
if make_eval_dataset:
653654
if (
654655
"validation" not in raw_datasets
655656
and "validation_matched" not in raw_datasets
@@ -725,7 +726,7 @@ def compute_metrics(p: EvalPrediction):
725726
args=training_args,
726727
data_args=data_args,
727728
train_dataset=train_dataset if training_args.do_train else None,
728-
eval_dataset=eval_dataset if training_args.do_eval else None,
729+
eval_dataset=eval_dataset if make_eval_dataset else None,
729730
tokenizer=tokenizer,
730731
data_collator=data_collator,
731732
compute_metrics=compute_metrics,

src/sparseml/transformers/token_classification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ def tokenize_and_align_labels(examples):
560560
desc="Running tokenizer on train dataset",
561561
)
562562

563-
if training_args.do_eval:
563+
make_eval_dataset = training_args.do_eval or data_args.num_export_samples > 0
564+
if make_eval_dataset:
564565
if "validation" not in raw_datasets:
565566
raise ValueError("--do_eval requires a validation dataset")
566567
eval_dataset = raw_datasets["validation"]
@@ -648,7 +649,7 @@ def compute_metrics(p):
648649
args=training_args,
649650
data_args=data_args,
650651
train_dataset=train_dataset if training_args.do_train else None,
651-
eval_dataset=eval_dataset if training_args.do_eval else None,
652+
eval_dataset=eval_dataset if make_eval_dataset else None,
652653
tokenizer=tokenizer,
653654
data_collator=data_collator,
654655
compute_metrics=compute_metrics,

0 commit comments

Comments
 (0)