Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 2d0a6a9

Browse files
authored
[Cherry-Pick][1.4.1] Fix divergence between student and teacher label mapping (#1400)
* initial commit * bump version
1 parent ed9673c commit 2d0a6a9

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

src/sparseml/transformers/token_classification.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import sys
2929
from contextlib import nullcontext
3030
from dataclasses import dataclass, field
31-
from typing import Any, Dict, List, Optional, Union
31+
from typing import Any, Dict, List, Optional, Tuple, Union
3232

3333
import datasets
3434
import numpy as np
@@ -382,6 +382,10 @@ def main(**kwargs):
382382
},
383383
)
384384

385+
if teacher:
386+
# check whether teacher and student have the corresponding outputs
387+
label_to_id, label_list = _check_teacher_student_outputs(teacher, label_to_id)
388+
385389
tokenizer_src = (
386390
model_args.tokenizer_name
387391
if model_args.tokenizer_name
@@ -575,6 +579,39 @@ def compute_metrics(p):
575579
)
576580

577581

582+
def _check_teacher_student_outputs(
583+
teacher: Module, label_to_id: Dict[str, int]
584+
) -> Tuple[Dict[str, int], List[str]]:
585+
# Check that the teacher and student have the same labels and if they do,
586+
# check that the mapping between labels and ids is the same.
587+
588+
teacher_labels = list(teacher.config.label2id.keys())
589+
teacher_ids = list(teacher.config.label2id.values())
590+
591+
student_labels = list(label_to_id.keys())
592+
student_ids = list(label_to_id.values())
593+
594+
if set(teacher_labels) != set(student_labels):
595+
_LOGGER.warning(
596+
f"Teacher labels {teacher_labels} do not match "
597+
f"student labels {student_labels}. Ignore this warning "
598+
"if this is expected behavior."
599+
)
600+
else:
601+
if student_ids != teacher_ids:
602+
_LOGGER.warning(
603+
"Teacher and student labels match, but the mapping "
604+
"between teachers labels and ids does not match the "
605+
"mapping between student labels and ids. "
606+
"The student's mapping will be overwritten "
607+
"by the teacher's mapping."
608+
)
609+
label_to_id = teacher.config.label2id
610+
label_list = teacher_labels
611+
612+
return label_to_id, label_list
613+
614+
578615
def _get_label_list(labels):
579616
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go
580617
# through the dataset to get the unique labels.

src/sparseml/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from datetime import date
2020

2121

22-
version_base = "1.4.0"
22+
version_base = "1.4.1"
2323
is_release = False # change to True to set the generated version as a release version
2424

2525

0 commit comments

Comments
 (0)