Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ credentials.ini

# MLFlow data directory
mlruns/
mlflow.db

#JSON data files used by the 3d visualizer
*.json
24 changes: 24 additions & 0 deletions tests/hyrax/test_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
from pathlib import Path
from unittest.mock import patch

from hyrax.config_utils import find_most_recent_results_dir


def test_mlflow_tracking_uri_set_to_root_results_dir(loopback_hyrax):
"""
Verify that MLflow tracking URI is set to the root results directory,
not the per-run results directory. This is a regression test for the fix
that ensures MLflow stores its backend in the correct location.
"""
h, _ = loopback_hyrax

# Get the expected root results directory
expected_root_dir = Path(h.config["general"]["results_dir"]).expanduser().resolve()
expected_tracking_uri = "file://" + str(expected_root_dir / "mlflow")

# Mock mlflow.set_tracking_uri to capture the call
# Since mlflow is imported inside the run() method, we need to patch it at the mlflow module level
with patch("mlflow.set_tracking_uri") as mock_set_tracking_uri:
h.train()

# Verify set_tracking_uri was called with the root results directory
mock_set_tracking_uri.assert_called_once_with(expected_tracking_uri)


def test_train(loopback_hyrax):
"""
Simple test that training succeeds with the loopback
Expand Down