Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ prq postgresql://localhost/mydb
prq --query "how many users in Italy" postgresql://localhost/mydb # JSON to stdout
prq --query "top 10 orders by total" --out csv postgresql://... > out.csv
prq --query "..." --out table postgresql://... # rich-formatted table
prq --query "..." --anonymize postgresql://... # hide schema names from LLMs
```

Exit codes: `0` success · `1` LLM/connection error · `2` safety-guard rejection · `3` execution error.
Expand Down Expand Up @@ -150,6 +151,7 @@ See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design
| `--select` | 15 | Tables the LLM selector picks from those candidates |
| `--max-tables` | 25 | Cap after FK expansion — what the SQL generator actually sees |
| `--no-selector` | — | Skip the LLM selector (v0.1 behaviour: TF-IDF + FK only) |
| `--anonymize` | — | Send opaque table/column names to LLMs, then map generated SQL back locally before validation/execution |
| `-y, --yes` | — | Skip the confirmation prompt before running |

### Environment
Expand All @@ -172,6 +174,10 @@ PromptQuery has **two independent layers** so a write is impossible, even if one

Every query is also shown to you before it runs. Confirm with `y`.

### Schema anonymisation

`--anonymize` replaces table and column names with deterministic opaque tokens before schema context is sent to either the table selector or SQL generator. PromptQuery still performs TF-IDF retrieval, FK expansion, SQL de-anonymisation, safety validation, and execution locally against the real schema. Table comments are omitted in anonymised prompts so they do not reintroduce business-specific names.

---

## How it compares
Expand Down Expand Up @@ -255,7 +261,7 @@ python3.12 -m venv .venv
.venv/bin/python -m eval.retrieval
```

37 tests, all pure-Python — no live database or API key required for the core suite.
The core suite is pure Python — no live database or API key required.

---

Expand Down
194 changes: 194 additions & 0 deletions src/promptquery/anonymize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from __future__ import annotations

import re
from dataclasses import dataclass

import sqlglot
from sqlglot import errors as sqlglot_errors
from sqlglot import exp
from sqlglot.tokens import Tokenizer

from .schema import Column, ForeignKey, Schema, Table


TableKey = tuple[str, str]

_SAFE_IDENTIFIER_RE = re.compile(r"^[a-z_][a-z0-9_]*$")


@dataclass(frozen=True)
class _TableMapping:
table: Table
token: str
columns_by_real: dict[str, str]
columns_by_token: dict[str, str]


class SchemaAnonymizer:
"""Map schema identifiers to opaque tokens and generated SQL back again."""

def __init__(self, schema: Schema):
self._by_real_key: dict[TableKey, _TableMapping] = {}
self._by_token: dict[str, _TableMapping] = {}

for table_index, table in enumerate(schema.tables, start=1):
table_token = f"table_{table_index:03d}"
columns_by_real = {
column.name: f"column_{column_index:03d}"
for column_index, column in enumerate(table.columns, start=1)
}
mapping = _TableMapping(
table=table,
token=table_token,
columns_by_real=columns_by_real,
columns_by_token={v: k for k, v in columns_by_real.items()},
)
key = (table.schema, table.name)
self._by_real_key[key] = mapping
self._by_token[table_token.lower()] = mapping

def anonymize_tables(self, tables: list[Table]) -> list[Table]:
return [self.anonymize_table(table) for table in tables]

def anonymize_table(self, table: Table) -> Table:
mapping = self._mapping_for_table(table)
return Table(
schema="public",
name=mapping.token,
comment=None,
columns=[
Column(
name=mapping.columns_by_real[column.name],
data_type=column.data_type,
nullable=column.nullable,
is_primary_key=column.is_primary_key,
)
for column in table.columns
if column.name in mapping.columns_by_real
],
foreign_keys=[
anonymized_fk
for fk in table.foreign_keys
if (anonymized_fk := self._anonymize_fk(mapping, fk)) is not None
],
)

def real_tables_from_anonymized(self, tables: list[Table]) -> list[Table]:
real_tables: list[Table] = []
seen: set[TableKey] = set()
for table in tables:
token = table.name.lower()
mapping = self._by_token.get(token)
if mapping is None:
continue
key = (mapping.table.schema, mapping.table.name)
if key in seen:
continue
real_tables.append(mapping.table)
seen.add(key)
return real_tables

def deanonymize_sql(self, sql: str) -> str:
try:
statements = [
stmt
for stmt in sqlglot.parse(sql, read="postgres")
if stmt is not None
]
except sqlglot_errors.SqlglotError:
return sql

