Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit ee2ee13

Browse files
jeanniefinkseldarkurticbfineranmarkurtz
authored
Jfinks release edits (#207)
* Update annotate.py minor edits, thx * Update annotate.py Getting additional edits at top of file to match ones below * Fix for integrations/timm checkpoint path (#198) This PR fixes issue #197 Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com> * Fix steps_per_epoch calculation (#201) Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com> * YOLO webcam example - add assert for webcam load (#202) * YOLO webcam example - add assert for webcam load * update readme to note other options for annotate * formatting Co-authored-by: Eldar Kurtic <eldar.ciki@gmail.com> Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com> Co-authored-by: Mark Kurtz <mark@neuralmagic.com>
1 parent 4387a23 commit ee2ee13

File tree

4 files changed

+65
-63
lines changed

4 files changed

+65
-63
lines changed

integrations/timm/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def main():
389389
elif args.initial_checkpoint.startswith("zoo:"):
390390
# Load weights from a SparseZoo model stub
391391
zoo_model = Zoo.load_model_from_stub(args.initial_checkpoint)
392-
args.initial_checkpoint = zoo_model.download_framework_files(extensions=[".pth"])
392+
args.initial_checkpoint = zoo_model.download_framework_files(extensions=[".pth"])[0]
393393
####################################################################################
394394
# End - SparseML optional load weights from SparseZoo
395395
####################################################################################

integrations/transformers/run_distill_qa.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
# limitations under the License.
2020

2121
"""
22-
Example script for integrating spaseml with the transformers library to perform model distillation.
23-
This script is addopted from hugging face's implementation for Question Answering on the SQUAD Dataset.
22+
Example script for integrating spaseml with the transformers library to perform model distillation.
23+
This script is addopted from hugging face's implementation for Question Answering on the SQUAD Dataset.
2424
Hugging Face's original implementation is regularly updated and can be found at https://github.com/huggingface/transformers/blob/master/examples/question-answering/run_qa.py
2525
This script will:
2626
- Load transformer based models
@@ -54,12 +54,12 @@
5454
[--onnx_export_path] \
5555
[--layers_to_keep] \
5656
57-
Train, prune, and evaluate a transformer base question answering model on squad.
57+
Train, prune, and evaluate a transformer base question answering model on squad.
5858
-h, --help show this help message and exit
5959
--teacher_model_name_or_path The name or path of model which will be used for distilation.
6060
Note, this model needs to be trained for QA task already.
6161
--student_model_name_or_path The name or path of the model wich will be trained using distilation.
62-
--temperature Hyperparameter which controls model distilation
62+
--temperature Hyperparameter which controls model distilation
6363
--distill_hardness Hyperparameter which controls how much of the loss comes from teacher vs training labels
6464
--model_name_or_path The path to the transformers model you wish to train
6565
or the name of the pretrained language model you wish
@@ -72,21 +72,21 @@
7272
or not. Default is false.
7373
--do_eval Boolean denoting if the model should be evaluated
7474
or not. Default is false.
75-
--per_device_train_batch_size Size of each training batch based on samples per GPU.
75+
--per_device_train_batch_size Size of each training batch based on samples per GPU.
7676
12 will fit in a 11gb GPU, 16 in a 16gb.
77-
--per_device_eval_batch_size Size of each training batch based on samples per GPU.
77+
--per_device_eval_batch_size Size of each training batch based on samples per GPU.
7878
12 will fit in a 11gb GPU, 16 in a 16gb.
7979
--learning_rate Learning rate initial float value. ex: 3e-5.
80-
--max_seq_length Int for the max sequence length to be parsed as a context
80+
--max_seq_length Int for the max sequence length to be parsed as a context
8181
window. ex: 384 tokens.
8282
--output_dir Path which model checkpoints and paths should be saved.
83-
--overwrite_output_dir Boolean to define if the
83+
--overwrite_output_dir Boolean to define if the
8484
--cache_dir Directiory which cached transformer files(datasets, models
85-
, tokenizers) are saved for fast loading.
85+
, tokenizers) are saved for fast loading.
8686
--preprocessing_num_workers The amount of cpu workers which are used to process datasets
8787
--seed Int which determines what random seed is for training/shuffling
8888
--nm_prune_config Path to the neural magic prune configuration file. examples can
89-
be found in prune_config_files but are customized for bert-base-uncased.
89+
be found in prune_config_files but are customized for bert-base-uncased.
9090
--do_onnx_export Boolean denoting if the model should be exported to onnx
9191
--onnx_export_path Path where onnx model path will be exported. ex: onnx-export
9292
--layers_to_keep Number of layers to keep from original model. Layers are dropped before training
@@ -611,7 +611,7 @@ def prepare_validation_features(examples):
611611
]
612612
return tokenized_examples
613613

614-
transformers.utils.logging.set_verbosity_info()
614+
transformers.utils.logging.set_verbosity_info()
615615
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
616616
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
617617
# If we pass only one argument to the script and it's the path to a json file,
@@ -639,7 +639,7 @@ def prepare_validation_features(examples):
639639
)
640640

641641
logger.warning(
642-
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
642+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
643643
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
644644
)
645645

@@ -690,10 +690,10 @@ def prepare_validation_features(examples):
690690

691691
student_model_parameters = filter(lambda p: p.requires_grad, student_model.parameters())
692692
params = sum([np.prod(p.size()) for p in student_model_parameters])
693-
logger.info("Student Model has %s parameters", params)
693+
logger.info("Student Model has %s parameters", params)
694694
teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters())
695695
params = sum([np.prod(p.size()) for p in teacher_model_parameters])
696-
logger.info("Teacher Model has %s parameters", params)
696+
logger.info("Teacher Model has %s parameters", params)
697697
# Tokenizer check: this script requires a fast tokenizer.
698698
if not isinstance(tokenizer, PreTrainedTokenizerFast):
699699
raise ValueError(
@@ -710,7 +710,7 @@ def prepare_validation_features(examples):
710710
context_column_name = "context" if "context" in column_names else column_names[1]
711711
answer_column_name = "answers" if "answers" in column_names else column_names[2]
712712

713-
pad_on_right = tokenizer.padding_side == "right"
713+
pad_on_right = tokenizer.padding_side == "right"
714714

715715
data_collator = (
716716
default_data_collator
@@ -744,15 +744,16 @@ def prepare_validation_features(examples):
744744
)
745745
####################################################################################
746746
# Start SparseML Integration
747-
####################################################################################
748-
optim = load_optimizer(student_model, TrainingArguments)
749-
steps_per_epoch = math.ceil(len(datasets["train"]) / (training_args.per_device_train_batch_size*training_args._n_gpu))
750-
manager = ScheduledModifierManager.from_yaml(data_args.nm_prune_config)
751-
training_args.num_train_epochs = float(manager.modifiers[0].end_epoch)
752-
optim = ScheduledOptimizer(optim, student_model, manager, steps_per_epoch=steps_per_epoch, loggers=None)
747+
####################################################################################
748+
if training_args.do_train:
749+
optim = load_optimizer(student_model, TrainingArguments)
750+
steps_per_epoch = math.ceil(len(train_dataset) / (training_args.per_device_train_batch_size * training_args._n_gpu))
751+
manager = ScheduledModifierManager.from_yaml(data_args.nm_prune_config)
752+
training_args.num_train_epochs = float(manager.modifiers[0].end_epoch)
753+
optim = ScheduledOptimizer(optim, student_model, manager, steps_per_epoch=steps_per_epoch, loggers=None)
753754
####################################################################################
754755
# End SparseML Integration
755-
####################################################################################
756+
####################################################################################
756757
# Initialize our Trainer
757758
trainer = DistillQuestionAnsweringTrainer(
758759
model=student_model,
@@ -764,7 +765,7 @@ def prepare_validation_features(examples):
764765
data_collator=data_collator,
765766
post_process_function=post_processing_function,
766767
compute_metrics=compute_metrics,
767-
optimizers=(optim, None),
768+
optimizers=(optim, None) if training_args.do_train else (None, None),
768769
teacher=teacher_model,
769770
distill_hardness = model_args.distill_hardness,
770771
temperature = model_args.temperature,

integrations/transformers/run_qa.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
# limitations under the License.
2020

2121
"""
22-
Example script for integrating spaseml with the transformers library.
23-
This script is addopted from hugging face's implementation for Question Answering on the SQUAD Dataset.
22+
Example script for integrating spaseml with the transformers library.
23+
This script is addopted from hugging face's implementation for Question Answering on the SQUAD Dataset.
2424
Hugging Face's original implementation is regularly updated and can be found at https://github.com/huggingface/transformers/blob/master/examples/question-answering/run_qa.py
2525
This script will:
2626
- Load transformer based modesl
@@ -50,7 +50,7 @@
5050
[--do_onnx_export]
5151
[--onnx_export_path]
5252
53-
Train, prune, and evaluate a transformer base question answering model on squad.
53+
Train, prune, and evaluate a transformer base question answering model on squad.
5454
-h, --help show this help message and exit
5555
--model_name_or_path MODEL The path to the transformers model you wish to train
5656
or the name of the pretrained language model you wish
@@ -63,21 +63,21 @@
6363
or not. Default is false.
6464
--do_eval Boolean denoting if the model should be evaluated
6565
or not. Default is false.
66-
--per_device_train_batch_size Size of each training batch based on samples per GPU.
66+
--per_device_train_batch_size Size of each training batch based on samples per GPU.
6767
12 will fit in a 11gb GPU, 16 in a 16gb.
68-
--per_device_eval_batch_size Size of each training batch based on samples per GPU.
68+
--per_device_eval_batch_size Size of each training batch based on samples per GPU.
6969
12 will fit in a 11gb GPU, 16 in a 16gb.
7070
--learning_rate Learning rate initial float value. ex: 3e-5.
71-
--max_seq_length Int for the max sequence length to be parsed as a context
71+
--max_seq_length Int for the max sequence length to be parsed as a context
7272
window. ex: 384 tokens.
7373
--output_dir Path which model checkpoints and paths should be saved.
74-
--overwrite_output_dir Boolean to define if the
74+
--overwrite_output_dir Boolean to define if the
7575
--cache_dir Directiory which cached transformer files(datasets, models
76-
, tokenizers) are saved for fast loading.
76+
, tokenizers) are saved for fast loading.
7777
--preprocessing_num_workers The amount of cpu workers which are used to process datasets
7878
--seed Int which determines what random seed is for training/shuffling
7979
--nm_prune_config Path to the neural magic prune configuration file. examples can
80-
be found in prune_config_files but are customized for bert-base-uncased.
80+
be found in prune_config_files but are customized for bert-base-uncased.
8181
--do_onnx_export Boolean denoting if the model should be exported to onnx
8282
--onnx_export_path Path where onnx model path will be exported. ex: onnx-export
8383
@@ -101,7 +101,7 @@
101101
--seed 42 \
102102
--nm_prune_config prune_config_files/95sparsity1epoch.yaml \
103103
--do_onnx_export \
104-
--onnx_export_path 95sparsity1epoch/
104+
--onnx_export_path 95sparsity1epoch/
105105
"""
106106
import collections
107107
import json
@@ -590,7 +590,7 @@ def prepare_validation_features(examples):
590590

591591
return tokenized_examples
592592

593-
transformers.utils.logging.set_verbosity_info()
593+
transformers.utils.logging.set_verbosity_info()
594594
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
595595
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
596596
# If we pass only one argument to the script and it's the path to a json file,
@@ -618,7 +618,7 @@ def prepare_validation_features(examples):
618618
)
619619

620620
logger.warning(
621-
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
621+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
622622
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
623623
)
624624

@@ -663,7 +663,7 @@ def prepare_validation_features(examples):
663663

664664
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
665665
params = sum([np.prod(p.size()) for p in model_parameters])
666-
logger.info("Model has %s parameters", params)
666+
logger.info("Model has %s parameters", params)
667667
# Tokenizer check: this script requires a fast tokenizer.
668668
if not isinstance(tokenizer, PreTrainedTokenizerFast):
669669
raise ValueError(
@@ -679,7 +679,7 @@ def prepare_validation_features(examples):
679679
question_column_name = "question" if "question" in column_names else column_names[0]
680680
context_column_name = "context" if "context" in column_names else column_names[1]
681681
answer_column_name = "answers" if "answers" in column_names else column_names[2]
682-
pad_on_right = tokenizer.padding_side == "right"
682+
pad_on_right = tokenizer.padding_side == "right"
683683

684684
if training_args.do_train:
685685
train_dataset = datasets["train"].map(
@@ -714,12 +714,13 @@ def prepare_validation_features(examples):
714714

715715
####################################################################################
716716
# Start SparseML Integration
717-
####################################################################################
718-
optim = load_optimizer(model, TrainingArguments)
719-
steps_per_epoch = math.ceil(len(datasets["train"]) / (training_args.per_device_train_batch_size*training_args._n_gpu))
720-
manager = ScheduledModifierManager.from_yaml(data_args.nm_prune_config)
721-
training_args.num_train_epochs = float(manager.max_epochs)
722-
optim = ScheduledOptimizer(optim, model, manager, steps_per_epoch=steps_per_epoch, loggers=None)
717+
####################################################################################
718+
if training_args.do_train:
719+
optim = load_optimizer(model, TrainingArguments)
720+
steps_per_epoch = math.ceil(len(train_dataset) / (training_args.per_device_train_batch_size * training_args._n_gpu))
721+
manager = ScheduledModifierManager.from_yaml(data_args.nm_prune_config)
722+
training_args.num_train_epochs = float(manager.max_epochs)
723+
optim = ScheduledOptimizer(optim, model, manager, steps_per_epoch=steps_per_epoch, loggers=None)
723724
####################################################################################
724725
# End SparseML Integration
725726
####################################################################################
@@ -734,7 +735,7 @@ def prepare_validation_features(examples):
734735
data_collator=data_collator,
735736
post_process_function=post_processing_function,
736737
compute_metrics=compute_metrics,
737-
optimizers=(optim, None),
738+
optimizers=(optim, None) if training_args.do_train else (None, None),
738739
)
739740

740741
# Training
@@ -765,7 +766,7 @@ def prepare_validation_features(examples):
765766
####################################################################################
766767
if data_args.do_onnx_export:
767768
logger.info("*** Export to ONNX ***")
768-
print("Exporting onnx model")
769+
print("Exporting onnx model")
769770
os.environ["TOKENIZERS_PARALLELISM"] = "false"
770771
exporter = ModuleExporter(
771772
model, output_dir='onnx-export'

0 commit comments

Comments
 (0)