diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index e2d3d9c..0106fac 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -19,6 +19,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + fetch-depth: 0 - name: Set up Python uses: actions/setup-python@v4 diff --git a/llmsql/evaluation/evaluate.py b/llmsql/evaluation/evaluate.py index 1fe79c1..bce7241 100644 --- a/llmsql/evaluation/evaluate.py +++ b/llmsql/evaluation/evaluate.py @@ -80,6 +80,7 @@ def evaluate( metrics = { "total": 0, "matches": 0, + "exact_string_matches": 0, "pred_none": 0, "gold_none": 0, "sql_errors": 0, @@ -95,6 +96,7 @@ def evaluate( metrics["pred_none"] += m["pred_none"] metrics["gold_none"] += m["gold_none"] metrics["sql_errors"] += m["sql_error"] + metrics["exact_string_matches"] += m["exact_string_match"] if mismatch_info: mismatches.append(mismatch_info) @@ -107,12 +109,18 @@ def evaluate( metrics["pred_none"], metrics["gold_none"], metrics["sql_errors"], + metrics["exact_string_matches"], ) # --- Build report structure --- report = { **metrics, "accuracy": metrics["matches"] / metrics["total"] if metrics["total"] else 0, + "exact_string_match_accuracy": ( + metrics["exact_string_matches"] / metrics["total"] + if metrics["total"] + else 0 + ), "mismatches": mismatches, "timestamp": datetime.now(timezone.utc).isoformat(), "input_mode": input_mode, diff --git a/llmsql/utils/evaluation_utils.py b/llmsql/utils/evaluation_utils.py index de86815..b28dec5 100644 --- a/llmsql/utils/evaluation_utils.py +++ b/llmsql/utils/evaluation_utils.py @@ -114,7 +114,7 @@ def evaluate_sample( gold_results = execute_sql(conn, gold_sql) # Initialize counters for this sample - pred_none = gold_none = sql_error = 0 + pred_none = gold_none = sql_error = exact_string_match = 0 # Track if gold query returned a NULL-equivalent result if gold_results == [(None,)]: @@ -136,6 +136,9 @@ def evaluate_sample( pred_res = execute_sql(conn, pred_sql_fixed) last_pred_res = pred_res + if pred_sql_fixed.strip() == gold_sql.strip(): + exact_string_match = 1 + # Update metrics if pred_res is None: # execution failed sql_error += 1 @@ -165,15 +168,16 @@ def evaluate_sample( return ( is_match, mismatch_info, - {"pred_none": pred_none, "gold_none": gold_none, "sql_error": sql_error}, + { + "pred_none": pred_none, + "gold_none": gold_none, + "sql_error": sql_error, + "exact_string_match": exact_string_match, + }, ) -def download_benchmark_file( - repo_id: str, - filename: str, - local_dir: Path -) -> str: +def download_benchmark_file(repo_id: str, filename: str, local_dir: Path) -> str: """Download a benchmark file from HuggingFace Hub.""" file_path = hf_hub_download( repo_id=repo_id, diff --git a/llmsql/utils/rich_utils.py b/llmsql/utils/rich_utils.py index 90b2075..3fc41bf 100644 --- a/llmsql/utils/rich_utils.py +++ b/llmsql/utils/rich_utils.py @@ -36,6 +36,7 @@ def print_summary( pred_none: int, gold_none: int, sql_errors: int, + exact_string_matches: int, ) -> None: """Pretty-print summary with Rich.""" table = Table(title="[green]Evaluation Summary[/green]", show_lines=True) @@ -44,6 +45,10 @@ def print_summary( table.add_row("Total Samples", str(total)) table.add_row("Correct Results", f"{matches} ({matches / total:.2%})") + table.add_row( + "Exact String Match", + f"{exact_string_matches} ({exact_string_matches / total:.2%})", + ) table.add_row("Prediction None", f"{pred_none}/{total}") table.add_row("Ground Truth None", f"{gold_none}/{total}") table.add_row("SQL Errors", str(sql_errors)) diff --git a/llmsql/utils/utils.py b/llmsql/utils/utils.py index fc9b1ab..0d62e84 100644 --- a/llmsql/utils/utils.py +++ b/llmsql/utils/utils.py @@ -2,6 +2,8 @@ import json from pathlib import Path +from transformers import AutoTokenizer + from llmsql.loggers.logging_config import log from llmsql.prompts.prompts import ( build_prompt_0shot, @@ -65,7 +67,7 @@ def build_all_requests( questions: list[dict], tables: dict, prompt_builder: Callable[[str, list[str], list[str], list[str | float | int]], str], - tokenizer=None, + tokenizer: AutoTokenizer = None, use_chat_template: bool = True, ) -> list[str]: """ diff --git a/tests/conftest.py b/tests/conftest.py index 9f4b0d2..2c8b736 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -103,7 +103,11 @@ def mock_utils(mocker, tmp_path): # evaluate_sample → always correct prediction mocker.patch( "llmsql.evaluation.evaluate.evaluate_sample", - return_value=(1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}), + return_value=( + 1, + None, + {"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 0}, + ), ) # rich logging diff --git a/tests/evaluation/test_evaluator_different_llmsql_versions.py b/tests/evaluation/test_evaluator_different_llmsql_versions.py index 12f82fa..9eb4487 100644 --- a/tests/evaluation/test_evaluator_different_llmsql_versions.py +++ b/tests/evaluation/test_evaluator_different_llmsql_versions.py @@ -40,7 +40,7 @@ async def test_evaluate_runs_with_valid_versions( lambda *a, **k: ( 1, None, - {"pred_none": 0, "gold_none": 0, "sql_error": 0}, + {"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1}, ), ) monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) @@ -83,7 +83,7 @@ async def test_evaluate_raises_with_invalid_version( lambda *a, **k: ( 1, None, - {"pred_none": 0, "gold_none": 0, "sql_error": 0}, + {"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1}, ), ) monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) diff --git a/tests/evaluation/test_evaluator_stability.py b/tests/evaluation/test_evaluator_stability.py index 09fb1a4..321b115 100644 --- a/tests/evaluation/test_evaluator_stability.py +++ b/tests/evaluation/test_evaluator_stability.py @@ -31,7 +31,11 @@ async def test_evaluate_with_mock(monkeypatch, temp_dir, dummy_db_file): # Monkeypatch dependencies monkeypatch.setattr( "llmsql.utils.evaluation_utils.evaluate_sample", - lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}), + lambda *a, **k: ( + 1, + None, + {"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1}, + ), ) monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) @@ -68,7 +72,11 @@ async def test_evaluate_saves_report(monkeypatch, temp_dir, dummy_db_file): # Mock dependencies monkeypatch.setattr( "llmsql.utils.evaluation_utils.evaluate_sample", - lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}), + lambda *a, **k: ( + 1, + None, + {"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1}, + ), ) monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) @@ -106,7 +114,11 @@ async def test_evaluate_with_jsonl_file(monkeypatch, temp_dir, dummy_db_file): # Mock dependencies monkeypatch.setattr( "llmsql.utils.evaluation_utils.evaluate_sample", - lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}), + lambda *a, **k: ( + 1, + None, + {"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1}, + ), ) monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) @@ -146,7 +158,11 @@ async def test_evaluate_with_dict_list(monkeypatch, temp_dir, dummy_db_file): # Mock dependencies monkeypatch.setattr( "llmsql.utils.evaluation_utils.evaluate_sample", - lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}), + lambda *a, **k: ( + 1, + None, + {"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1}, + ), ) monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) @@ -245,7 +261,7 @@ def test_mismatch_handling(mock_utils, mocker): return_value=( 0, {"info": "bad"}, - {"pred_none": 0, "gold_none": 0, "sql_error": 0}, + {"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1}, ), ) diff --git a/tests/utils/test_evaluation_utils.py b/tests/utils/test_evaluation_utils.py index 68a1cba..3343017 100644 --- a/tests/utils/test_evaluation_utils.py +++ b/tests/utils/test_evaluation_utils.py @@ -114,7 +114,7 @@ def test_multiple_table_references(self) -> None: sql = "SELECT * FROM 'Table' JOIN 'Table' ON Table.id = Table.parent_id" result = fix_table_name(sql, "my_table") # Placeholder in FROM should be swapped out for the real table name - assert 'FROM Table' not in result + assert "FROM Table" not in result assert "FROM 'Table'" not in result assert 'FROM "Table"' not in result assert 'FROM "my_table"' in result @@ -137,7 +137,7 @@ def test_case_sensitive_table_keyword(self) -> None: sql = "SELECT * FROM Table WHERE table_name = 'other_table'" result = fix_table_name(sql, "my_table") # Should only replace the FROM Table, not 'other_table' - assert result == 'SELECT * FROM "my_table" WHERE table_name = \'other_table\'' + assert result == "SELECT * FROM \"my_table\" WHERE table_name = 'other_table'" def test_complex_query(self) -> None: """Test complex query with joins and subqueries.""" @@ -203,6 +203,7 @@ def test_matching_prediction(self, eval_db, questions_dict) -> None: assert metrics["pred_none"] == 0 assert metrics["gold_none"] == 0 assert metrics["sql_error"] == 0 + assert metrics["exact_string_match"] == 0 def test_non_matching_prediction(self, eval_db, questions_dict) -> None: """Test when prediction does not match gold SQL.""" @@ -299,9 +300,11 @@ def test_metrics_counters(self, eval_db, questions_dict) -> None: assert "pred_none" in metrics assert "gold_none" in metrics assert "sql_error" in metrics + assert "exact_string_match" in metrics assert isinstance(metrics["pred_none"], int) assert isinstance(metrics["gold_none"], int) assert isinstance(metrics["sql_error"], int) + assert isinstance(metrics["exact_string_match"], int) def test_mismatch_info_structure(self, eval_db, questions_dict) -> None: """Test structure of mismatch_info when prediction fails.""" @@ -343,3 +346,4 @@ def test_null_results_metrics(self, eval_db) -> None: assert metrics["gold_none"] == 1 assert metrics["pred_none"] == 1 assert metrics["sql_error"] == 0 + assert metrics["exact_string_match"] == 1 diff --git a/tests/utils/test_rich_utils.py b/tests/utils/test_rich_utils.py index e495ebf..ed42a9c 100644 --- a/tests/utils/test_rich_utils.py +++ b/tests/utils/test_rich_utils.py @@ -34,7 +34,14 @@ def test_print_summary_includes_metrics(monkeypatch) -> None: recording_console = Console(record=True) monkeypatch.setattr(rich_utils, "console", recording_console) - rich_utils.print_summary(total=5, matches=3, pred_none=1, gold_none=0, sql_errors=2) + rich_utils.print_summary( + total=5, + matches=3, + pred_none=1, + gold_none=0, + sql_errors=2, + exact_string_matches=1, + ) output = recording_console.export_text() assert "Evaluation Summary" in output