|
28 | 28 | import sys |
29 | 29 | from contextlib import nullcontext |
30 | 30 | from dataclasses import dataclass, field |
31 | | -from typing import Any, Dict, List, Optional, Union |
| 31 | +from typing import Any, Dict, List, Optional, Tuple, Union |
32 | 32 |
|
33 | 33 | import datasets |
34 | 34 | import numpy as np |
@@ -382,6 +382,10 @@ def main(**kwargs): |
382 | 382 | }, |
383 | 383 | ) |
384 | 384 |
|
| 385 | + if teacher: |
| 386 | + # check whether teacher and student have the corresponding outputs |
| 387 | + label_to_id, label_list = _check_teacher_student_outputs(teacher, label_to_id) |
| 388 | + |
385 | 389 | tokenizer_src = ( |
386 | 390 | model_args.tokenizer_name |
387 | 391 | if model_args.tokenizer_name |
@@ -575,6 +579,39 @@ def compute_metrics(p): |
575 | 579 | ) |
576 | 580 |
|
577 | 581 |
|
| 582 | +def _check_teacher_student_outputs( |
| 583 | + teacher: Module, label_to_id: Dict[str, int] |
| 584 | +) -> Tuple[Dict[str, int], List[str]]: |
| 585 | + # Check that the teacher and student have the same labels and if they do, |
| 586 | + # check that the mapping between labels and ids is the same. |
| 587 | + |
| 588 | + teacher_labels = list(teacher.config.label2id.keys()) |
| 589 | + teacher_ids = list(teacher.config.label2id.values()) |
| 590 | + |
| 591 | + student_labels = list(label_to_id.keys()) |
| 592 | + student_ids = list(label_to_id.values()) |
| 593 | + |
| 594 | + if set(teacher_labels) != set(student_labels): |
| 595 | + _LOGGER.warning( |
| 596 | + f"Teacher labels {teacher_labels} do not match " |
| 597 | + f"student labels {student_labels}. Ignore this warning " |
| 598 | + "if this is expected behavior." |
| 599 | + ) |
| 600 | + else: |
| 601 | + if student_ids != teacher_ids: |
| 602 | + _LOGGER.warning( |
| 603 | + "Teacher and student labels match, but the mapping " |
| 604 | + "between teachers labels and ids does not match the " |
| 605 | + "mapping between student labels and ids. " |
| 606 | + "The student's mapping will be overwritten " |
| 607 | + "by the teacher's mapping." |
| 608 | + ) |
| 609 | + label_to_id = teacher.config.label2id |
| 610 | + label_list = teacher_labels |
| 611 | + |
| 612 | + return label_to_id, label_list |
| 613 | + |
| 614 | + |
578 | 615 | def _get_label_list(labels): |
579 | 616 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go |
580 | 617 | # through the dataset to get the unique labels. |
|
0 commit comments