Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configure_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self, model_name, data_dir, **kwargs):
"{:}_{:}_{:}_predictions.pkl").format

# update defaults with passed-in hyperparameters
self.tasks = {}
self.update(kwargs)

# default hyperparameters for single-task models
Expand Down
20 changes: 20 additions & 0 deletions finetune/classification/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,23 @@ def _create_examples(self, lines, split):
examples += self._load_glue(
lines, split, -3, -2, -1, True, len(examples), True)
return examples

class StandardTSV(ClassificationTask):
def __init__(self, config: configure_finetuning.FinetuningConfig,
task_name: str, task_config: dict, tokenizer):
super(StandardTSV, self).__init__(config, task_name, tokenizer,
task_config["labels"])
self.task_config = task_config

def get_examples(self, split):
return self._create_examples(read_tsv(
os.path.join(self.config.raw_data_dir(self.name), split + ".tsv"),
quotechar="\"",
max_lines=100 if self.config.debug else None), split)

def _create_examples(self, lines, split):
text_column_2 = self.task_config.get("text_column_2", None)
header = self.task_config.get("header", False)
return self._load_glue(lines, split, self.task_config["text_column"],
text_column_2, self.task_config["label_column"],
skip_first_line=header)
7 changes: 7 additions & 0 deletions finetune/task_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,12 @@ def get_task(config: configure_finetuning.FinetuningConfig, task_name,
return qa_tasks.SearchQA(config, tokenizer)
elif task_name == "chunk":
return tagging_tasks.Chunking(config, tokenizer)
elif (task_name in config.tasks):
if config.tasks[task_name]["type"] == "classification":
return classification_tasks.StandardTSV(config, task_name,
config.tasks[task_name],
tokenizer)
else:
raise ValueError("Unknown task type: " + config.tasks[task_name]["type"])
else:
raise ValueError("Unknown task " + task_name)
10 changes: 10 additions & 0 deletions run_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,21 @@ def main():
help="The name of the model being fine-tuned.")
parser.add_argument("--hparams", default="{}",
help="JSON dict of model hyperparameters.")
parser.add_argument("--task-config", default="{}",
help="JSON dict of custom fine-tuning task parameters")
args = parser.parse_args()
if args.hparams.endswith(".json"):
hparams = utils.load_json(args.hparams)
else:
hparams = json.loads(args.hparams)

if args.task_config.endswith(".json"):
task_config = utils.load_json(args.task_config)
else:
task_config = json.loads(args.task_config)
if len(task_config.keys()) > 0:
hparams["tasks"] = task_config

tf.logging.set_verbosity(tf.logging.ERROR)
run_finetuning(configure_finetuning.FinetuningConfig(
args.model_name, args.data_dir, **hparams))
Expand Down