|
4 | 4 | `query` CLI 명령어를 제공합니다. |
5 | 5 | """ |
6 | 6 |
|
7 | | -import os |
8 | | - |
9 | 7 | import click |
10 | 8 |
|
11 | 9 | from cli.utils.logger import configure_logging |
|
16 | 14 | @click.command(name="query") |
17 | 15 | @click.argument("question", type=str) |
18 | 16 | @click.option( |
19 | | - "--database-env", |
20 | | - default="clickhouse", |
21 | | - help="사용할 데이터베이스 환경 (기본값: clickhouse)", |
22 | | -) |
23 | | -@click.option( |
24 | | - "--retriever-name", |
25 | | - default="기본", |
26 | | - help="테이블 검색기 이름 (기본값: 기본)", |
| 17 | + "--flow", |
| 18 | + type=click.Choice(["baseline", "enriched"]), |
| 19 | + default="baseline", |
| 20 | + help="사용할 플로우 (기본값: baseline)", |
27 | 21 | ) |
28 | 22 | @click.option( |
29 | 23 | "--top-n", |
|
32 | 26 | help="검색된 상위 테이블 수 제한 (기본값: 5)", |
33 | 27 | ) |
34 | 28 | @click.option( |
35 | | - "--device", |
36 | | - default="cpu", |
37 | | - help="LLM 실행에 사용할 디바이스 (기본값: cpu)", |
| 29 | + "--dialect", |
| 30 | + default=None, |
| 31 | + help="SQL 방언 (예: sqlite, postgresql, mysql, bigquery, duckdb)", |
38 | 32 | ) |
39 | 33 | @click.option( |
40 | | - "--use-enriched-graph", |
| 34 | + "--no-gate", |
41 | 35 | is_flag=True, |
42 | | - help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부", |
43 | | -) |
44 | | -@click.option( |
45 | | - "--vectordb-type", |
46 | | - type=click.Choice(["faiss", "pgvector"]), |
47 | | - default="faiss", |
48 | | - help="사용할 벡터 데이터베이스 타입 (기본값: faiss)", |
49 | | -) |
50 | | -@click.option( |
51 | | - "--vectordb-location", |
52 | | - help=( |
53 | | - "VectorDB 위치 설정\n" |
54 | | - "- FAISS: 디렉토리 경로 (예: ./my_vectordb)\n" |
55 | | - "- pgvector: 연결 문자열 (예: postgresql://user:pass@host:port/db)\n" |
56 | | - "기본값: FAISS는 './dev/table_info_db', pgvector는 환경변수 사용" |
57 | | - ), |
| 36 | + help="QuestionGate 비활성화 (enriched 플로우 전용)", |
58 | 37 | ) |
59 | 38 | def query_command( |
60 | 39 | question: str, |
61 | | - database_env: str, |
62 | | - retriever_name: str, |
| 40 | + flow: str, |
63 | 41 | top_n: int, |
64 | | - device: str, |
65 | | - use_enriched_graph: bool, |
66 | | - vectordb_type: str = "faiss", |
67 | | - vectordb_location: str = None, |
| 42 | + dialect: str, |
| 43 | + no_gate: bool, |
68 | 44 | ) -> None: |
69 | | - """자연어 질문을 SQL 쿼리로 변환하여 출력합니다. |
| 45 | + """자연어 질문을 SQL 쿼리로 변환하여 실행 결과를 출력합니다. |
70 | 46 |
|
71 | | - Args: |
72 | | - question (str): SQL로 변환할 자연어 질문 |
73 | | - database_env (str): 사용할 데이터베이스 환경 |
74 | | - retriever_name (str): 테이블 검색기 이름 |
75 | | - top_n (int): 검색된 상위 테이블 수 제한 |
76 | | - device (str): LLM 실행 디바이스 |
77 | | - use_enriched_graph (bool): 확장된 그래프 사용 여부 |
78 | | - vectordb_type (str): 벡터 데이터베이스 타입 ("faiss" 또는 "pgvector") |
79 | | - vectordb_location (Optional[str]): 벡터DB 경로 또는 연결 URL |
| 47 | + 환경변수(LLM_PROVIDER, EMBEDDING_PROVIDER, DB_TYPE 등)로 설정을 제어합니다. |
80 | 48 | """ |
81 | 49 | try: |
82 | | - from engine.query_executor import execute_query, extract_sql_from_result |
| 50 | + from lang2sql.factory import ( |
| 51 | + build_db_from_env, |
| 52 | + build_embedding_from_env, |
| 53 | + build_llm_from_env, |
| 54 | + ) |
| 55 | + from lang2sql.flows import BaselineNL2SQL, EnrichedNL2SQL |
83 | 56 |
|
84 | | - os.environ["VECTORDB_TYPE"] = vectordb_type |
| 57 | + llm = build_llm_from_env() |
| 58 | + db = build_db_from_env() |
85 | 59 |
|
86 | | - if vectordb_location: |
87 | | - os.environ["VECTORDB_LOCATION"] = vectordb_location |
| 60 | + if flow == "baseline": |
| 61 | + pipeline = BaselineNL2SQL( |
| 62 | + catalog=[], |
| 63 | + llm=llm, |
| 64 | + db=db, |
| 65 | + db_dialect=dialect, |
| 66 | + ) |
| 67 | + else: |
| 68 | + embedding = build_embedding_from_env() |
| 69 | + pipeline = EnrichedNL2SQL( |
| 70 | + catalog=[], |
| 71 | + llm=llm, |
| 72 | + db=db, |
| 73 | + embedding=embedding, |
| 74 | + db_dialect=dialect, |
| 75 | + gate_enabled=not no_gate, |
| 76 | + top_n=top_n, |
| 77 | + ) |
88 | 78 |
|
89 | | - res = execute_query( |
90 | | - query=question, |
91 | | - database_env=database_env, |
92 | | - retriever_name=retriever_name, |
93 | | - top_n=top_n, |
94 | | - device=device, |
95 | | - use_enriched_graph=use_enriched_graph, |
96 | | - ) |
| 79 | + rows = pipeline.run(question) |
| 80 | + if rows: |
| 81 | + import json |
97 | 82 |
|
98 | | - sql = extract_sql_from_result(res) |
99 | | - if sql: |
100 | | - print(sql) |
| 83 | + print(json.dumps(rows, ensure_ascii=False, indent=2)) |
101 | 84 | else: |
102 | | - generated_query = res.get("generated_query") |
103 | | - if generated_query: |
104 | | - query_text = ( |
105 | | - generated_query.content |
106 | | - if hasattr(generated_query, "content") |
107 | | - else str(generated_query) |
108 | | - ) |
109 | | - print(query_text) |
| 85 | + print("(결과 없음)") |
110 | 86 |
|
111 | 87 | except Exception as e: |
112 | 88 | logger.error("쿼리 처리 중 오류 발생: %s", e) |
|
0 commit comments