Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def register_model(
model_name=model_name,
model_path=model_path,
model_manager=ModelManager(model_service_type, config),
model_type=model_type.value,
training_type=t_type,
run_name=run_name,
model_config=m_config,
Expand Down
51 changes: 45 additions & 6 deletions app/management/tracker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,36 @@ def log_model_config(config: Dict[str, str]) -> None:

mlflow.log_params(config)

@staticmethod
def _set_model_version_tags(
client: MlflowClient,
model_name: str,
version: str,
model_type: str,
validation_status: Optional[str] = None,
) -> None:
"""
Sets standard tags on a model version for serving and discovery.

Args:
client (MlflowClient): The MLflow client to use for setting tags.
model_name (str): The name of the registered model.
version (str): The version of the model.
model_type (str): The type of the model (e.g., "medcat_snomed").
validation_status (Optional[str]): The status of the model validation (e.g., "pending").
"""
try:
client.set_model_version_tag(
name=model_name, version=version, key="model_uri", value=f"models:/{model_name}/{version}"
)
client.set_model_version_tag(name=model_name, version=version, key="model_type", value=model_type)
if validation_status is not None:
client.set_model_version_tag(
name=model_name, version=version, key="validation_status", value=validation_status
)
except Exception:
logger.warning("Failed to set tags on version %s of model %s", version, model_name)

@staticmethod
def log_model(
model_name: str,
Expand Down Expand Up @@ -381,6 +411,7 @@ def save_pretrained_model(
model_name: str,
model_path: str,
model_manager: ModelManager,
model_type: str,
training_type: Optional[str] = "",
run_name: Optional[str] = "",
model_config: Optional[Dict] = None,
Expand All @@ -394,6 +425,7 @@ def save_pretrained_model(
model_name (str): The name of the model.
model_path (str): The path to the pretrained model.
model_manager (ModelManager): The instance of ModelManager used for model saving.
model_type (str): The type of the model (e.g., "medcat_snomed").
training_type (Optional[str]): The type of training used for the model.
run_name (Optional[str]): The name of the run for identification purposes.
model_config (Optional[Dict]): The configuration of the model to save.
Expand Down Expand Up @@ -423,6 +455,10 @@ def save_pretrained_model(
mlflow.set_tags(tags)
model_name = model_name.replace(" ", "_")
TrackerClient.log_model(model_name, model_path, model_manager, model_name)
client = MlflowClient()
versions = client.search_model_versions(f"name='{model_name}'", order_by=["version_number DESC"])
if versions:
TrackerClient._set_model_version_tags(client, model_name, versions[0].version, model_type)
TrackerClient.end_with_success()
except KeyboardInterrupt:
TrackerClient.end_with_interruption()
Expand Down Expand Up @@ -502,6 +538,7 @@ def save_model(
filepath: str,
model_name: str,
model_manager: ModelManager,
model_type: str,
validation_status: str = "pending",
) -> str:
"""
Expand All @@ -511,6 +548,7 @@ def save_model(
filepath (str): The artifact path of the model to save.
model_name (str): The name of the model.
model_manager (ModelManager): The instance of ModelManager used for model saving.
model_type (str): The type of the model (e.g., "medcat_snomed").
validation_status (str): The status of the model validation (default: "pending").

Returns:
Expand All @@ -523,18 +561,19 @@ def save_model(

if not mlflow.get_tracking_uri().startswith("file:/"):
TrackerClient.log_model(model_name, filepath, model_manager, model_name)
versions = self.mlflow_client.search_model_versions(f"name='{model_name}'")
self.mlflow_client.set_model_version_tag(
name=model_name,
version=versions[0].version,
key="validation_status",
value=validation_status,
versions = self.mlflow_client.search_model_versions(
f"name='{model_name}'", order_by=["version_number DESC"]
)
if versions:
TrackerClient._set_model_version_tags(
self.mlflow_client, model_name, versions[0].version, model_type, validation_status
)
else:
TrackerClient.log_model(model_name, filepath, model_manager)

artifact_uri = mlflow.get_artifact_uri(model_name)
mlflow.set_tag("training.output.model_uri", artifact_uri)
mlflow.set_tag("training.output.model_type", model_type)

return artifact_uri

Expand Down
1 change: 1 addition & 0 deletions app/trainers/huggingface_llm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def run(
retrained_model_pack_path,
self._model_name,
self._model_manager,
self._model_service.info().model_type.value,
)
logger.info(f"Retrained model saved: {model_uri}")
else:
Expand Down
2 changes: 2 additions & 0 deletions app/trainers/huggingface_ner_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def run(
retrained_model_pack_path,
self._model_name,
self._model_manager,
self._model_service.info().model_type.value,
)
logger.info(f"Retrained model saved: {model_uri}")
else:
Expand Down Expand Up @@ -664,6 +665,7 @@ def run(
retrained_model_pack_path,
self._model_name,
self._model_manager,
self._model_service.info().model_type.value,
)
logger.info(f"Retrained model saved: {model_uri}")
else:
Expand Down
7 changes: 6 additions & 1 deletion app/trainers/medcat_deid_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,12 @@ def run(
)
with open(cdb_config_path, "w") as f:
json.dump(dump_pydantic_object_to_dict(model.config), f)
model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager)
model_uri = self._tracker_client.save_model(
model_pack_path,
self._model_name,
self._model_manager,
self._model_service.info().model_type.value,
)
logger.info("Retrained model saved: %s", model_uri)
self._tracker_client.save_model_artifact(cdb_config_path, self._model_name)
else:
Expand Down
14 changes: 12 additions & 2 deletions app/trainers/medcat_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,12 @@ def run(
)
with open(cdb_config_path, "w") as f:
json.dump(dump_pydantic_object_to_dict(model.config), f)
model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager)
model_uri = self._tracker_client.save_model(
model_pack_path,
self._model_name,
self._model_manager,
self._model_service.info().model_type.value,
)
logger.info("Retrained model saved: %s", model_uri)
self._tracker_client.save_model_artifact(cdb_config_path, self._model_name)
else:
Expand Down Expand Up @@ -472,7 +477,12 @@ def run(
)
with open(cdb_config_path, "w") as f:
json.dump(dump_pydantic_object_to_dict(model.config), f)
model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager)
model_uri = self._tracker_client.save_model(
model_pack_path,
self._model_name,
self._model_manager,
self._model_service.info().model_type.value,
)
logger.info(f"Retrained model saved: {model_uri}")
self._tracker_client.save_model_artifact(cdb_config_path, self._model_name)
else:
Expand Down
7 changes: 6 additions & 1 deletion app/trainers/metacat_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,12 @@ def run(
)
with open(cdb_config_path, "w") as f:
json.dump(dump_pydantic_object_to_dict(model.config), f)
model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager)
model_uri = self._tracker_client.save_model(
model_pack_path,
self._model_name,
self._model_manager,
self._model_service.info().model_type.value,
)
logger.info("Retrained model saved: %s", model_uri)
self._tracker_client.save_model_artifact(cdb_config_path, self._model_name)
else:
Expand Down
41 changes: 37 additions & 4 deletions tests/app/monitoring/test_tracker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datasets
import pytest
import pandas as pd
from unittest.mock import Mock, call, ANY
from unittest.mock import Mock, call, patch, ANY
from app.management.tracker_client import TrackerClient
from app.data import doc_dataset
from app.domain import TrainerBackend
Expand Down Expand Up @@ -161,15 +161,30 @@ def test_save_model(mlflow_fixture):
mlflow_client.search_model_versions.return_value = [version]
tracker_client.mlflow_client = mlflow_client

artifact_uri = tracker_client.save_model("path/to/file.zip", "model_name", model_manager, "validation_status")
artifact_uri = tracker_client.save_model(
"path/to/file.zip", "model_name", model_manager, "model_type", "validation_status"
)

assert "artifacts/model_name" in artifact_uri
model_manager.log_model.assert_called_once_with("model_name", "path/to/file.zip", "model_name")
mlflow_client.set_model_version_tag.assert_called_once_with(name="model_name", version="1", key="validation_status", value="validation_status")
mlflow_client.search_model_versions.assert_called_once_with(
"name='model_name'", order_by=["version_number DESC"]
)
assert mlflow_client.set_model_version_tag.call_count == 3
mlflow_client.set_model_version_tag.assert_any_call(
name="model_name", version="1", key="model_uri", value="models:/model_name/1"
)
mlflow_client.set_model_version_tag.assert_any_call(
name="model_name", version="1", key="model_type", value="model_type"
)
mlflow_client.set_model_version_tag.assert_any_call(
name="model_name", version="1", key="validation_status", value="validation_status"
)
mlflow.set_tag.has_calls(
[
call("training.output.package", "file.zip"),
call("training.output.model_uri", artifact_uri),
call("training.output.model_type", "model_type"),
],
any_order=False,
)
Comment on lines 183 to 190
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect assertion method. Should be assert_has_calls instead of has_calls. The current code will not actually perform the assertion, allowing the test to pass even if the calls were not made.

Copilot uses AI. Check for mistakes.
Expand All @@ -184,14 +199,21 @@ def test_save_model_local(mlflow_fixture):
model_manager.save_model.assert_called_once_with("local_dir", "filepath")


def test_save_pretrained_model(mlflow_fixture):
@patch("app.management.tracker_client.MlflowClient")
def test_save_pretrained_model(mock_mlflow_client_class, mlflow_fixture):
tracker_client = TrackerClient("")
model_manager = Mock()
mlflow_client = Mock()
version = Mock()
version.version = "1"
mlflow_client.search_model_versions.return_value = [version]
mock_mlflow_client_class.return_value = mlflow_client

tracker_client.save_pretrained_model(
"model_name",
"model_path",
model_manager,
"model_type",
"training_type",
"run_name",
{"param": "value"},
Expand All @@ -212,6 +234,17 @@ def test_save_pretrained_model(mlflow_fixture):
assert len(mlflow.set_tags.call_args.args[0]["mlflow.source.name"]) > 0
assert mlflow.set_tags.call_args.args[0]["tag_name"] == "tag_value"

mlflow_client.search_model_versions.assert_called_once_with(
"name='model_name'", order_by=["version_number DESC"]
)
assert mlflow_client.set_model_version_tag.call_count == 2
mlflow_client.set_model_version_tag.assert_any_call(
name="model_name", version="1", key="model_uri", value="models:/model_name/1"
)
mlflow_client.set_model_version_tag.assert_any_call(
name="model_name", version="1", key="model_type", value="model_type"
)


def test_log_single_exception(mlflow_fixture):
tracker_client = TrackerClient("")
Expand Down