From 97d333bbe30f7e7f503cfda38db58fbec9bb89ee Mon Sep 17 00:00:00 2001 From: Wondr Date: Fri, 5 Jun 2026 04:12:20 +0100 Subject: [PATCH] feat: add schema anonymization mode --- README.md | 8 +- src/promptquery/anonymize.py | 194 +++++++++++++++++++++++++++++++++++ src/promptquery/cli.py | 44 ++++++-- tests/test_anonymize.py | 96 +++++++++++++++++ tests/test_cli_oneshot.py | 55 ++++++++++ 5 files changed, 387 insertions(+), 10 deletions(-) create mode 100644 src/promptquery/anonymize.py create mode 100644 tests/test_anonymize.py diff --git a/README.md b/README.md index 7d59d92..1e3adc0 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ prq postgresql://localhost/mydb prq --query "how many users in Italy" postgresql://localhost/mydb # JSON to stdout prq --query "top 10 orders by total" --out csv postgresql://... > out.csv prq --query "..." --out table postgresql://... # rich-formatted table +prq --query "..." --anonymize postgresql://... # hide schema names from LLMs ``` Exit codes: `0` success · `1` LLM/connection error · `2` safety-guard rejection · `3` execution error. @@ -150,6 +151,7 @@ See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design | `--select` | 15 | Tables the LLM selector picks from those candidates | | `--max-tables` | 25 | Cap after FK expansion — what the SQL generator actually sees | | `--no-selector` | — | Skip the LLM selector (v0.1 behaviour: TF-IDF + FK only) | +| `--anonymize` | — | Send opaque table/column names to LLMs, then map generated SQL back locally before validation/execution | | `-y, --yes` | — | Skip the confirmation prompt before running | ### Environment @@ -172,6 +174,10 @@ PromptQuery has **two independent layers** so a write is impossible, even if one Every query is also shown to you before it runs. Confirm with `y`. +### Schema anonymisation + +`--anonymize` replaces table and column names with deterministic opaque tokens before schema context is sent to either the table selector or SQL generator. PromptQuery still performs TF-IDF retrieval, FK expansion, SQL de-anonymisation, safety validation, and execution locally against the real schema. Table comments are omitted in anonymised prompts so they do not reintroduce business-specific names. + --- ## How it compares @@ -255,7 +261,7 @@ python3.12 -m venv .venv .venv/bin/python -m eval.retrieval ``` -37 tests, all pure-Python — no live database or API key required for the core suite. +The core suite is pure Python — no live database or API key required. --- diff --git a/src/promptquery/anonymize.py b/src/promptquery/anonymize.py new file mode 100644 index 0000000..04b72ab --- /dev/null +++ b/src/promptquery/anonymize.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass + +import sqlglot +from sqlglot import errors as sqlglot_errors +from sqlglot import exp +from sqlglot.tokens import Tokenizer + +from .schema import Column, ForeignKey, Schema, Table + + +TableKey = tuple[str, str] + +_SAFE_IDENTIFIER_RE = re.compile(r"^[a-z_][a-z0-9_]*$") + + +@dataclass(frozen=True) +class _TableMapping: + table: Table + token: str + columns_by_real: dict[str, str] + columns_by_token: dict[str, str] + + +class SchemaAnonymizer: + """Map schema identifiers to opaque tokens and generated SQL back again.""" + + def __init__(self, schema: Schema): + self._by_real_key: dict[TableKey, _TableMapping] = {} + self._by_token: dict[str, _TableMapping] = {} + + for table_index, table in enumerate(schema.tables, start=1): + table_token = f"table_{table_index:03d}" + columns_by_real = { + column.name: f"column_{column_index:03d}" + for column_index, column in enumerate(table.columns, start=1) + } + mapping = _TableMapping( + table=table, + token=table_token, + columns_by_real=columns_by_real, + columns_by_token={v: k for k, v in columns_by_real.items()}, + ) + key = (table.schema, table.name) + self._by_real_key[key] = mapping + self._by_token[table_token.lower()] = mapping + + def anonymize_tables(self, tables: list[Table]) -> list[Table]: + return [self.anonymize_table(table) for table in tables] + + def anonymize_table(self, table: Table) -> Table: + mapping = self._mapping_for_table(table) + return Table( + schema="public", + name=mapping.token, + comment=None, + columns=[ + Column( + name=mapping.columns_by_real[column.name], + data_type=column.data_type, + nullable=column.nullable, + is_primary_key=column.is_primary_key, + ) + for column in table.columns + if column.name in mapping.columns_by_real + ], + foreign_keys=[ + anonymized_fk + for fk in table.foreign_keys + if (anonymized_fk := self._anonymize_fk(mapping, fk)) is not None + ], + ) + + def real_tables_from_anonymized(self, tables: list[Table]) -> list[Table]: + real_tables: list[Table] = [] + seen: set[TableKey] = set() + for table in tables: + token = table.name.lower() + mapping = self._by_token.get(token) + if mapping is None: + continue + key = (mapping.table.schema, mapping.table.name) + if key in seen: + continue + real_tables.append(mapping.table) + seen.add(key) + return real_tables + + def deanonymize_sql(self, sql: str) -> str: + try: + statements = [ + stmt + for stmt in sqlglot.parse(sql, read="postgres") + if stmt is not None + ] + except sqlglot_errors.SqlglotError: + return sql + + if not statements: + return sql + + for statement in statements: + self._deanonymize_statement(statement) + + return ";\n".join(statement.sql(dialect="postgres") for statement in statements) + + def _mapping_for_table(self, table: Table) -> _TableMapping: + key = (table.schema, table.name) + try: + return self._by_real_key[key] + except KeyError as exc: + raise ValueError( + f"table is not part of the anonymized schema: {table.qualified_name}" + ) from exc + + def _anonymize_fk(self, mapping: _TableMapping, fk: ForeignKey) -> ForeignKey | None: + local_column = mapping.columns_by_real.get(fk.column) + target_mapping = self._by_real_key.get( + (fk.referenced_schema or "public", fk.referenced_table) + ) + if local_column is None or target_mapping is None: + return None + + target_column = target_mapping.columns_by_real.get(fk.referenced_column) + if target_column is None: + return None + + return ForeignKey( + column=local_column, + referenced_schema="public", + referenced_table=target_mapping.token, + referenced_column=target_column, + ) + + def _deanonymize_statement(self, statement: exp.Expression) -> None: + table_by_qualifier: dict[str, _TableMapping] = {} + referenced_mappings: list[_TableMapping] = [] + + for table_expr in statement.find_all(exp.Table): + mapping = self._by_token.get(table_expr.name.lower()) + if mapping is None: + continue + + referenced_mappings.append(mapping) + table_by_qualifier[mapping.token.lower()] = mapping + table_by_qualifier[mapping.table.name.lower()] = mapping + if table_expr.alias: + table_by_qualifier[table_expr.alias.lower()] = mapping + + table_expr.set("this", _identifier(mapping.table.name)) + if mapping.table.schema == "public": + table_expr.set("db", None) + else: + table_expr.set("db", _identifier(mapping.table.schema)) + + for column_expr in statement.find_all(exp.Column): + column_token = column_expr.name.lower() + table_qualifier = column_expr.table.lower() if column_expr.table else "" + + if table_qualifier: + mapping = table_by_qualifier.get(table_qualifier) + if mapping is None: + continue + real_column = mapping.columns_by_token.get(column_token) + if real_column is None: + continue + column_expr.set("this", _identifier(real_column)) + if table_qualifier == mapping.token.lower(): + column_expr.set("table", _identifier(mapping.table.name)) + if mapping.table.schema == "public": + column_expr.set("db", None) + else: + column_expr.set("db", _identifier(mapping.table.schema)) + continue + + matches = [ + (mapping, mapping.columns_by_token[column_token]) + for mapping in referenced_mappings + if column_token in mapping.columns_by_token + ] + if len(matches) != 1: + continue + _, real_column = matches[0] + column_expr.set("this", _identifier(real_column)) + + +def _identifier(name: str) -> exp.Identifier: + return exp.to_identifier(name, quoted=_needs_quotes(name)) + + +def _needs_quotes(name: str) -> bool: + return not _SAFE_IDENTIFIER_RE.match(name) or name.upper() in Tokenizer.KEYWORDS diff --git a/src/promptquery/cli.py b/src/promptquery/cli.py index 2644c15..44f8199 100644 --- a/src/promptquery/cli.py +++ b/src/promptquery/cli.py @@ -14,6 +14,7 @@ from rich.console import Console from . import __version__ +from .anonymize import SchemaAnonymizer from .db import Database from .llm import LLMClient, LLMError, extract_sql, make_client from .prompts import build_system_prompt @@ -89,6 +90,7 @@ def run_question( confirm: bool, progress: Console | None = None, prompt_for_confirm=None, + anonymizer: SchemaAnonymizer | None = None, ) -> QueryResult: """Run a single NL question end-to-end. Used by both the REPL and --query mode.""" ranked = retriever.rank(question, top_k=top_k) @@ -98,7 +100,17 @@ def run_question( if progress: progress.print(f"[dim]Selecting from {len(candidates)} candidates...[/dim]") try: - selected = llm_select_tables(question, candidates, selector_llm, max_select=select_n) + if anonymizer is None: + selected = llm_select_tables(question, candidates, selector_llm, max_select=select_n) + else: + anonymized_candidates = anonymizer.anonymize_tables(candidates) + selected_anonymized = llm_select_tables( + question, + anonymized_candidates, + selector_llm, + max_select=select_n, + ) + selected = anonymizer.real_tables_from_anonymized(selected_anonymized) if selected: candidates = selected except Exception as e: @@ -111,7 +123,8 @@ def run_question( suffix = "..." if len(relevant) > 5 else "" progress.print(f"[dim]Using {len(relevant)} tables: {preview}{suffix}[/dim]") - system_prompt = build_system_prompt(relevant) + prompt_tables = anonymizer.anonymize_tables(relevant) if anonymizer else relevant + system_prompt = build_system_prompt(prompt_tables) try: if progress: @@ -123,6 +136,8 @@ def run_question( sql = extract_sql(raw) if not sql: return QueryResult(Outcome.EMPTY_SQL, None, [], [], "LLM returned an empty response.") + if anonymizer is not None: + sql = anonymizer.deanonymize_sql(sql) try: validate_select_only(sql) @@ -199,6 +214,11 @@ def run_question( is_flag=True, help="Disable the LLM table-selector and use TF-IDF + FK expansion only (v0.1 behavior).", ) +@click.option( + "--anonymize", + is_flag=True, + help="Send opaque table/column names to LLMs, then map generated SQL back locally.", +) @click.option( "-y", "--yes", @@ -209,7 +229,7 @@ def run_question( def main(dsn: str, model: str | None, selector_model: str | None, query: str | None, out_format: str | None, top_k: int, select_n: int, max_tables: int, - no_selector: bool, yes: bool) -> None: + no_selector: bool, anonymize: bool, yes: bool) -> None: """PromptQuery — natural-language SQL for Postgres. DSN is a libpq connection string, e.g. postgresql://user:pass@host/db. @@ -266,8 +286,11 @@ def main(dsn: str, model: str | None, selector_model: str | None, if selector_llm is not None and selector_llm is not llm else (" (selector: same)" if selector_llm is not None else " (selector: off)") ) + anonymizer = SchemaAnonymizer(schema) if anonymize else None + anonymize_info = ", anonymize: on" if anonymizer is not None else "" progress.print(f"[green]✓[/green] {len(schema.tables)} tables found " - f"[dim](sql: {llm.name}/{llm.model}{selector_info})[/dim]") + f"[dim](sql: {llm.name}/{llm.model}{selector_info}" + f"{anonymize_info})[/dim]") retriever = TfIdfRetriever(schema) @@ -276,14 +299,14 @@ def main(dsn: str, model: str | None, selector_model: str | None, question=query, schema=schema, retriever=retriever, llm=llm, selector_llm=selector_llm, db=db_ctx, top_k=top_k, select_n=select_n, max_tables=max_tables, - out_fmt=out_fmt, progress=progress, + out_fmt=out_fmt, progress=progress, anonymizer=anonymizer, )) _run_repl( schema=schema, retriever=retriever, llm=llm, selector_llm=selector_llm, db=db_ctx, top_k=top_k, select_n=select_n, max_tables=max_tables, - out_fmt=out_fmt, yes=yes, progress=progress, + out_fmt=out_fmt, yes=yes, progress=progress, anonymizer=anonymizer, ) finally: db_ctx.close() @@ -291,12 +314,13 @@ def main(dsn: str, model: str | None, selector_model: str | None, def _run_one_shot(*, question, schema, retriever, llm, selector_llm, db, top_k, select_n, max_tables, out_fmt: OutputFormat, - progress: Console) -> int: + progress: Console, + anonymizer: SchemaAnonymizer | None = None) -> int: result = run_question( question=question, schema=schema, retriever=retriever, llm=llm, selector_llm=selector_llm, db=db, top_k=top_k, select_n=select_n, max_tables=max_tables, - confirm=False, progress=progress, + confirm=False, progress=progress, anonymizer=anonymizer, ) if result.outcome is Outcome.LLM_ERROR: @@ -326,7 +350,8 @@ def _run_one_shot(*, question, schema, retriever, llm, selector_llm, db, def _run_repl(*, schema, retriever, llm, selector_llm, db, top_k, select_n, max_tables, out_fmt: OutputFormat, - yes: bool, progress: Console) -> None: + yes: bool, progress: Console, + anonymizer: SchemaAnonymizer | None = None) -> None: session: PromptSession[str] = PromptSession(history=InMemoryHistory()) progress.print( "\n[bold]PromptQuery[/bold] — ask a question in plain English, " @@ -350,6 +375,7 @@ def _run_repl(*, schema, retriever, llm, selector_llm, db, top_k=top_k, select_n=select_n, max_tables=max_tables, confirm=not yes, progress=progress, prompt_for_confirm=(lambda: session.prompt("Run? [y/N] ")) if not yes else None, + anonymizer=anonymizer, ) if result.outcome is Outcome.LLM_ERROR: diff --git a/tests/test_anonymize.py b/tests/test_anonymize.py new file mode 100644 index 0000000..85e1bc5 --- /dev/null +++ b/tests/test_anonymize.py @@ -0,0 +1,96 @@ +import pytest + +from promptquery.anonymize import SchemaAnonymizer +from promptquery.prompts import format_schema +from promptquery.safety import UnsafeQuery, validate_select_only +from promptquery.schema import Column, ForeignKey, Schema, Table + + +def _schema() -> Schema: + users = Table( + schema="public", + name="users", + comment="customer accounts with private emails", + columns=[ + Column("id", "bigint", False, True), + Column("email", "text", True, False), + ], + ) + orders = Table( + schema="sales", + name="orders", + comment="commercial order history", + columns=[ + Column("id", "bigint", False, True), + Column("user_id", "bigint", False, False), + Column("total", "numeric", False, False), + ], + foreign_keys=[ForeignKey("user_id", "public", "users", "id")], + ) + return Schema(tables=[users, orders]) + + +def test_anonymized_schema_hides_names_and_preserves_structure(): + anonymizer = SchemaAnonymizer(_schema()) + anonymized = anonymizer.anonymize_tables(_schema().tables) + + rendered = format_schema(anonymized) + assert "TABLE table_001" in rendered + assert "TABLE table_002" in rendered + assert "column_001 bigint [PK, NOT NULL]" in rendered + assert "FK column_002 -> table_001(column_001)" in rendered + + for leaked in [ + "users", + "orders", + "email", + "user_id", + "customer accounts", + "commercial order", + ]: + assert leaked not in rendered + + +def test_deanonymize_sql_maps_qualified_aliases_and_non_public_schema(): + anonymizer = SchemaAnonymizer(_schema()) + + sql = anonymizer.deanonymize_sql( + "SELECT o.column_003 " + "FROM table_002 AS o " + "JOIN table_001 AS u ON o.column_002 = u.column_001 " + "WHERE o.column_003 > 0" + ) + + assert sql == ( + "SELECT o.total FROM sales.orders AS o " + "JOIN users AS u ON o.user_id = u.id WHERE o.total > 0" + ) + + +def test_deanonymize_sql_maps_unqualified_column_for_single_table(): + anonymizer = SchemaAnonymizer(_schema()) + + sql = anonymizer.deanonymize_sql("SELECT column_002 FROM table_001") + + assert sql == "SELECT email FROM users" + + +def test_deanonymize_preserves_multiple_statement_rejection(): + anonymizer = SchemaAnonymizer(_schema()) + + sql = anonymizer.deanonymize_sql("SELECT column_001 FROM table_001; DELETE FROM table_001") + + assert "SELECT id FROM users" in sql + assert "DELETE FROM users" in sql + with pytest.raises(UnsafeQuery): + validate_select_only(sql) + + +def test_real_tables_from_anonymized_returns_original_tables(): + schema = _schema() + anonymizer = SchemaAnonymizer(schema) + anonymized = anonymizer.anonymize_tables(schema.tables) + + restored = anonymizer.real_tables_from_anonymized([anonymized[1], anonymized[0], anonymized[1]]) + + assert restored == [schema.tables[1], schema.tables[0]] diff --git a/tests/test_cli_oneshot.py b/tests/test_cli_oneshot.py index d83b211..6b5d6f1 100644 --- a/tests/test_cli_oneshot.py +++ b/tests/test_cli_oneshot.py @@ -11,6 +11,7 @@ import pytest +from promptquery.anonymize import SchemaAnonymizer from promptquery.cli import ( OutputFormat, Outcome, @@ -100,8 +101,10 @@ class _FakeLLM: def __init__(self, response: str): self._response = response + self.calls: list[tuple[str, str]] = [] def generate(self, system: str, user: str) -> str: + self.calls.append((system, user)) return self._response @@ -201,6 +204,57 @@ def test_run_question_exec_error_propagates_as_outcome(): assert "relation users" in (result.error or "") +def test_run_question_anonymizes_prompt_and_maps_sql_back(): + schema = _example_schema() + retriever = TfIdfRetriever(schema) + anonymizer = SchemaAnonymizer(schema) + llm = _FakeLLM("```sql\nSELECT table_001.column_001 FROM table_001\n```") + db = _FakeDB(rows=[(1,)]) + + result = run_question( + question="show user ids", + schema=schema, retriever=retriever, + llm=llm, selector_llm=None, db=db, + top_k=10, select_n=15, max_tables=10, + confirm=False, anonymizer=anonymizer, + ) + + assert result.outcome is Outcome.OK + assert db.last_sql == "SELECT users.id FROM users" + + system_prompt = llm.calls[0][0] + assert "TABLE table_001" in system_prompt + assert "column_001" in system_prompt + for leaked in ["users", "orders", "email", "user_id", "end-user"]: + assert leaked not in system_prompt + + +def test_run_question_anonymizes_selector_candidates(): + schema = _example_schema() + retriever = TfIdfRetriever(schema) + anonymizer = SchemaAnonymizer(schema) + selector_llm = _FakeLLM('["table_002"]') + llm = _FakeLLM("```sql\nSELECT table_002.column_003 FROM table_002\n```") + db = _FakeDB(rows=[(42,)]) + + result = run_question( + question="anything", + schema=schema, retriever=retriever, + llm=llm, selector_llm=selector_llm, db=db, + top_k=10, select_n=1, max_tables=10, + confirm=False, anonymizer=anonymizer, + ) + + assert result.outcome is Outcome.OK + assert db.last_sql == "SELECT orders.total FROM orders" + + selector_prompt = selector_llm.calls[0][0] + assert "table_001" in selector_prompt + assert "table_002" in selector_prompt + for leaked in ["users", "orders", "email", "user_id", "end-user"]: + assert leaked not in selector_prompt + + # -- CLI surface (just argv parsing — no real DB / LLM call) -------------- def test_cli_help_lists_query_flag(): @@ -211,6 +265,7 @@ def test_cli_help_lists_query_flag(): assert result.exit_code == 0 assert "--query" in result.output assert "--out" in result.output + assert "--anonymize" in result.output assert "json" in result.output and "csv" in result.output