Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 833ec73

Browse files
sparse BERT + Distilation (#150)
* adding distillation shell * add WANDB for exps and add layer perf for README * add line to remeber to remove spacemanidols's wandb stuff * small implementations in model distillation * cleaning up code because I realized the teacher model only needs to run once * distilation code implemented * cleaned up code for public usage. awaiting numbers * remove unused variable * adding recipes * improve layer dropping * improve layer dropping * clean up layer dropping * updating recipes * Updaing results in README * updating readme * fixing epoch config issue and pushing recipe * forgot to remvoe exit(0) * minor update to distillation code to remove logging of epoch * removing unneeded print * removing wandb * added REAME updates for model perf Co-authored-by: Mark Kurtz <mark@neuralmagic.com>
1 parent 3051fcf commit 833ec73

20 files changed

+11117
-16
lines changed

integrations/transformers/README.md

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ python run_qa.py \
3333
--model_name_or_path bert-base-uncased \
3434
--dataset_name squad \
3535
--do_train \
36-
--per_device_train_batch_size 12 \
36+
--per_device_train_batch_size 16 \
3737
--learning_rate 3e-5 \
3838
--max_seq_length 384 \
3939
--doc_stride 128 \
@@ -44,6 +44,7 @@ python run_qa.py \
4444
--seed 42 \
4545
--num_train_epochs 2 \
4646
--nm_prune_config recipes/90sparsity1shot.yaml
47+
--fp16
4748
```
4849

4950
#### Evaluation
@@ -52,7 +53,7 @@ python run_qa.py \
5253
--model_name_or_path bert-base-uncased-99sparsity-10total8gmp/ \
5354
--dataset_name squad \
5455
--do_eval \
55-
--per_device_eval_batch_size 12 \
56+
--per_device_eval_batch_size 16 \
5657
--output_dir bert-base-uncased-99sparsity-10total8gmp/ \
5758
--overwrite_output_dir \
5859
--cache_dir cache \
@@ -75,7 +76,7 @@ To demostrate the effect that various pruning regimes and techniques can have we
7576

7677
| base model name | sparsity | total train epochs | prunned | one shot |pruning epochs| F1 Score | EM Score |
7778
|-----------------------|---------- |-----------------------|---------|----------|--------------|---------- |-----------|
78-
| bert-base-uncased |0 |1 |no |no |0 |09.685 |3.614 |
79+
| bert-base-uncased |0 |1 |no |no |0 |09.685 |03.614 |
7980
| bert-base-uncased |0 |2 |no |no |0 |88.002 |80.634 |
8081
| bert-base-uncased |0 |10 |no |no |0 |87.603 |79.130 |
8182
| bert-base-uncased |80 |1 |yes |yes |0 |25.141 |15.998 |
@@ -84,14 +85,103 @@ To demostrate the effect that various pruning regimes and techniques can have we
8485
| bert-base-uncased |90 |1 |yes |yes |0 |16.064 |07.786 |
8586
| bert-base-uncased |90 |2 |yes |no |0 |64.185 |50.946 |
8687
| bert-base-uncased |90 |10 |yes |no |8 |79.091 |68.184 |
87-
| bert-base-uncased |95 |1 |yes |yes |0 |10.501 |4.929 |
88+
| bert-base-uncased |95 |1 |yes |yes |0 |10.501 |04.929 |
8889
| bert-base-uncased |95 |2 |yes |no |0 |24.445 |14.437 |
8990
| bert-base-uncased |95 |10 |yes |no |8 |72.761 |60.407 |
91+
| bert-base-uncased |97 |10 |yes |no |6 |70.260 |57.021 |
9092
| bert-base-uncased |99 |1 |yes |yes |0 |09.685 |03.614 |
9193
| bert-base-uncased |99 |2 |yes |no |0 |17.433 |07.871 |
9294
| bert-base-uncased |99 |10 |yes |no |8 |47.306 |32.564 |
9395

96+
## Training With distillation
97+
In addition to a simple QA model we provide implementation which can leverage teacher-student distillation. The usage of the distillation code is virually identical to the non distilled model but commands are as follow.
9498

99+
#### Training
100+
```bash
101+
python run_distill_qa.py \
102+
--teacher_model_name_or_path spacemanidol/neuralmagic-bert-squad-12layer-0sparse\
103+
--student_model_name_or_path bert-base-uncased \
104+
--dataset_name squad \
105+
--do_train \
106+
--per_device_train_batch_size 16 \
107+
--learning_rate 3e-5 \
108+
--max_seq_length 384 \
109+
--doc_stride 128 \
110+
--output_dir distill_2epoch/ \
111+
--overwrite_output_dir \
112+
--cache_dir cache \
113+
--preprocessing_num_workers 4 \
114+
--seed 42 \
115+
--num_train_epochs 2 \
116+
--nm_prune_config recipes/noprune2epoch.yaml
117+
--fp16
118+
```
119+
120+
#### Evaluation
121+
```bash
122+
python run_qa.py \
123+
--model_name_or_path bert-base-uncased-99sparsity-10total8gmp/ \
124+
--dataset_name squad \
125+
--do_eval \
126+
--per_device_eval_batch_size 16 \
127+
--output_dir bert-base-uncased-99sparsity-10total8gmp/ \
128+
--overwrite_output_dir \
129+
--cache_dir cache \
130+
--preprocessing_num_workers 4 \
131+
```
132+
#### ONNX Export
133+
```bash
134+
python run_qa.py \
135+
--model_name_or_path bert-base-uncased-99sparsity-10total8gmp/
136+
--do_eval \
137+
--dataset_name squad \
138+
--do_onnx_export \
139+
--onnx_export_path bert-base-uncased-99sparsity-10total8gmp/ \
140+
--cache_dir cache \
141+
--preprocessing_num_workers 4 \
142+
```
143+
### Distillation Results
144+
Sparsity 80, 90, 97
145+
| base model name | sparsity |Distilled| prunned |train epochs|pruning epochs| F1 Score | EM Score |
146+
|-----------------------|---------- |---------|---------|------------|--------------|----------|----------|
147+
| bert-base-uncased |0 |no |no |2 |0 |88.32442 |81.10690 |
148+
| bert-base-uncased |80 |no |no |30 |18 |84.06276 |74.63576 |
149+
| bert-base-uncased |90 |no |no |30 |18 |79.64549 |68.50520 |
150+
| bert-base-uncased |97 |no |no |30 |18 |70.42570 |57.29423 |
151+
| bert-base-uncased |0 |yes |no |2 |0 |89.02277 |82.03406 |
152+
| bert-base-uncased |80 |yes |yes |30 |18 |88.03192 |80.81362 |
153+
| bert-base-uncased |90 |yes |yes |30 |18 |85.63751 |77.41721 |
154+
| bert-base-uncased |97 |yes |yes |30 |18 | | |
155+
156+
### Distillation, Pruning, Layer Dropping
157+
To explore the effect of model pruning compared to layer dropping we train models to sparsity to match the amount of parameters in models with layers droppend. Results feature both with and without distillation. For distillation we use hard distillation and a a trained teacher model which is trained on SQUAD for 2 epochs and achieves an 88.32442/81.10690 F1/EM. A 9 layer model is roughly equivalent to 20% sparsity, 6 layer to 40%, 3 layer to 60%, 1 layer to 72%.
158+
159+
| base model name | sparsity | params |Distilled| prunned | layers |pruning epochs| F1 Score | EM Score |
160+
|-----------------------|---------- |-----------------------|---------|---------|----------|--------------|----------|-----------|
161+
| bert-base-uncased |0 |108,893,186 |no |no |12 |0 |88.32442 |81.10690 |
162+
| bert-base-uncased |0 |87,629,570 |no |no |9 |0 |86.70732 |78.81740 |
163+
| bert-base-uncased |0 |66,365,954 |no |no |6 |0 |81.63629 |72.66793 |
164+
| bert-base-uncased |0 |45,102,338 |no |no |3 |0 |51.75267 |39.11069 |
165+
| bert-base-uncased |0 |30,926,594 |no |no |1 |0 |26.22600 |17.32261 |
166+
| bert-base-uncased |20 |108,893,186 |no |yes |12 |8 |87.19622 |79.16746 |
167+
| bert-base-uncased |40 |108,893,186 |no |yes |12 |8 |86.27294 |78.07947 |
168+
| bert-base-uncased |60 |108,893,186 |no |yes |12 |8 |86.4412 |77.94702 |
169+
| bert-base-uncased |72 |108,893,186 |no |yes |12 |8 |85.49873 |76.43330 |
170+
| bert-base-uncased |80 |66,365,954 |no |yes |6 |8 |77.86777 |67.07663 |
171+
| bert-base-uncased |90 |66,365,954 |no |yes |6 |8 |73.51963 |61.22044 |
172+
| bert-base-uncased |97 |66,365,954 |no |yes |6 |8 |67.27468 |53.85998 |
173+
| bert-base-uncased |0 |108,893,186 |yes |no |12 |0 |89.02277 |82.03406 |
174+
| bert-base-uncased |0 |87,629,570 |yes |no |9 |0 |87.94176 |80.46358 |
175+
| bert-base-uncased |0 |66,365,954 |yes |no |6 |0 |83.4553 |75.03311 |
176+
| bert-base-uncased |0 |45,102,338 |yes |no |3 |0 |43.82823 |33.05581 |
177+
| bert-base-uncased |0 |30,926,594 |yes |no |1 |0 |28.10105 |18.5052 |
178+
| bert-base-uncased |20 |108,893,186 |yes |yes |12 |18 | | |
179+
| bert-base-uncased |40 |108,893,186 |yes |yes |12 |18 | | |
180+
| bert-base-uncased |60 |108,893,186 |yes |yes |12 |18 | | |
181+
| bert-base-uncased |72 |108,893,186 |yes |yes |12 |18 | | |
182+
| bert-base-uncased |80 |66,365,954 |yes |yes |6 |8 | | |
183+
| bert-base-uncased |90 |66,365,954 |yes |yes |6 |8 | | |
184+
| bert-base-uncased |97 |66,365,954 |yes |yes |6 |8 | | |
95185

96186
## Script origin and how to integrate sparseml with other Transformers projects
97187
This script is based on the example BERT-QA implementation in transformers found [here](https://github.com/huggingface/transformers/blob/master/examples/question-answering/run_qa.py).
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# neuralmagic: no copyright
2+
# flake8: noqa
3+
# fmt: off
4+
# isort: skip_file
5+
#!/usr/bin/env python
6+
# coding=utf-8
7+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing,
16+
# software distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
from typing import Union
21+
22+
import torch
23+
from torch import nn
24+
import torch.nn.functional as F
25+
from torch import Tensor
26+
27+
from transformers import Trainer, is_datasets_available, is_torch_tpu_available
28+
from transformers.trainer_utils import PredictionOutput
29+
30+
from trainer_qa import QuestionAnsweringTrainer
31+
32+
class DistillQuestionAnsweringTrainer(QuestionAnsweringTrainer):
33+
def __init__(self, *args, eval_examples=None, post_process_function=None, teacher=None, loss=None, batch_size=8, max_sequence_length=384,distill_hardness =0.5, temperature=2.0, **kwargs):
34+
super().__init__(*args, **kwargs)
35+
self.eval_examples = eval_examples
36+
self.post_process_function = post_process_function
37+
self.loss = loss
38+
self.teacher = teacher
39+
self.batch_size = batch_size
40+
self.temperature = temperature
41+
self.distill_hardness = distill_hardness
42+
self.criterion = nn.CrossEntropyLoss()
43+
self.max_sequence_length = max_sequence_length
44+
45+
def compute_loss(self, model, inputs, return_outputs=False):
46+
"""
47+
How the loss is computed by Trainer. Modified for Distilation using student teacher framework modified for distilation.
48+
"""
49+
input_device = inputs["input_ids"].device
50+
outputs = model(**inputs)
51+
start_logits_student = outputs["start_logits"]
52+
end_logits_student = outputs["end_logits"]
53+
start_logits_label = inputs["start_positions"]
54+
end_logits_label = inputs["start_positions"]
55+
self.teacher = self.teacher.to(input_device)
56+
with torch.no_grad():
57+
teacher_output = self.teacher(
58+
input_ids=inputs["input_ids"],
59+
token_type_ids=inputs["token_type_ids"],
60+
attention_mask=inputs["attention_mask"],
61+
)
62+
start_logits_teacher = teacher_output["start_logits"]
63+
end_logits_teacher = teacher_output["end_logits"]
64+
loss_start = (
65+
F.kl_div(
66+
input=F.log_softmax(start_logits_student / self.temperature, dim=-1),
67+
target=F.softmax(start_logits_teacher / self.temperature, dim=-1),
68+
reduction="batchmean",
69+
)
70+
* (self.temperature ** 2)
71+
)
72+
loss_end = (
73+
F.kl_div(
74+
input=F.log_softmax(end_logits_student / self.temperature, dim=-1),
75+
target=F.softmax(end_logits_teacher / self.temperature, dim=-1),
76+
reduction="batchmean",
77+
)
78+
* (self.temperature ** 2)
79+
)
80+
teacher_loss = (loss_start + loss_end) / 2.0
81+
loss_start = self.criterion(start_logits_student, start_logits_label)
82+
loss_end = self.criterion(end_logits_student, end_logits_label)
83+
label_loss = (loss_start + loss_end) / 2.0
84+
loss = ((1-self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss)
85+
return loss

0 commit comments

Comments
 (0)