diff --git a/app/api/routers/evaluation.py b/app/api/routers/evaluation.py index 97ff305..6cc49a5 100644 --- a/app/api/routers/evaluation.py +++ b/app/api/routers/evaluation.py @@ -57,11 +57,26 @@ async def get_evaluation_with_trainer_export(request: Request, data_file.flush() data_file.seek(0) evaluation_id = tracking_id or str(uuid.uuid4()) - evaluation_accepted = model_service.train_supervised(data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names)) + evaluation_accepted, experiment_id, run_id = model_service.train_supervised( + data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names) + ) if evaluation_accepted: - return JSONResponse(content={"message": "Your evaluation started successfully.", "evaluation_id": evaluation_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={ + "message": "Your evaluation started successfully.", + "evaluation_id": evaluation_id, + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_202_ACCEPTED + ) else: - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": "Another training or evaluation on this model is still active. Please retry later.", + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_503_SERVICE_UNAVAILABLE + ) @router.post("/sanity-check", diff --git a/app/api/routers/metacat_training.py b/app/api/routers/metacat_training.py index 1e16e20..5e60adb 100644 --- a/app/api/routers/metacat_training.py +++ b/app/api/routers/metacat_training.py @@ -2,7 +2,7 @@ import uuid import json import logging -from typing import List, Union +from typing import List, Tuple, Union from typing_extensions import Annotated from fastapi import APIRouter, Depends, UploadFile, Query, Request, File @@ -53,7 +53,7 @@ async def train_metacat(request: Request, data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) try: - training_accepted = model_service.train_metacat(data_file, + training_response = model_service.train_metacat(data_file, epochs, log_frequency, training_id, @@ -65,13 +65,27 @@ async def train_metacat(request: Request, for file in files: file.close() - return _get_training_response(training_accepted, training_id) + return _get_training_response(training_response, training_id) -def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse: +def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse: + training_accepted, experiment_id, run_id = training_response if training_accepted: logger.debug("Training accepted with ID: %s", training_id) - return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={ + "message": "Your training started successfully.", + "training_id": training_id, + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_202_ACCEPTED + ) else: logger.debug("Training refused due to another active training or evaluation on this model") - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry your training later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": "Another training or evaluation on this model is still active. Please retry your training later.", + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_503_SERVICE_UNAVAILABLE + ) diff --git a/app/api/routers/supervised_training.py b/app/api/routers/supervised_training.py index fd66443..9a49c60 100644 --- a/app/api/routers/supervised_training.py +++ b/app/api/routers/supervised_training.py @@ -2,7 +2,7 @@ import uuid import json import logging -from typing import List, Union +from typing import List, Tuple, Union from typing_extensions import Annotated from fastapi import APIRouter, Depends, UploadFile, Query, Request, File, Form @@ -55,7 +55,7 @@ async def train_supervised(request: Request, data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) try: - training_accepted = model_service.train_supervised(data_file, + training_response = model_service.train_supervised(data_file, epochs, log_frequency, training_id, @@ -69,13 +69,27 @@ async def train_supervised(request: Request, for file in files: file.close() - return _get_training_response(training_accepted, training_id) + return _get_training_response(training_response, training_id) -def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse: +def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse: + training_accepted, experiment_id, run_id = training_response if training_accepted: logger.debug("Training accepted with ID: %s", training_id) - return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={ + "message": "Your training started successfully.", + "training_id": training_id, + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_202_ACCEPTED + ) else: logger.debug("Training refused due to another active training or evaluation on this model") - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry your training later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": "Another training or evaluation on this model is still active. Please retry your training later.", + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_503_SERVICE_UNAVAILABLE + ) diff --git a/app/api/routers/unsupervised_training.py b/app/api/routers/unsupervised_training.py index c3925aa..831cd64 100644 --- a/app/api/routers/unsupervised_training.py +++ b/app/api/routers/unsupervised_training.py @@ -5,7 +5,7 @@ import logging import datasets import zipfile -from typing import List, Union +from typing import List, Tuple, Union from typing_extensions import Annotated from fastapi import APIRouter, Depends, UploadFile, Query, Request, File @@ -65,7 +65,7 @@ async def train_unsupervised(request: Request, data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) try: - training_accepted = model_service.train_unsupervised(data_file, + training_response = model_service.train_unsupervised(data_file, epochs, log_frequency, training_id, @@ -79,7 +79,7 @@ async def train_unsupervised(request: Request, for file in files: file.close() - return _get_training_response(training_accepted, training_id) + return _get_training_response(training_response, training_id) @router.post("/train_unsupervised_with_hf_hub_dataset", @@ -133,7 +133,7 @@ async def train_unsupervised_with_hf_dataset(request: Request, hf_dataset.save_to_disk(data_dir.name) training_id = tracking_id or str(uuid.uuid4()) - training_accepted = model_service.train_unsupervised(data_dir, + training_response = model_service.train_unsupervised(data_dir, epochs, log_frequency, training_id, @@ -143,13 +143,27 @@ async def train_unsupervised_with_hf_dataset(request: Request, lr_override=lr_override, test_size=test_size, description=description) - return _get_training_response(training_accepted, training_id) + return _get_training_response(training_response, training_id) -def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse: +def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse: + training_accepted, experiment_id, run_id = training_response if training_accepted: logger.debug("Training accepted with ID: %s", training_id) - return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={ + "message": "Your training started successfully.", + "training_id": training_id, + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_202_ACCEPTED + ) else: logger.debug("Training refused due to another active training or evaluation on this model") - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": "Another training or evaluation on this model is still active. Please retry later.", + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_503_SERVICE_UNAVAILABLE + ) diff --git a/app/model_services/base.py b/app/model_services/base.py index fceb9a8..b431eff 100644 --- a/app/model_services/base.py +++ b/app/model_services/base.py @@ -56,11 +56,11 @@ def batch_annotate(self, texts: List[str]) -> List[List[Dict[str, Any]]]: def init_model(self) -> None: raise NotImplementedError - def train_supervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool: + def train_supervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]: raise NotImplementedError - def train_unsupervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool: + def train_unsupervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]: raise NotImplementedError - def train_metacat(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool: + def train_metacat(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]: raise NotImplementedError diff --git a/app/model_services/huggingface_ner_model.py b/app/model_services/huggingface_ner_model.py index 2c8b0d4..afd8e24 100644 --- a/app/model_services/huggingface_ner_model.py +++ b/app/model_services/huggingface_ner_model.py @@ -156,7 +156,7 @@ def train_supervised(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._supervised_trainer is None: raise ConfigurationException("The supervised trainer is not enabled") return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) @@ -170,7 +170,7 @@ def train_unsupervised(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._unsupervised_trainer is None: raise ConfigurationException("The unsupervised trainer is not enabled") return self._unsupervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) diff --git a/app/model_services/medcat_model.py b/app/model_services/medcat_model.py index 9452c99..92bd22e 100644 --- a/app/model_services/medcat_model.py +++ b/app/model_services/medcat_model.py @@ -121,7 +121,7 @@ def train_supervised(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._supervised_trainer is None: raise ConfigurationException("The supervised trainer is not enabled") return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) @@ -135,7 +135,7 @@ def train_unsupervised(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._unsupervised_trainer is None: raise ConfigurationException("The unsupervised trainer is not enabled") return self._unsupervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) @@ -149,7 +149,7 @@ def train_metacat(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._metacat_trainer is None: raise ConfigurationException("The metacat trainer is not enabled") return self._metacat_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) diff --git a/app/model_services/medcat_model_deid.py b/app/model_services/medcat_model_deid.py index deba5fe..43ed794 100644 --- a/app/model_services/medcat_model_deid.py +++ b/app/model_services/medcat_model_deid.py @@ -2,7 +2,7 @@ import inspect import threading import torch -from typing import Dict, List, TextIO, Optional, Any, final, Callable +from typing import Dict, List, TextIO, Tuple, Optional, Any, final, Callable from functools import partial from transformers import pipeline from medcat.cat import CAT @@ -147,7 +147,7 @@ def train_supervised(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._supervised_trainer is None: raise ConfigurationException("Trainers are not enabled") return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) diff --git a/app/trainers/base.py b/app/trainers/base.py index 2cc22fe..da4f795 100644 --- a/app/trainers/base.py +++ b/app/trainers/base.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from functools import partial -from typing import TextIO, Callable, Dict, Optional, Any, List, Union, final +from typing import TextIO, Callable, Dict, Tuple, Optional, Any, List, Union, final from config import Settings from management.tracker_client import TrackerClient from data import doc_dataset, anno_dataset @@ -26,6 +26,8 @@ def __init__(self, config: Settings, model_name: str) -> None: self._model_name = model_name self._training_lock = threading.Lock() self._training_in_progress = False + self._experiment_id = None + self._run_id = None self._tracker_client = TrackerClient(self._config.MLFLOW_TRACKING_URI) self._executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(max_workers=1) @@ -37,6 +39,14 @@ def model_name(self) -> str: def model_name(self, model_name: str) -> None: self._model_name = model_name + @property + def experiment_id(self) -> str: + return self._experiment_id or "" + + @property + def run_id(self) -> str: + return self._run_id or "" + @final def start_training(self, run: Callable, @@ -48,13 +58,13 @@ def start_training(self, input_file_name: str, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, - synchronised: bool = False) -> bool: + synchronised: bool = False) -> Tuple[bool, str, str]: with self._training_lock: if self._training_in_progress: - return False + return False, self.experiment_id, self.run_id else: loop = asyncio.get_event_loop() - experiment_id, run_id = self._tracker_client.start_tracking( + self._experiment_id, self._run_id = self._tracker_client.start_tracking( model_name=self._model_name, input_file_name=input_file_name, base_model_original=self._config.BASE_MODEL_FULL_PATH, @@ -101,15 +111,15 @@ def start_training(self, else: raise ValueError(f"Unknown training type: {training_type}") - logger.info("Starting training job: %s with experiment ID: %s", training_id, experiment_id) + logger.info("Starting training job: %s with experiment ID: %s", training_id, self.experiment_id) self._training_in_progress = True training_task = asyncio.ensure_future(loop.run_in_executor(self._executor, - partial(run, self, training_params, data_file, log_frequency, run_id, description))) + partial(run, self, training_params, data_file, log_frequency, self.run_id, description))) if synchronised: loop.run_until_complete(training_task) - return True + return True, self.experiment_id, self.run_id @staticmethod def _make_model_file_copy(model_file_path: str, run_id: str) -> str: @@ -161,7 +171,7 @@ def train(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: training_type = TrainingType.SUPERVISED.value training_params = { "data_path": data_file.name, @@ -204,7 +214,7 @@ def train(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: training_type = TrainingType.UNSUPERVISED.value training_params = { "nepochs": epochs, diff --git a/tests/app/api/test_serving_common.py b/tests/app/api/test_serving_common.py index b0be036..91b156c 100644 --- a/tests/app/api/test_serving_common.py +++ b/tests/app/api/test_serving_common.py @@ -11,7 +11,7 @@ from utils import get_settings from model_services.medcat_model import MedCATModel from management.model_manager import ModelManager -from unittest.mock import create_autospec +from unittest.mock import create_autospec, patch config = get_settings() config.ENABLE_TRAINING_APIS = "true" @@ -258,13 +258,14 @@ def test_preview_trainer_export_on_missing_project_or_document(pid, did, client) def test_train_supervised(model_service, client): + model_service.train_supervised.return_value = (True, "experiment_id", "run_id") with open(TRAINER_EXPORT_PATH, "rb") as f: response = client.post("/train_supervised", files=[("trainer_export", f)]) model_service.train_supervised.assert_called() assert response.status_code == 202 assert response.json()["message"] == "Your training started successfully." - assert "training_id" in response.json() + assert all(key in response.json() for key in ["training_id", "experiment_id", "run_id"]) # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f: @@ -278,13 +279,14 @@ def test_train_supervised(model_service, client): def test_train_unsupervised(model_service, client): + model_service.train_unsupervised.return_value = (True, "experiment_id", "run_id") with tempfile.TemporaryFile("r+b") as f: f.write(str.encode("[\"Spinal stenosis\"]")) response = client.post("/train_unsupervised", files=[("training_data", f)]) model_service.train_unsupervised.assert_called() assert response.json()["message"] == "Your training started successfully." - assert "training_id" in response.json() + assert all(key in response.json() for key in ["training_id", "experiment_id", "run_id"]) # test with provided tracking ID with tempfile.TemporaryFile("r+b") as f: @@ -305,12 +307,13 @@ def test_train_unsupervised_with_hf_hub_dataset(model_service, client): "model_card": None, }) model_service.info.return_value = model_card + model_service.train_unsupervised.return_value = (True, "experiment_id", "run_id") response = client.post("/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb") model_service.train_unsupervised.assert_called() assert response.json()["message"] == "Your training started successfully." - assert "training_id" in response.json() + assert all(key in response.json() for key in ["training_id", "experiment_id", "run_id"]) # test with provided tracking ID response = client.post(f"/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb&tracking_id={TRACKING_ID}") @@ -322,13 +325,14 @@ def test_train_unsupervised_with_hf_hub_dataset(model_service, client): def test_train_metacat(model_service, client): + model_service.train_metacat.return_value = (True, "experiment_id", "run_id") with open(TRAINER_EXPORT_PATH, "rb") as f: response = client.post("/train_metacat", files=[("trainer_export", f)]) model_service.train_metacat.assert_called() assert response.status_code == 202 assert response.json()["message"] == "Your training started successfully." - assert "training_id" in response.json() + assert all(key in response.json() for key in ["training_id", "experiment_id", "run_id"]) # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f: @@ -341,7 +345,8 @@ def test_train_metacat(model_service, client): assert response.json().get("training_id") == TRACKING_ID -def test_evaluate_with_trainer_export(client): +def test_evaluate_with_trainer_export(model_service, client): + model_service.train_supervised.return_value = (True, "experiment_id", "run_id") with open(TRAINER_EXPORT_PATH, "rb") as f: response = client.post("/evaluate", files=[("trainer_export", f)]) diff --git a/tests/app/api/test_serving_hf_ner.py b/tests/app/api/test_serving_hf_ner.py index c6ac825..c0f8dfa 100644 --- a/tests/app/api/test_serving_hf_ner.py +++ b/tests/app/api/test_serving_hf_ner.py @@ -44,9 +44,10 @@ def test_train_unsupervised_with_hf_hub_dataset(model_service, client): "model_card": None, }) model_service.info.return_value = model_card + model_service.train_unsupervised.return_value = (True, "experiment_id", "run_id") response = client.post("/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb") model_service.train_unsupervised.assert_called() assert response.json()["message"] == "Your training started successfully." - assert "training_id" in response.json() + assert all([key in response.json() for key in ["training_id", "experiment_id", "run_id"]])