if not statements:
return sql

for statement in statements:
self._deanonymize_statement(statement)

return ";\n".join(statement.sql(dialect="postgres") for statement in statements)

def _mapping_for_table(self, table: Table) -> _TableMapping:
key = (table.schema, table.name)
try:
return self._by_real_key[key]
except KeyError as exc:
raise ValueError(
f"table is not part of the anonymized schema: {table.qualified_name}"
) from exc

def _anonymize_fk(self, mapping: _TableMapping, fk: ForeignKey) -> ForeignKey | None:
local_column = mapping.columns_by_real.get(fk.column)
target_mapping = self._by_real_key.get(
(fk.referenced_schema or "public", fk.referenced_table)
)
if local_column is None or target_mapping is None:
return None

target_column = target_mapping.columns_by_real.get(fk.referenced_column)
if target_column is None:
return None

return ForeignKey(
column=local_column,
referenced_schema="public",
referenced_table=target_mapping.token,
referenced_column=target_column,
)

def _deanonymize_statement(self, statement: exp.Expression) -> None:
table_by_qualifier: dict[str, _TableMapping] = {}
referenced_mappings: list[_TableMapping] = []

for table_expr in statement.find_all(exp.Table):
mapping = self._by_token.get(table_expr.name.lower())
if mapping is None:
continue

referenced_mappings.append(mapping)
table_by_qualifier[mapping.token.lower()] = mapping
table_by_qualifier[mapping.table.name.lower()] = mapping
if table_expr.alias:
table_by_qualifier[table_expr.alias.lower()] = mapping

table_expr.set("this", _identifier(mapping.table.name))
if mapping.table.schema == "public":
table_expr.set("db", None)
else:
table_expr.set("db", _identifier(mapping.table.schema))

for column_expr in statement.find_all(exp.Column):
column_token = column_expr.name.lower()
table_qualifier = column_expr.table.lower() if column_expr.table else ""

if table_qualifier:
mapping = table_by_qualifier.get(table_qualifier)
if mapping is None:
continue
real_column = mapping.columns_by_token.get(column_token)
if real_column is None:
continue
column_expr.set("this", _identifier(real_column))
if table_qualifier == mapping.token.lower():
column_expr.set("table", _identifier(mapping.table.name))
if mapping.table.schema == "public":
column_expr.set("db", None)
else:
column_expr.set("db", _identifier(mapping.table.schema))
continue

matches = [
(mapping, mapping.columns_by_token[column_token])
for mapping in referenced_mappings
if column_token in mapping.columns_by_token
]
if len(matches) != 1:
continue
_, real_column = matches[0]
column_expr.set("this", _identifier(real_column))


def _identifier(name: str) -> exp.Identifier:
return exp.to_identifier(name, quoted=_needs_quotes(name))


def _needs_quotes(name: str) -> bool:
return not _SAFE_IDENTIFIER_RE.match(name) or name.upper() in Tokenizer.KEYWORDS
44 changes: 35 additions & 9 deletions src/promptquery/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rich.console import Console

