Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
token: ${{ secrets.GITHUB_TOKEN }}
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v4
Expand Down
8 changes: 8 additions & 0 deletions llmsql/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def evaluate(
metrics = {
"total": 0,
"matches": 0,
"exact_string_matches": 0,
"pred_none": 0,
"gold_none": 0,
"sql_errors": 0,
Expand All @@ -95,6 +96,7 @@ def evaluate(
metrics["pred_none"] += m["pred_none"]
metrics["gold_none"] += m["gold_none"]
metrics["sql_errors"] += m["sql_error"]
metrics["exact_string_matches"] += m["exact_string_match"]

if mismatch_info:
mismatches.append(mismatch_info)
Expand All @@ -107,12 +109,18 @@ def evaluate(
metrics["pred_none"],
metrics["gold_none"],
metrics["sql_errors"],
metrics["exact_string_matches"],
)

# --- Build report structure ---
report = {
**metrics,
"accuracy": metrics["matches"] / metrics["total"] if metrics["total"] else 0,
"exact_string_match_accuracy": (
metrics["exact_string_matches"] / metrics["total"]
if metrics["total"]
else 0
),
"mismatches": mismatches,
"timestamp": datetime.now(timezone.utc).isoformat(),
"input_mode": input_mode,
Expand Down
18 changes: 11 additions & 7 deletions llmsql/utils/evaluation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def evaluate_sample(
gold_results = execute_sql(conn, gold_sql)

# Initialize counters for this sample
pred_none = gold_none = sql_error = 0
pred_none = gold_none = sql_error = exact_string_match = 0

# Track if gold query returned a NULL-equivalent result
if gold_results == [(None,)]:
Expand All @@ -136,6 +136,9 @@ def evaluate_sample(
pred_res = execute_sql(conn, pred_sql_fixed)
last_pred_res = pred_res

if pred_sql_fixed.strip() == gold_sql.strip():
exact_string_match = 1

# Update metrics
if pred_res is None: # execution failed
sql_error += 1
Expand Down Expand Up @@ -165,15 +168,16 @@ def evaluate_sample(
return (
is_match,
mismatch_info,
{"pred_none": pred_none, "gold_none": gold_none, "sql_error": sql_error},
{
"pred_none": pred_none,
"gold_none": gold_none,
"sql_error": sql_error,
"exact_string_match": exact_string_match,
},
)


def download_benchmark_file(
repo_id: str,
filename: str,
local_dir: Path
) -> str:
def download_benchmark_file(repo_id: str, filename: str, local_dir: Path) -> str:
"""Download a benchmark file from HuggingFace Hub."""
file_path = hf_hub_download(
repo_id=repo_id,
Expand Down
5 changes: 5 additions & 0 deletions llmsql/utils/rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def print_summary(
pred_none: int,
gold_none: int,
sql_errors: int,
exact_string_matches: int,
) -> None:
"""Pretty-print summary with Rich."""
table = Table(title="[green]Evaluation Summary[/green]", show_lines=True)
Expand All @@ -44,6 +45,10 @@ def print_summary(

table.add_row("Total Samples", str(total))
table.add_row("Correct Results", f"{matches} ({matches / total:.2%})")
table.add_row(
"Exact String Match",
f"{exact_string_matches} ({exact_string_matches / total:.2%})",
)
table.add_row("Prediction None", f"{pred_none}/{total}")
table.add_row("Ground Truth None", f"{gold_none}/{total}")
table.add_row("SQL Errors", str(sql_errors))
Expand Down
4 changes: 3 additions & 1 deletion llmsql/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json
from pathlib import Path

from transformers import AutoTokenizer

from llmsql.loggers.logging_config import log
from llmsql.prompts.prompts import (
build_prompt_0shot,
Expand Down Expand Up @@ -65,7 +67,7 @@ def build_all_requests(
questions: list[dict],
tables: dict,
prompt_builder: Callable[[str, list[str], list[str], list[str | float | int]], str],
tokenizer=None,
tokenizer: AutoTokenizer = None,
use_chat_template: bool = True,
) -> list[str]:
"""
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def mock_utils(mocker, tmp_path):
# evaluate_sample → always correct prediction
mocker.patch(
"llmsql.evaluation.evaluate.evaluate_sample",
return_value=(1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}),
return_value=(
1,
None,
{"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 0},
),
)

# rich logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def test_evaluate_runs_with_valid_versions(
lambda *a, **k: (
1,
None,
{"pred_none": 0, "gold_none": 0, "sql_error": 0},
{"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1},
),
)
monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None)
Expand Down Expand Up @@ -83,7 +83,7 @@ async def test_evaluate_raises_with_invalid_version(
lambda *a, **k: (
1,
None,
{"pred_none": 0, "gold_none": 0, "sql_error": 0},
{"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1},
),
)
monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None)
Expand Down
26 changes: 21 additions & 5 deletions tests/evaluation/test_evaluator_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ async def test_evaluate_with_mock(monkeypatch, temp_dir, dummy_db_file):
# Monkeypatch dependencies
monkeypatch.setattr(
"llmsql.utils.evaluation_utils.evaluate_sample",
lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}),
lambda *a, **k: (
1,
None,
{"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1},
),
)
monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None)
monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None)
Expand Down Expand Up @@ -68,7 +72,11 @@ async def test_evaluate_saves_report(monkeypatch, temp_dir, dummy_db_file):
# Mock dependencies
monkeypatch.setattr(
"llmsql.utils.evaluation_utils.evaluate_sample",
lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}),
lambda *a, **k: (
1,
None,
{"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1},
),
)
monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None)
monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None)
Expand Down Expand Up @@ -106,7 +114,11 @@ async def test_evaluate_with_jsonl_file(monkeypatch, temp_dir, dummy_db_file):
# Mock dependencies
monkeypatch.setattr(
"llmsql.utils.evaluation_utils.evaluate_sample",
lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}),
lambda *a, **k: (
1,
None,
{"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1},
),
)
monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None)
monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None)
Expand Down Expand Up @@ -146,7 +158,11 @@ async def test_evaluate_with_dict_list(monkeypatch, temp_dir, dummy_db_file):
# Mock dependencies
monkeypatch.setattr(
"llmsql.utils.evaluation_utils.evaluate_sample",
lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}),
lambda *a, **k: (
1,
None,
{"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1},
),
)
monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None)
monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None)
Expand Down Expand Up @@ -245,7 +261,7 @@ def test_mismatch_handling(mock_utils, mocker):
return_value=(
0,
{"info": "bad"},
{"pred_none": 0, "gold_none": 0, "sql_error": 0},
{"pred_none": 0, "gold_none": 0, "sql_error": 0, "exact_string_match": 1},
),
)

Expand Down
8 changes: 6 additions & 2 deletions tests/utils/test_evaluation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_multiple_table_references(self) -> None:
sql = "SELECT * FROM 'Table' JOIN 'Table' ON Table.id = Table.parent_id"
result = fix_table_name(sql, "my_table")
# Placeholder in FROM should be swapped out for the real table name
assert 'FROM Table' not in result
assert "FROM Table" not in result
assert "FROM 'Table'" not in result
assert 'FROM "Table"' not in result
assert 'FROM "my_table"' in result
Expand All @@ -137,7 +137,7 @@ def test_case_sensitive_table_keyword(self) -> None:
sql = "SELECT * FROM Table WHERE table_name = 'other_table'"
result = fix_table_name(sql, "my_table")
# Should only replace the FROM Table, not 'other_table'
assert result == 'SELECT * FROM "my_table" WHERE table_name = \'other_table\''
assert result == "SELECT * FROM \"my_table\" WHERE table_name = 'other_table'"

def test_complex_query(self) -> None:
"""Test complex query with joins and subqueries."""
Expand Down Expand Up @@ -203,6 +203,7 @@ def test_matching_prediction(self, eval_db, questions_dict) -> None:
assert metrics["pred_none"] == 0
assert metrics["gold_none"] == 0
assert metrics["sql_error"] == 0
assert metrics["exact_string_match"] == 0

def test_non_matching_prediction(self, eval_db, questions_dict) -> None:
"""Test when prediction does not match gold SQL."""
Expand Down Expand Up @@ -299,9 +300,11 @@ def test_metrics_counters(self, eval_db, questions_dict) -> None:
assert "pred_none" in metrics
assert "gold_none" in metrics
assert "sql_error" in metrics
assert "exact_string_match" in metrics
assert isinstance(metrics["pred_none"], int)
assert isinstance(metrics["gold_none"], int)
assert isinstance(metrics["sql_error"], int)
assert isinstance(metrics["exact_string_match"], int)

def test_mismatch_info_structure(self, eval_db, questions_dict) -> None:
"""Test structure of mismatch_info when prediction fails."""
Expand Down Expand Up @@ -343,3 +346,4 @@ def test_null_results_metrics(self, eval_db) -> None:
assert metrics["gold_none"] == 1
assert metrics["pred_none"] == 1
assert metrics["sql_error"] == 0
assert metrics["exact_string_match"] == 1
9 changes: 8 additions & 1 deletion tests/utils/test_rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ def test_print_summary_includes_metrics(monkeypatch) -> None:
recording_console = Console(record=True)
monkeypatch.setattr(rich_utils, "console", recording_console)

rich_utils.print_summary(total=5, matches=3, pred_none=1, gold_none=0, sql_errors=2)
rich_utils.print_summary(
total=5,
matches=3,
pred_none=1,
gold_none=0,
sql_errors=2,
exact_string_matches=1,
)

output = recording_console.export_text()
assert "Evaluation Summary" in output
Expand Down
Loading