diff --git a/configure_finetuning.py b/configure_finetuning.py index 044e816..50d5986 100644 --- a/configure_finetuning.py +++ b/configure_finetuning.py @@ -128,6 +128,7 @@ def __init__(self, model_name, data_dir, **kwargs): "{:}_{:}_{:}_predictions.pkl").format # update defaults with passed-in hyperparameters + self.tasks = {} self.update(kwargs) # default hyperparameters for single-task models diff --git a/finetune/classification/classification_tasks.py b/finetune/classification/classification_tasks.py index 170fd11..0536363 100644 --- a/finetune/classification/classification_tasks.py +++ b/finetune/classification/classification_tasks.py @@ -437,3 +437,23 @@ def _create_examples(self, lines, split): examples += self._load_glue( lines, split, -3, -2, -1, True, len(examples), True) return examples + +class StandardTSV(ClassificationTask): + def __init__(self, config: configure_finetuning.FinetuningConfig, + task_name: str, task_config: dict, tokenizer): + super(StandardTSV, self).__init__(config, task_name, tokenizer, + task_config["labels"]) + self.task_config = task_config + + def get_examples(self, split): + return self._create_examples(read_tsv( + os.path.join(self.config.raw_data_dir(self.name), split + ".tsv"), + quotechar="\"", + max_lines=100 if self.config.debug else None), split) + + def _create_examples(self, lines, split): + text_column_2 = self.task_config.get("text_column_2", None) + header = self.task_config.get("header", False) + return self._load_glue(lines, split, self.task_config["text_column"], + text_column_2, self.task_config["label_column"], + skip_first_line=header) diff --git a/finetune/task_builder.py b/finetune/task_builder.py index 978b270..ca02f64 100644 --- a/finetune/task_builder.py +++ b/finetune/task_builder.py @@ -66,5 +66,12 @@ def get_task(config: configure_finetuning.FinetuningConfig, task_name, return qa_tasks.SearchQA(config, tokenizer) elif task_name == "chunk": return tagging_tasks.Chunking(config, tokenizer) + elif (task_name in config.tasks): + if config.tasks[task_name]["type"] == "classification": + return classification_tasks.StandardTSV(config, task_name, + config.tasks[task_name], + tokenizer) + else: + raise ValueError("Unknown task type: " + config.tasks[task_name]["type"]) else: raise ValueError("Unknown task " + task_name) diff --git a/run_finetuning.py b/run_finetuning.py index 5d20354..920ddf8 100644 --- a/run_finetuning.py +++ b/run_finetuning.py @@ -309,11 +309,21 @@ def main(): help="The name of the model being fine-tuned.") parser.add_argument("--hparams", default="{}", help="JSON dict of model hyperparameters.") + parser.add_argument("--task-config", default="{}", + help="JSON dict of custom fine-tuning task parameters") args = parser.parse_args() if args.hparams.endswith(".json"): hparams = utils.load_json(args.hparams) else: hparams = json.loads(args.hparams) + + if args.task_config.endswith(".json"): + task_config = utils.load_json(args.task_config) + else: + task_config = json.loads(args.task_config) + if len(task_config.keys()) > 0: + hparams["tasks"] = task_config + tf.logging.set_verbosity(tf.logging.ERROR) run_finetuning(configure_finetuning.FinetuningConfig( args.model_name, args.data_dir, **hparams))