from . import __version__
from .anonymize import SchemaAnonymizer
from .db import Database
from .llm import LLMClient, LLMError, extract_sql, make_client
from .prompts import build_system_prompt
Expand Down Expand Up @@ -89,6 +90,7 @@ def run_question(
confirm: bool,
progress: Console | None = None,
prompt_for_confirm=None,
anonymizer: SchemaAnonymizer | None = None,
) -> QueryResult:
"""Run a single NL question end-to-end. Used by both the REPL and --query mode."""
ranked = retriever.rank(question, top_k=top_k)
Expand All @@ -98,7 +100,17 @@ def run_question(
if progress:
progress.print(f"[dim]Selecting from {len(candidates)} candidates...[/dim]")
try:
selected = llm_select_tables(question, candidates, selector_llm, max_select=select_n)
if anonymizer is None:
selected = llm_select_tables(question, candidates, selector_llm, max_select=select_n)
else:
anonymized_candidates = anonymizer.anonymize_tables(candidates)
selected_anonymized = llm_select_tables(
question,
anonymized_candidates,
selector_llm,
max_select=select_n,
)
selected = anonymizer.real_tables_from_anonymized(selected_anonymized)
if selected:
candidates = selected
except Exception as e:
Expand All @@ -111,7 +123,8 @@ def run_question(
suffix = "..." if len(relevant) > 5 else ""
progress.print(f"[dim]Using {len(relevant)} tables: {preview}{suffix}[/dim]")

system_prompt = build_system_prompt(relevant)
prompt_tables = anonymizer.anonymize_tables(relevant) if anonymizer else relevant
system_prompt = build_system_prompt(prompt_tables)

try:
if progress:
Expand All @@ -123,6 +136,8 @@ def run_question(
sql = extract_sql(raw)
if not sql:
return QueryResult(Outcome.EMPTY_SQL, None, [], [], "LLM returned an empty response.")
if anonymizer is not None:
sql = anonymizer.deanonymize_sql(sql)

try:
validate_select_only(sql)
Expand Down Expand Up @@ -199,6 +214,11 @@ def run_question(
is_flag=True,
help="Disable the LLM table-selector and use TF-IDF + FK expansion only (v0.1 behavior).",
)
@click.option(
"--anonymize",
is_flag=True,
help="Send opaque table/column names to LLMs, then map generated SQL back locally.",
)
@click.option(
"-y",
"--yes",
Expand All @@ -209,7 +229,7 @@ def run_question(
def main(dsn: str, model: str | None, selector_model: str | None,
query: str | None, out_format: str | None,
top_k: int, select_n: int, max_tables: int,
no_selector: bool, yes: bool) -> None:
no_selector: bool, anonymize: bool, yes: bool) -> None:
"""PromptQuery — natural-language SQL for Postgres.

DSN is a libpq connection string, e.g. postgresql://user:pass@host/db.
Expand Down Expand Up @@ -266,8 +286,11 @@ def main(dsn: str, model: str | None, selector_model: str | None,
if selector_llm is not None and selector_llm is not llm
else (" (selector: same)" if selector_llm is not None else " (selector: off)")
)
anonymizer = SchemaAnonymizer(schema) if anonymize else None
anonymize_info = ", anonymize: on" if anonymizer is not None else ""
progress.print(f"[green]✓[/green] {len(schema.tables)} tables found "
f"[dim](sql: {llm.name}/{llm.model}{selector_info})[/dim]")
f"[dim](sql: {llm.name}/{llm.model}{selector_info}"
f"{anonymize_info})[/dim]")

retriever = TfIdfRetriever(schema)

Expand All @@ -276,27 +299,28 @@ def main(dsn: str, model: str | None, selector_model: str | None,
question=query, schema=schema, retriever=retriever,
llm=llm, selector_llm=selector_llm, db=db_ctx,
top_k=top_k, select_n=select_n, max_tables=max_tables,
out_fmt=out_fmt, progress=progress,
out_fmt=out_fmt, progress=progress, anonymizer=anonymizer,
))

_run_repl(
schema=schema, retriever=retriever,
llm=llm, selector_llm=selector_llm, db=db_ctx,
top_k=top_k, select_n=select_n, max_tables=max_tables,
out_fmt=out_fmt, yes=yes, progress=progress,
out_fmt=out_fmt, yes=yes, progress=progress, anonymizer=anonymizer,
)
finally:
db_ctx.close()


def _run_one_shot(*, question, schema, retriever, llm, selector_llm, db,
top_k, select_n, max_tables, out_fmt: OutputFormat,
progress: Console) -> int:
progress: Console,
anonymizer: SchemaAnonymizer | None = None) -> int:
result = run_question(
question=question, schema=schema, retriever=retriever,
llm=llm, selector_llm=selector_llm, db=db,
top_k=top_k, select_n=select_n, max_tables=max_tables,
confirm=False, progress=progress,
confirm=False, progress=progress, anonymizer=anonymizer,
)

if result.outcome is Outcome.LLM_ERROR:
Expand Down Expand Up @@ -326,7 +350,8 @@ def _run_one_shot(*, question, schema, retriever, llm, selector_llm, db,

def _run_repl(*, schema, retriever, llm, selector_llm, db,
top_k, select_n, max_tables, out_fmt: OutputFormat,
yes: bool, progress: Console) -> None:
yes: bool, progress: Console,
anonymizer: SchemaAnonymizer | None = None) -> None:
session: PromptSession[str] = PromptSession(history=InMemoryHistory())
progress.print(
"\n[bold]PromptQuery[/bold] — ask a question in plain English, "
Expand All @@ -350,6 +375,7 @@ def _run_repl(*, schema, retriever, llm, selector_llm, db,
top_k=top_k, select_n=select_n, max_tables=max_tables,
confirm=not yes, progress=progress,
prompt_for_confirm=(lambda: session.prompt("Run? [y/N] ")) if not yes else None,
anonymizer=anonymizer,
)

if result.outcome is Outcome.LLM_ERROR:
Expand Down
Loading