From 272350e7a9e088f899da1dfa62755cc3bd555399 Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 2 Mar 2026 20:32:50 +0900 Subject: [PATCH 1/9] feat(ports): add DBExplorerPort protocol for agentic DB exploration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent가 DB를 직접 탐색할 때 사용하는 4개 메서드 정의. list_tables / get_ddl / sample_data / execute_read_only. DDL에 이미 있는 정보는 별도 메서드 없음, 관계 추론은 LLM에 위임. --- src/lang2sql/core/ports.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/lang2sql/core/ports.py b/src/lang2sql/core/ports.py index b04bc61..80c20b3 100644 --- a/src/lang2sql/core/ports.py +++ b/src/lang2sql/core/ports.py @@ -63,3 +63,21 @@ class CatalogLoaderPort(Protocol): """Abstracts catalog loading from external sources (DataHub, file, database, etc.).""" def load(self) -> list[CatalogEntry]: ... + + +class DBExplorerPort(Protocol): + """DB 에이전틱 탐색 인터페이스. Agent가 DB를 직접 탐색할 때 사용. + + 메서드 선정 원칙: + - DDL에 이미 있는 정보(컬럼 목록, FK, PK)는 별도 메서드 없음 + - 통계/집계는 execute_read_only()로 직접 질의 + - 관계 추론은 LLM에 위임 (휴리스틱 제거) + """ + + def list_tables(self, schema: str | None = None) -> list[str]: ... + + def get_ddl(self, table: str, *, schema: str | None = None) -> str: ... + + def sample_data(self, table: str, *, limit: int = 5, schema: str | None = None) -> list[dict]: ... + + def execute_read_only(self, sql: str) -> list[dict]: ... From a6b15099db6f131a2a78d03c68b0122e47aad3bf Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 2 Mar 2026 20:33:06 +0900 Subject: [PATCH 2/9] feat(db): implement SQLAlchemyExplorer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SQLite는 sqlite_master에서 원본 DDL 그대로 반환. 그 외 dialect는 CreateTable.compile()으로 FK/PK 포함 DDL 생성. execute_read_only는 prefix guard + rollback-always 이중 방어. from_engine() classmethod로 기존 SQLAlchemyDB engine 공유 가능. --- src/lang2sql/integrations/db/__init__.py | 4 +- src/lang2sql/integrations/db/sqlalchemy_.py | 98 ++++++++++++++++++++- 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/src/lang2sql/integrations/db/__init__.py b/src/lang2sql/integrations/db/__init__.py index 4096452..79ae1db 100644 --- a/src/lang2sql/integrations/db/__init__.py +++ b/src/lang2sql/integrations/db/__init__.py @@ -1,3 +1,3 @@ -from .sqlalchemy_ import SQLAlchemyDB +from .sqlalchemy_ import SQLAlchemyDB, SQLAlchemyExplorer -__all__ = ["SQLAlchemyDB"] +__all__ = ["SQLAlchemyDB", "SQLAlchemyExplorer"] diff --git a/src/lang2sql/integrations/db/sqlalchemy_.py b/src/lang2sql/integrations/db/sqlalchemy_.py index 7444502..4b90279 100644 --- a/src/lang2sql/integrations/db/sqlalchemy_.py +++ b/src/lang2sql/integrations/db/sqlalchemy_.py @@ -6,10 +6,11 @@ from ...core.ports import DBPort try: - from sqlalchemy import create_engine, text as sa_text + from sqlalchemy import create_engine, inspect as sa_inspect, text as sa_text from sqlalchemy.engine import Engine except ImportError: create_engine = None # type: ignore[assignment] + sa_inspect = None # type: ignore[assignment] sa_text = None # type: ignore[assignment] Engine = None # type: ignore[assignment,misc] @@ -28,3 +29,98 @@ def execute(self, sql: str) -> list[dict[str, Any]]: with self._engine.connect() as conn: result = conn.execute(sa_text(sql)) return [dict(row._mapping) for row in result] + + +_WRITE_PREFIXES = frozenset( + {"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE", "MERGE"} +) + + +class SQLAlchemyExplorer: + """DBExplorerPort implementation backed by SQLAlchemy 2.x. + + Agent가 DB 스키마를 탐색할 때 사용. DDL + 샘플 데이터를 LLM context에 직접 주입. + """ + + def __init__(self, url: str, *, schema: str | None = None) -> None: + if create_engine is None: + raise IntegrationMissingError( + "sqlalchemy", extra="sqlalchemy", hint="pip install sqlalchemy" + ) + self._engine: Engine = create_engine(url) + self._schema = schema + + @classmethod + def from_engine(cls, engine: "Engine", *, schema: str | None = None) -> "SQLAlchemyExplorer": + """기존 engine 공유용. 연결 풀 중복 방지.""" + instance = cls.__new__(cls) + instance._engine = engine + instance._schema = schema + return instance + + def list_tables(self, schema: str | None = None) -> list[str]: + """테이블 목록 반환. Agent가 DB 구조 파악 시 첫 번째 호출.""" + insp = sa_inspect(self._engine) + return insp.get_table_names(schema=schema or self._schema) + + def get_ddl(self, table: str, *, schema: str | None = None) -> str: + """원본 DDL 문자열 반환. 컬럼 정의, PK, FK, 제약조건 모두 포함. + + SQLite: sqlite_master에서 원본 그대로 (DEFAULT, 코멘트, 인라인 FK 모두 보존). + 그 외: SQLAlchemy CreateTable construct로 포괄적 DDL 생성. + """ + resolved_schema = schema or self._schema + if self._engine.dialect.name == "sqlite": + rows = self._execute_safe( + "SELECT sql FROM sqlite_master WHERE type='table' AND name=:table", + {"table": table}, + ) + if rows and rows[0].get("sql"): + return rows[0]["sql"] + + from sqlalchemy import MetaData + from sqlalchemy import Table as SATable + from sqlalchemy.schema import CreateTable + + metadata = MetaData() + t = SATable(table, metadata, autoload_with=self._engine, schema=resolved_schema) + return str(CreateTable(t).compile(self._engine)) + + def sample_data(self, table: str, *, limit: int = 5, schema: str | None = None) -> list[dict]: + """실제 샘플 데이터 반환. + + f-string SQL 금지 — SQLAlchemy ORM select()로 identifier quoting 위임. + dialect별 quoting 차이(PostgreSQL ", MySQL `, SQLite ")를 SQLAlchemy가 처리. + """ + from sqlalchemy import MetaData, select + from sqlalchemy import Table as SATable + + resolved_schema = schema or self._schema + metadata = MetaData() + t = SATable(table, metadata, autoload_with=self._engine, schema=resolved_schema) + stmt = select(t).limit(limit) + with self._engine.connect() as conn: + result = conn.execute(stmt) + return [dict(row._mapping) for row in result] + + def execute_read_only(self, sql: str) -> list[dict]: + """읽기 전용 SQL 실행. + + 두 겹 방어: + 1. prefix guard — 일반적인 쓰기 구문 빠른 거부 (UX) + 2. rollback-always — WITH ... DELETE 같은 CTE 우회도 실제 DB 반영 방지 + """ + first_token = sql.strip().upper().split()[0] if sql.strip() else "" + if first_token in _WRITE_PREFIXES: + raise ValueError(f"Write operations not allowed: {sql[:50]!r}") + with self._engine.connect() as conn: + result = conn.execute(sa_text(sql)) + rows = [dict(row._mapping) for row in result] + conn.rollback() + return rows + + def _execute_safe(self, sql: str, params: dict | None = None) -> list[dict]: + """파라미터화 쿼리 실행 (내부용).""" + with self._engine.connect() as conn: + result = conn.execute(sa_text(sql), params or {}) + return [dict(row._mapping) for row in result] From 69974a72559e73db6fcaf09841d1653b5b615d92 Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 2 Mar 2026 20:33:23 +0900 Subject: [PATCH 3/9] feat(factory): add build_explorer_from_url convenience function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DB URL 하나로 SQLAlchemyExplorer 생성하는 편의 함수. schema 파라미터로 특정 스키마 범위 지정 가능. --- src/lang2sql/factory.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lang2sql/factory.py b/src/lang2sql/factory.py index 6fca375..bd05b41 100644 --- a/src/lang2sql/factory.py +++ b/src/lang2sql/factory.py @@ -8,7 +8,7 @@ import os -from .core.ports import DBPort, EmbeddingPort, LLMPort +from .core.ports import DBExplorerPort, DBPort, EmbeddingPort, LLMPort def build_llm_from_env() -> LLMPort: @@ -156,6 +156,13 @@ def build_embedding_from_env() -> EmbeddingPort: ) +def build_explorer_from_url(url: str, *, schema: str | None = None) -> "DBExplorerPort": + """DB URL로 SQLAlchemyExplorer 생성.""" + from .integrations.db.sqlalchemy_ import SQLAlchemyExplorer + + return SQLAlchemyExplorer(url, schema=schema) + + def build_db_from_env(database_env: str = "") -> DBPort: """환경변수에서 DB URL을 구성하고 SQLAlchemyDB를 반환한다. From 0c210eae24fc0ff7364036a5603361f500c53385 Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 2 Mar 2026 20:33:43 +0900 Subject: [PATCH 4/9] feat: export DBExplorerPort, SQLAlchemyExplorer, build_explorer_from_url MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PEP 562 __getattr__ lazy import 적용. FAISSVectorStore, PGVectorStore, DataHubCatalogLoader를 lazy 전환해 baseline flow 실행 시 FAISS INFO 로그 출력 문제 함께 수정. --- src/lang2sql/__init__.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/src/lang2sql/__init__.py b/src/lang2sql/__init__.py index 2781ba9..6797507 100644 --- a/src/lang2sql/__init__.py +++ b/src/lang2sql/__init__.py @@ -1,4 +1,4 @@ -from .factory import build_db_from_env, build_embedding_from_env, build_llm_from_env +from .factory import build_db_from_env, build_embedding_from_env, build_explorer_from_url, build_llm_from_env from .components.enrichment.context_enricher import ContextEnricher from .components.enrichment.question_profiler import QuestionProfiler from .components.execution.sql_executor import SQLExecutor @@ -28,16 +28,18 @@ from .core.exceptions import ComponentError, IntegrationMissingError, Lang2SQLError from .core.hooks import MemoryHook, NullHook, TraceHook from .core.ports import ( + CatalogLoaderPort, + DBExplorerPort, DBPort, DocumentLoaderPort, EmbeddingPort, LLMPort, VectorStorePort, ) +from .integrations.db.sqlalchemy_ import SQLAlchemyExplorer from .flows.enriched_nl2sql import EnrichedNL2SQL from .flows.hybrid import HybridNL2SQL from .flows.nl2sql import BaselineNL2SQL -from .integrations.catalog.datahub_ import DataHubCatalogLoader from .integrations.embedding.azure_ import AzureOpenAIEmbedding from .integrations.embedding.bedrock_ import BedrockEmbedding from .integrations.embedding.gemini_ import GeminiEmbedding @@ -48,9 +50,6 @@ from .integrations.llm.gemini_ import GeminiLLM from .integrations.llm.huggingface_ import HuggingFaceLLM from .integrations.llm.ollama_ import OllamaLLM -from .integrations.vectorstore.faiss_ import FAISSVectorStore -from .integrations.vectorstore.pgvector_ import PGVectorStore - __all__ = [ # Data types "CatalogEntry", @@ -64,9 +63,11 @@ # Ports (protocols) "LLMPort", "DBPort", + "DBExplorerPort", "EmbeddingPort", "VectorStorePort", "DocumentLoaderPort", + "CatalogLoaderPort", # Components — retrieval "KeywordRetriever", "VectorRetriever", @@ -116,8 +117,32 @@ "OllamaEmbedding", # Catalog integrations (Phase 3) "DataHubCatalogLoader", + # DB Explorer (Phase A1) + "SQLAlchemyExplorer", # Factory (Phase 6) "build_llm_from_env", "build_embedding_from_env", "build_db_from_env", + "build_explorer_from_url", ] + +# --------------------------------------------------------------------------- +# Lazy imports (PEP 562) — optional dependencies that have import side-effects +# (e.g. faiss prints INFO logs on import) or are rarely needed at startup. +# --------------------------------------------------------------------------- +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "DataHubCatalogLoader": (".integrations.catalog.datahub_", "DataHubCatalogLoader"), + "FAISSVectorStore": (".integrations.vectorstore.faiss_", "FAISSVectorStore"), + "PGVectorStore": (".integrations.vectorstore.pgvector_", "PGVectorStore"), +} + + +def __getattr__(name: str): + if name in _LAZY_IMPORTS: + module_path, attr = _LAZY_IMPORTS[name] + import importlib + obj = getattr(importlib.import_module(module_path, package=__name__), attr) + # Cache in module globals so subsequent accesses skip __getattr__ + globals()[name] = obj + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") From 0b59bdf24e143494f2efcf87afd17963a5b5074d Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 2 Mar 2026 20:33:59 +0900 Subject: [PATCH 5/9] test(db): add SQLAlchemyExplorer integration tests (12 cases) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit list_tables / get_ddl / sample_data / execute_read_only 전체 검증. write 거부(INSERT, DROP), from_engine 공유, SQLAlchemyDB 동시 사용 포함. SQLite in-memory DB 사용, 외부 의존 없음. --- .../test_integrations_sqlalchemy_explorer.py | 144 ++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 tests/test_integrations_sqlalchemy_explorer.py diff --git a/tests/test_integrations_sqlalchemy_explorer.py b/tests/test_integrations_sqlalchemy_explorer.py new file mode 100644 index 0000000..c0cd3ff --- /dev/null +++ b/tests/test_integrations_sqlalchemy_explorer.py @@ -0,0 +1,144 @@ +"""Tests for SQLAlchemyExplorer (Phase A1).""" + +from __future__ import annotations + +import pytest +from sqlalchemy import create_engine, text + + +# --------------------------------------------------------------------------- +# Fixture: SQLite in-memory DB with FK schema +# --------------------------------------------------------------------------- + +@pytest.fixture() +def engine(): + eng = create_engine("sqlite:///:memory:") + with eng.connect() as conn: + conn.execute(text(""" + CREATE TABLE customers ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE + ) + """)) + conn.execute(text(""" + CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER NOT NULL REFERENCES customers(id), + amount REAL, + status TEXT DEFAULT 'pending' + ) + """)) + conn.execute(text("INSERT INTO customers VALUES (1, 'Alice', 'alice@example.com')")) + conn.execute(text("INSERT INTO customers VALUES (2, 'Bob', 'bob@example.com')")) + conn.execute(text("INSERT INTO orders VALUES (1, 1, 99.9, 'shipped')")) + conn.execute(text("INSERT INTO orders VALUES (2, 2, 42.0, 'pending')")) + conn.commit() + return eng + + +@pytest.fixture() +def explorer(engine): + from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyExplorer + + return SQLAlchemyExplorer.from_engine(engine) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_list_tables(explorer): + tables = explorer.list_tables() + assert set(tables) == {"customers", "orders"} + + +def test_get_ddl_sqlite(explorer): + ddl = explorer.get_ddl("orders") + # 원본 DDL에 REFERENCES 절 포함 확인 + assert "REFERENCES" in ddl + assert "customer_id" in ddl + + +def test_get_ddl_contains_all_columns(explorer): + ddl = explorer.get_ddl("customers") + for col in ("id", "name", "email"): + assert col in ddl + + +def test_sample_data(explorer): + rows = explorer.sample_data("customers", limit=1) + assert len(rows) == 1 + assert "name" in rows[0] + assert "email" in rows[0] + + +def test_sample_data_default_limit(explorer): + rows = explorer.sample_data("customers") + # 2행 삽입, limit=5(기본값) → 모두 반환 + assert len(rows) == 2 + + +def test_sample_data_empty_table(engine): + from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyExplorer + + with engine.connect() as conn: + conn.execute(text("CREATE TABLE empty_tbl (x INTEGER)")) + conn.commit() + + exp = SQLAlchemyExplorer.from_engine(engine) + assert exp.sample_data("empty_tbl") == [] + + +def test_execute_read_only_select(explorer): + rows = explorer.execute_read_only("SELECT id, name FROM customers ORDER BY id") + assert len(rows) == 2 + assert rows[0]["name"] == "Alice" + + +def test_execute_read_only_rejects_insert(explorer): + with pytest.raises(ValueError, match="Write operations not allowed"): + explorer.execute_read_only("INSERT INTO customers VALUES (3, 'Eve', 'eve@x.com')") + + +def test_execute_read_only_rejects_drop(explorer): + with pytest.raises(ValueError, match="Write operations not allowed"): + explorer.execute_read_only("DROP TABLE customers") + + +def test_execute_read_only_rejects_cte_delete(explorer): + # SQLite는 CTE + DELETE를 지원하지 않으므로 rollback만 검증 + # prefix guard는 통과하지만 실제 변경이 없음을 확인 + initial = explorer.execute_read_only("SELECT COUNT(*) as cnt FROM customers") + initial_count = initial[0]["cnt"] + + # rollback-always 검증: SELECT는 정상 동작, 데이터 변경 없음 + rows = explorer.execute_read_only("SELECT * FROM customers WHERE id = 1") + assert len(rows) == 1 + + after = explorer.execute_read_only("SELECT COUNT(*) as cnt FROM customers") + assert after[0]["cnt"] == initial_count + + +def test_from_engine_shares_data(engine): + from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyExplorer + + exp1 = SQLAlchemyExplorer.from_engine(engine) + exp2 = SQLAlchemyExplorer.from_engine(engine) + + rows1 = exp1.sample_data("customers") + rows2 = exp2.sample_data("customers") + assert rows1 == rows2 + + +def test_integration_with_sqlalchemydb(engine): + from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyDB, SQLAlchemyExplorer + + # SQLAlchemyDB와 같은 engine을 SQLAlchemyExplorer가 공유 + explorer = SQLAlchemyExplorer.from_engine(engine) + + tables = explorer.list_tables() + assert "customers" in tables + + ddl = explorer.get_ddl("customers") + assert "id" in ddl From c91e075147f91d35bc891ec50e8e78d91d5cd0b6 Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 2 Mar 2026 20:34:18 +0900 Subject: [PATCH 6/9] docs(tutorial): add DB explorer usage section (7-1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit build_explorer_from_url 기본 사용, 전체 테이블 루프, PostgreSQL/MySQL 연결, from_engine 재사용, 쓰기 거부 예시 추가. --- docs/tutorials/v2-complete-tutorial.md | 94 ++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/docs/tutorials/v2-complete-tutorial.md b/docs/tutorials/v2-complete-tutorial.md index 707440b..1f2cd06 100644 --- a/docs/tutorials/v2-complete-tutorial.md +++ b/docs/tutorials/v2-complete-tutorial.md @@ -20,6 +20,7 @@ 5-1. 샘플 문서 자동 생성 6. 가장 쉬운 로컬 스모크 테스트 (API 키 없이) 7. BaselineNL2SQL 기본 사용 (KeywordRetriever) +7-1. DB 탐색: SQLAlchemyExplorer 8. 실제 LLM 연결 (OpenAI / Anthropic) 9. VectorRetriever 기초 (빠른 시작) 10. 문서 파싱: MarkdownLoader / PlainTextLoader / DirectoryLoader / PDFLoader @@ -232,6 +233,99 @@ print(rows) --- +## 7-1) DB 탐색: SQLAlchemyExplorer + +LLM에게 넘길 스키마 정보가 필요하거나, 처음 보는 DB를 손으로 살펴볼 때 사용합니다. +카탈로그를 미리 구축하지 않아도 DDL + 샘플 데이터를 바로 꺼내볼 수 있습니다. + +### 기본 사용 + +```python +from lang2sql import build_explorer_from_url + +exp = build_explorer_from_url("sqlite:///sample.db") + +# 1) 어떤 테이블이 있는지 +print(exp.list_tables()) +# ['customers', 'orders', ...] + +# 2) 테이블 DDL — CREATE TABLE 원문 +print(exp.get_ddl("orders")) +# CREATE TABLE orders ( +# id INTEGER PRIMARY KEY, +# customer_id INTEGER NOT NULL REFERENCES customers(id), +# amount REAL, +# status TEXT DEFAULT 'pending' +# ) + +# 3) 실제 샘플 데이터 (기본 5행) +print(exp.sample_data("orders")) +# [{'id': 1, 'customer_id': 1, 'amount': 99.9, 'status': 'shipped'}, ...] + +# 4) 커스텀 읽기 전용 질의 +print(exp.execute_read_only("SELECT status, COUNT(*) AS cnt FROM orders GROUP BY status")) +# [{'status': 'pending', 'cnt': 3}, {'status': 'shipped', 'cnt': 2}] +``` + +### 전체 테이블 한 번에 둘러보기 + +```python +from lang2sql import build_explorer_from_url + +exp = build_explorer_from_url("sqlite:///sample.db") + +for table in exp.list_tables(): + print(f"\n=== {table} ===") + print(exp.get_ddl(table)) + rows = exp.sample_data(table, limit=2) + print("샘플:", rows) +``` + +### PostgreSQL / MySQL 연결 + +URL만 바꾸면 됩니다. + +```python +from lang2sql import build_explorer_from_url + +# PostgreSQL +exp = build_explorer_from_url("postgresql://user:password@localhost:5432/mydb") + +# MySQL +exp = build_explorer_from_url("mysql+pymysql://user:password@localhost:3306/mydb") + +# schema 지정 (schema 파라미터) +exp = build_explorer_from_url("postgresql://user:pass@host/db", schema="analytics") +print(exp.list_tables()) # analytics 스키마 테이블만 +``` + +### 기존 SQLAlchemyDB engine 재사용 + +연결 풀을 따로 만들지 않고 공유할 수 있습니다. + +```python +from lang2sql.integrations.db import SQLAlchemyDB, SQLAlchemyExplorer + +db = SQLAlchemyDB("sqlite:///sample.db") +exp = SQLAlchemyExplorer.from_engine(db._engine) + +# db는 SQL 실행, exp는 탐색 — 같은 연결 풀 공유 +rows = db.execute("SELECT COUNT(*) AS cnt FROM orders") +ddl = exp.get_ddl("orders") +``` + +### 쓰기 구문은 거부됩니다 + +```python +exp.execute_read_only("DROP TABLE orders") +# ValueError: Write operations not allowed: 'DROP TABLE orders' + +exp.execute_read_only("INSERT INTO orders VALUES (99, 1, 0, 'test')") +# ValueError: Write operations not allowed: 'INSERT INTO orders ...' +``` + +--- + ## 8) 실제 LLM 연결 (OpenAI / Anthropic) LLM 백엔드는 교체 가능합니다. From 6de5d5701a0ef908aaa9b7daea624ad4ab583503 Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 2 Mar 2026 20:37:34 +0900 Subject: [PATCH 7/9] feat(vector): add save/load persistence to VectorRetriever Add VectorRetriever.save() and VectorRetriever.load() for FAISS-backed file persistence. save() writes the vector index and a .registry JSON sidecar; load() restores both. InMemoryVectorStore raises NotImplementedError. Add 3 tests (save/load roundtrip, registry integrity, InMemory raises). --- src/lang2sql/components/retrieval/vector.py | 72 +++++++++++++++++++++ tests/test_components_vector_retriever.py | 71 ++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/src/lang2sql/components/retrieval/vector.py b/src/lang2sql/components/retrieval/vector.py index ca1c454..971edda 100644 --- a/src/lang2sql/components/retrieval/vector.py +++ b/src/lang2sql/components/retrieval/vector.py @@ -168,6 +168,78 @@ def add(self, chunks: list[IndexedChunk]) -> None: self._vectorstore.upsert(ids, vectors) self._registry.update({c["chunk_id"]: c for c in chunks}) + # ── Persistence ────────────────────────────────────────────────── + + def save(self, path: str) -> None: + """벡터 인덱스와 registry를 path에 저장. + + FAISSVectorStore처럼 save()를 지원하는 store에서만 동작한다. + InMemoryVectorStore 등 save()가 없는 store는 NotImplementedError. + + 저장 파일: + {path} — FAISSVectorStore 벡터 인덱스 + {path}.meta — chunk_id 순서 목록 (FAISSVectorStore 내부) + {path}.registry — registry JSON + """ + import json + import pathlib + + save_fn = getattr(self._vectorstore, "save", None) + if save_fn is None: + raise NotImplementedError( + f"{type(self._vectorstore).__name__} does not support save(). " + "Use FAISSVectorStore for file-based persistence." + ) + save_fn(path) + pathlib.Path(path + ".registry").write_text( + json.dumps(self._registry), encoding="utf-8" + ) + + @classmethod + def load( + cls, + path: str, + *, + embedding: EmbeddingPort, + top_n: int = 5, + score_threshold: float = 0.0, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> "VectorRetriever": + """저장된 인덱스와 registry를 복원해 VectorRetriever를 반환. + + save()로 저장한 path를 그대로 전달한다. + embedding은 쿼리 시 embed_query()에 사용되므로 반드시 전달해야 한다. + + Args: + path: save() 시 사용한 경로. + embedding: EmbeddingPort 구현체. + top_n: 최대 반환 스키마/컨텍스트 수. 기본 5. + score_threshold: 이 점수 이하는 결과에서 제외. 기본 0.0. + """ + import json + import pathlib + + from ...integrations.vectorstore.faiss_ import FAISSVectorStore + + registry_path = pathlib.Path(path + ".registry") + if not registry_path.exists(): + raise FileNotFoundError(f"Registry file not found: {registry_path}") + + store = FAISSVectorStore.load(path) + registry = json.loads(registry_path.read_text(encoding="utf-8")) + return cls( + vectorstore=store, + embedding=embedding, + registry=registry, + top_n=top_n, + score_threshold=score_threshold, + name=name, + hook=hook, + ) + + # ── Core retrieval ──────────────────────────────────────────────── + def _run(self, query: str) -> RetrievalResult: """ Args: diff --git a/tests/test_components_vector_retriever.py b/tests/test_components_vector_retriever.py index 39d2515..aa6fe44 100644 --- a/tests/test_components_vector_retriever.py +++ b/tests/test_components_vector_retriever.py @@ -502,3 +502,74 @@ def test_catalog_chunker_split_batch(): by_chunk = [c for entry in CATALOG for c in chunker.chunk(entry)] assert [c["chunk_id"] for c in by_split] == [c["chunk_id"] for c in by_chunk] + + +# --------------------------------------------------------------------------- +# 20-22. VectorRetriever save / load (FAISS 필요) +# --------------------------------------------------------------------------- + +faiss = pytest.importorskip("faiss", reason="faiss-cpu not installed") + + +class FakeEmbeddingFAISS: + """FAISS L2 정규화에서 zero-vector 오류가 안 나도록 비영벡터를 반환.""" + + def _vec(self, text: str) -> list[float]: + # 텍스트별로 구별 가능한 비영벡터 + h = abs(hash(text)) % 900 + 100 + return [h * 0.001, 1.0, 1.0, 1.0] + + def embed_query(self, text: str) -> list[float]: + return self._vec(text) + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return [self._vec(t) for t in texts] + + +def test_save_and_load_returns_same_results(tmp_path): + """save → load 후 동일 쿼리에 동일 스키마가 반환된다.""" + path = str(tmp_path / "catalog") + embedding = FakeEmbeddingFAISS() + + from lang2sql.integrations.vectorstore.faiss_ import FAISSVectorStore + + store = FAISSVectorStore(index_path=path + ".faiss") + chunks = CatalogChunker().split(CATALOG) + original = VectorRetriever.from_chunks(chunks, embedding=embedding, vectorstore=store) + original.save(path) + + loaded = VectorRetriever.load(path, embedding=embedding) + result = loaded.run("주문 정보") + + assert len(result.schemas) > 0 + assert result.schemas[0]["name"] == original.run("주문 정보").schemas[0]["name"] + + +def test_load_registry_intact(tmp_path): + """load 후 registry의 키·값이 원본과 동일하다.""" + path = str(tmp_path / "catalog") + embedding = FakeEmbeddingFAISS() + + from lang2sql.integrations.vectorstore.faiss_ import FAISSVectorStore + + store = FAISSVectorStore(index_path=path + ".faiss") + chunks = CatalogChunker().split(CATALOG) + original = VectorRetriever.from_chunks(chunks, embedding=embedding, vectorstore=store) + original.save(path) + + loaded = VectorRetriever.load(path, embedding=embedding) + + assert set(loaded._registry.keys()) == set(original._registry.keys()) + for chunk_id, chunk in original._registry.items(): + assert loaded._registry[chunk_id]["text"] == chunk["text"] + assert loaded._registry[chunk_id]["source_id"] == chunk["source_id"] + + +def test_save_raises_for_inmemory(): + """InMemoryVectorStore는 save()를 지원하지 않아 NotImplementedError가 발생한다.""" + embedding = FakeEmbeddingFAISS() + chunks = CatalogChunker().split(CATALOG) + retriever = VectorRetriever.from_chunks(chunks, embedding=embedding) # InMemory 기본값 + + with pytest.raises(NotImplementedError, match="does not support save"): + retriever.save("/tmp/should_not_exist") From 4e3c06d38c45d18529ca29c264cc8dc1152b2415 Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 2 Mar 2026 20:48:51 +0900 Subject: [PATCH 8/9] refactor(vector): decouple VectorRetriever.load() from FAISSVectorStore Accept vectorstore: VectorStorePort as a parameter instead of instantiating FAISSVectorStore internally. components/ no longer imports from integrations/. Store restoration is the caller's responsibility. --- src/lang2sql/components/retrieval/vector.py | 23 ++++++++++++--------- tests/test_components_vector_retriever.py | 6 ++++-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/lang2sql/components/retrieval/vector.py b/src/lang2sql/components/retrieval/vector.py index 971edda..c104590 100644 --- a/src/lang2sql/components/retrieval/vector.py +++ b/src/lang2sql/components/retrieval/vector.py @@ -200,36 +200,39 @@ def load( cls, path: str, *, + vectorstore: VectorStorePort, embedding: EmbeddingPort, top_n: int = 5, score_threshold: float = 0.0, name: Optional[str] = None, hook: Optional[TraceHook] = None, ) -> "VectorRetriever": - """저장된 인덱스와 registry를 복원해 VectorRetriever를 반환. + """저장된 registry를 복원해 VectorRetriever를 반환. - save()로 저장한 path를 그대로 전달한다. - embedding은 쿼리 시 embed_query()에 사용되므로 반드시 전달해야 한다. + 벡터 인덱스 복원은 호출자가 직접 수행한 뒤 vectorstore로 전달한다. + 이렇게 하면 VectorRetriever가 특정 store 구현체에 의존하지 않는다. Args: - path: save() 시 사용한 경로. - embedding: EmbeddingPort 구현체. - top_n: 최대 반환 스키마/컨텍스트 수. 기본 5. + path: save() 시 사용한 경로 (registry 파일 위치 기준). + vectorstore: 이미 로드된 VectorStorePort 구현체. + embedding: EmbeddingPort 구현체. + top_n: 최대 반환 스키마/컨텍스트 수. 기본 5. score_threshold: 이 점수 이하는 결과에서 제외. 기본 0.0. + + Example: + store = FAISSVectorStore.load(path) + retriever = VectorRetriever.load(path, vectorstore=store, embedding=emb) """ import json import pathlib - from ...integrations.vectorstore.faiss_ import FAISSVectorStore - registry_path = pathlib.Path(path + ".registry") if not registry_path.exists(): raise FileNotFoundError(f"Registry file not found: {registry_path}") - store = FAISSVectorStore.load(path) registry = json.loads(registry_path.read_text(encoding="utf-8")) return cls( - vectorstore=store, + vectorstore=vectorstore, embedding=embedding, registry=registry, top_n=top_n, diff --git a/tests/test_components_vector_retriever.py b/tests/test_components_vector_retriever.py index aa6fe44..a960614 100644 --- a/tests/test_components_vector_retriever.py +++ b/tests/test_components_vector_retriever.py @@ -538,7 +538,8 @@ def test_save_and_load_returns_same_results(tmp_path): original = VectorRetriever.from_chunks(chunks, embedding=embedding, vectorstore=store) original.save(path) - loaded = VectorRetriever.load(path, embedding=embedding) + loaded_store = FAISSVectorStore.load(path) + loaded = VectorRetriever.load(path, vectorstore=loaded_store, embedding=embedding) result = loaded.run("주문 정보") assert len(result.schemas) > 0 @@ -557,7 +558,8 @@ def test_load_registry_intact(tmp_path): original = VectorRetriever.from_chunks(chunks, embedding=embedding, vectorstore=store) original.save(path) - loaded = VectorRetriever.load(path, embedding=embedding) + loaded_store = FAISSVectorStore.load(path) + loaded = VectorRetriever.load(path, vectorstore=loaded_store, embedding=embedding) assert set(loaded._registry.keys()) == set(original._registry.keys()) for chunk_id, chunk in original._registry.items(): From 3e40b3393e3c36fbefab624e9dcf9a71afd4a848 Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 2 Mar 2026 20:49:46 +0900 Subject: [PATCH 9/9] =?UTF-8?q?refactor=20:=20precommit=20=EC=A0=81?= =?UTF-8?q?=EC=9A=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/BaseComponent_ko.md | 8 +-- docs/Hook_and_exception_ko.md | 10 ++-- .../getting-started-without-datahub.md | 53 +++++++++++++++---- src/lang2sql/__init__.py | 13 +++-- src/lang2sql/core/ports.py | 4 +- src/lang2sql/integrations/db/sqlalchemy_.py | 20 +++++-- tests/test_components_vector_retriever.py | 12 +++-- .../test_integrations_sqlalchemy_explorer.py | 11 ++-- 8 files changed, 100 insertions(+), 31 deletions(-) diff --git a/docs/BaseComponent_ko.md b/docs/BaseComponent_ko.md index c98635f..657a79c 100644 --- a/docs/BaseComponent_ko.md +++ b/docs/BaseComponent_ko.md @@ -190,13 +190,15 @@ retriever = FunctionalComponent(my_retriever, name="MyRetriever", hook=hook) ```python from lang2sql.core.hooks import MemoryHook +from lang2sql.flows.baseline import SequentialFlow + hook = MemoryHook() -flow = BaselineFlow(steps=[...], hook=hook) # 또는 컴포넌트마다 hook 주입 -out = flow.run_query("지난달 매출") +flow = SequentialFlow(steps=[...], hook=hook) # 또는 컴포넌트마다 hook 주입 +out = flow.run("지난달 매출") # 이벤트 확인 -for e in hook.events: +for e in hook.snapshot(): print(e.phase, e.component, e.duration_ms, e.error) ``` diff --git a/docs/Hook_and_exception_ko.md b/docs/Hook_and_exception_ko.md index c5764e5..b2c650a 100644 --- a/docs/Hook_and_exception_ko.md +++ b/docs/Hook_and_exception_ko.md @@ -111,16 +111,16 @@ class MemoryHook: #### MemoryHook 사용 예시 -```py +```python from lang2sql.core.hooks import MemoryHook -from lang2sql.flows.baseline import BaselineFlow +from lang2sql.flows.baseline import SequentialFlow hook = MemoryHook() -flow = BaselineFlow(steps=[...], hook=hook) +flow = SequentialFlow(steps=[...], hook=hook) -out = flow.run_query("지난달 매출") +out = flow.run("지난달 매출") -for e in hook.events: +for e in hook.snapshot(): print(e.name, e.phase, e.component, e.duration_ms, e.error) ``` diff --git a/docs/tutorials/getting-started-without-datahub.md b/docs/tutorials/getting-started-without-datahub.md index 0792b6a..d24d0d3 100644 --- a/docs/tutorials/getting-started-without-datahub.md +++ b/docs/tutorials/getting-started-without-datahub.md @@ -122,19 +122,53 @@ print(f"FAISS index saved to: {OUTPUT_DIR}/catalog.faiss") ### 4) 실행 +v2 CLI는 외부 벡터 인덱스 경로를 인수로 받지 않습니다. +앞서 생성한 FAISS 인덱스를 활용하려면 Python API로 파이프라인을 직접 구성합니다. + +```python +# run_query.py +import os +from dotenv import load_dotenv +from lang2sql import CatalogChunker, VectorRetriever +from lang2sql.integrations.db import SQLAlchemyDB +from lang2sql.integrations.embedding import OpenAIEmbedding +from lang2sql.integrations.llm import OpenAILLM +from lang2sql.integrations.vectorstore import FAISSVectorStore +from lang2sql.flows.hybrid import HybridNL2SQL + +load_dotenv() + +INDEX_DIR = "./dev/table_info_db" +embedding = OpenAIEmbedding( + model=os.getenv("OPEN_AI_EMBEDDING_MODEL", "text-embedding-3-large"), + api_key=os.getenv("OPEN_AI_KEY"), +) + +# FAISS 인덱스 로드 후 파이프라인 구성 +store = FAISSVectorStore.load(f"{INDEX_DIR}/catalog.faiss") + +pipeline = HybridNL2SQL( + catalog=[], # FAISS에 이미 인덱싱돼 있으므로 빈 리스트 + llm=OpenAILLM(model=os.getenv("OPEN_AI_LLM_MODEL", "gpt-4o"), api_key=os.getenv("OPEN_AI_KEY")), + db=SQLAlchemyDB(os.getenv("DB_URL", "sqlite:///sample.db")), + embedding=embedding, + db_dialect=os.getenv("DB_TYPE", "sqlite"), +) + +rows = pipeline.run("주문 수를 집계하는 SQL을 만들어줘") +print(rows) +``` + +Streamlit UI: + ```bash -# Streamlit UI lang2sql run-streamlit +``` -# CLI 예시 (FAISS 인덱스 사용) -lang2sql query "주문 수를 집계하는 SQL을 만들어줘" \ - --vectordb-type faiss \ - --vectordb-location ./dev/table_info_db +CLI (카탈로그 없이 baseline만 가능): -# CLI 예시 (pgvector) -lang2sql query "주문 수를 집계하는 SQL을 만들어줘" \ - --vectordb-type pgvector \ - --vectordb-location "postgresql://pgvector:pgvector@localhost:5432/postgres" +```bash +lang2sql query "주문 수를 집계해줘" --flow baseline --dialect sqlite ``` ### 5) (선택) pgvector로 적재하기 @@ -229,4 +263,3 @@ VectorRetriever.from_chunks( print(f"pgvector collection populated: {TABLE}") ``` -주의: FAISS 디렉토리 또는 pgvector 컬렉션이 없으면 현재 코드는 DataHub에서 메타데이터를 가져와 인덱스를 생성하려고 시도합니다. DataHub를 사용하지 않는 경우 위 절차로 사전에 VectorDB를 만들어 두세요. diff --git a/src/lang2sql/__init__.py b/src/lang2sql/__init__.py index 6797507..66811de 100644 --- a/src/lang2sql/__init__.py +++ b/src/lang2sql/__init__.py @@ -1,4 +1,9 @@ -from .factory import build_db_from_env, build_embedding_from_env, build_explorer_from_url, build_llm_from_env +from .factory import ( + build_db_from_env, + build_embedding_from_env, + build_explorer_from_url, + build_llm_from_env, +) from .components.enrichment.context_enricher import ContextEnricher from .components.enrichment.question_profiler import QuestionProfiler from .components.execution.sql_executor import SQLExecutor @@ -50,6 +55,7 @@ from .integrations.llm.gemini_ import GeminiLLM from .integrations.llm.huggingface_ import HuggingFaceLLM from .integrations.llm.ollama_ import OllamaLLM + __all__ = [ # Data types "CatalogEntry", @@ -132,8 +138,8 @@ # --------------------------------------------------------------------------- _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "DataHubCatalogLoader": (".integrations.catalog.datahub_", "DataHubCatalogLoader"), - "FAISSVectorStore": (".integrations.vectorstore.faiss_", "FAISSVectorStore"), - "PGVectorStore": (".integrations.vectorstore.pgvector_", "PGVectorStore"), + "FAISSVectorStore": (".integrations.vectorstore.faiss_", "FAISSVectorStore"), + "PGVectorStore": (".integrations.vectorstore.pgvector_", "PGVectorStore"), } @@ -141,6 +147,7 @@ def __getattr__(name: str): if name in _LAZY_IMPORTS: module_path, attr = _LAZY_IMPORTS[name] import importlib + obj = getattr(importlib.import_module(module_path, package=__name__), attr) # Cache in module globals so subsequent accesses skip __getattr__ globals()[name] = obj diff --git a/src/lang2sql/core/ports.py b/src/lang2sql/core/ports.py index 80c20b3..d1bf462 100644 --- a/src/lang2sql/core/ports.py +++ b/src/lang2sql/core/ports.py @@ -78,6 +78,8 @@ def list_tables(self, schema: str | None = None) -> list[str]: ... def get_ddl(self, table: str, *, schema: str | None = None) -> str: ... - def sample_data(self, table: str, *, limit: int = 5, schema: str | None = None) -> list[dict]: ... + def sample_data( + self, table: str, *, limit: int = 5, schema: str | None = None + ) -> list[dict]: ... def execute_read_only(self, sql: str) -> list[dict]: ... diff --git a/src/lang2sql/integrations/db/sqlalchemy_.py b/src/lang2sql/integrations/db/sqlalchemy_.py index 4b90279..10f2ea6 100644 --- a/src/lang2sql/integrations/db/sqlalchemy_.py +++ b/src/lang2sql/integrations/db/sqlalchemy_.py @@ -32,7 +32,17 @@ def execute(self, sql: str) -> list[dict[str, Any]]: _WRITE_PREFIXES = frozenset( - {"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE", "MERGE"} + { + "INSERT", + "UPDATE", + "DELETE", + "DROP", + "ALTER", + "CREATE", + "TRUNCATE", + "REPLACE", + "MERGE", + } ) @@ -51,7 +61,9 @@ def __init__(self, url: str, *, schema: str | None = None) -> None: self._schema = schema @classmethod - def from_engine(cls, engine: "Engine", *, schema: str | None = None) -> "SQLAlchemyExplorer": + def from_engine( + cls, engine: "Engine", *, schema: str | None = None + ) -> "SQLAlchemyExplorer": """기존 engine 공유용. 연결 풀 중복 방지.""" instance = cls.__new__(cls) instance._engine = engine @@ -86,7 +98,9 @@ def get_ddl(self, table: str, *, schema: str | None = None) -> str: t = SATable(table, metadata, autoload_with=self._engine, schema=resolved_schema) return str(CreateTable(t).compile(self._engine)) - def sample_data(self, table: str, *, limit: int = 5, schema: str | None = None) -> list[dict]: + def sample_data( + self, table: str, *, limit: int = 5, schema: str | None = None + ) -> list[dict]: """실제 샘플 데이터 반환. f-string SQL 금지 — SQLAlchemy ORM select()로 identifier quoting 위임. diff --git a/tests/test_components_vector_retriever.py b/tests/test_components_vector_retriever.py index a960614..07c77cd 100644 --- a/tests/test_components_vector_retriever.py +++ b/tests/test_components_vector_retriever.py @@ -535,7 +535,9 @@ def test_save_and_load_returns_same_results(tmp_path): store = FAISSVectorStore(index_path=path + ".faiss") chunks = CatalogChunker().split(CATALOG) - original = VectorRetriever.from_chunks(chunks, embedding=embedding, vectorstore=store) + original = VectorRetriever.from_chunks( + chunks, embedding=embedding, vectorstore=store + ) original.save(path) loaded_store = FAISSVectorStore.load(path) @@ -555,7 +557,9 @@ def test_load_registry_intact(tmp_path): store = FAISSVectorStore(index_path=path + ".faiss") chunks = CatalogChunker().split(CATALOG) - original = VectorRetriever.from_chunks(chunks, embedding=embedding, vectorstore=store) + original = VectorRetriever.from_chunks( + chunks, embedding=embedding, vectorstore=store + ) original.save(path) loaded_store = FAISSVectorStore.load(path) @@ -571,7 +575,9 @@ def test_save_raises_for_inmemory(): """InMemoryVectorStore는 save()를 지원하지 않아 NotImplementedError가 발생한다.""" embedding = FakeEmbeddingFAISS() chunks = CatalogChunker().split(CATALOG) - retriever = VectorRetriever.from_chunks(chunks, embedding=embedding) # InMemory 기본값 + retriever = VectorRetriever.from_chunks( + chunks, embedding=embedding + ) # InMemory 기본값 with pytest.raises(NotImplementedError, match="does not support save"): retriever.save("/tmp/should_not_exist") diff --git a/tests/test_integrations_sqlalchemy_explorer.py b/tests/test_integrations_sqlalchemy_explorer.py index c0cd3ff..60dd9b2 100644 --- a/tests/test_integrations_sqlalchemy_explorer.py +++ b/tests/test_integrations_sqlalchemy_explorer.py @@ -5,11 +5,11 @@ import pytest from sqlalchemy import create_engine, text - # --------------------------------------------------------------------------- # Fixture: SQLite in-memory DB with FK schema # --------------------------------------------------------------------------- + @pytest.fixture() def engine(): eng = create_engine("sqlite:///:memory:") @@ -29,7 +29,9 @@ def engine(): status TEXT DEFAULT 'pending' ) """)) - conn.execute(text("INSERT INTO customers VALUES (1, 'Alice', 'alice@example.com')")) + conn.execute( + text("INSERT INTO customers VALUES (1, 'Alice', 'alice@example.com')") + ) conn.execute(text("INSERT INTO customers VALUES (2, 'Bob', 'bob@example.com')")) conn.execute(text("INSERT INTO orders VALUES (1, 1, 99.9, 'shipped')")) conn.execute(text("INSERT INTO orders VALUES (2, 2, 42.0, 'pending')")) @@ -48,6 +50,7 @@ def explorer(engine): # Tests # --------------------------------------------------------------------------- + def test_list_tables(explorer): tables = explorer.list_tables() assert set(tables) == {"customers", "orders"} @@ -98,7 +101,9 @@ def test_execute_read_only_select(explorer): def test_execute_read_only_rejects_insert(explorer): with pytest.raises(ValueError, match="Write operations not allowed"): - explorer.execute_read_only("INSERT INTO customers VALUES (3, 'Eve', 'eve@x.com')") + explorer.execute_read_only( + "INSERT INTO customers VALUES (3, 'Eve', 'eve@x.com')" + ) def test_execute_read_only_rejects_drop(explorer):