-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdb_manager.py
More file actions
137 lines (111 loc) · 4.93 KB
/
db_manager.py
File metadata and controls
137 lines (111 loc) · 4.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import sqlite3
import psycopg2
import pymssql
from config import Config
class DBManager:
def __init__(self):
self.db_type = Config.DB_TYPE
self.conn = None
self.connect()
def connect(self):
try:
if self.db_type == 'sqlite':
self.conn = sqlite3.connect(Config.SQLITE_DB_PATH)
elif self.db_type == 'postgres':
self.conn = psycopg2.connect(
host=Config.DB_HOST,
port=Config.DB_PORT or 5432,
user=Config.DB_USER,
password=Config.DB_PASSWORD,
dbname=Config.DB_NAME
)
elif self.db_type == 'mssql':
self.conn = pymssql.connect(
server=Config.DB_HOST,
port=Config.DB_PORT or 1433,
user=Config.DB_USER,
password=Config.DB_PASSWORD,
database=Config.DB_NAME
)
# Enable autocommit for some drivers if needed,
# or manage transactions in execute_query.
if self.db_type == 'postgres':
self.conn.autocommit = True
except Exception as e:
print(f"Error connecting to database ({self.db_type}): {e}")
raise e
def execute_query(self, sql):
"""Executes a SQL query and returns the results."""
if not self.conn:
self.connect()
cursor = self.conn.cursor()
try:
cursor.execute(sql)
# If query returns rows (SELECT)
if cursor.description:
columns = [desc[0] for desc in cursor.description]
results = cursor.fetchall()
return {"columns": columns, "data": results}
else:
# INSERT/UPDATE/DELETE
self.conn.commit()
return {"message": "Query executed successfully."}
except Exception as e:
return {"error": str(e)}
finally:
cursor.close()
def get_schema_context(self):
"""Returns a string representation of the database schema for the LLM."""
if not self.conn:
self.connect()
schema_str = ""
try:
if self.db_type == 'sqlite':
cursor = self.conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
for table in tables:
table_name = table[0]
schema_str += f"Table: {table_name}\n"
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
# PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk
col_strs = []
for col in columns:
col_strs.append(f"{col[1]} ({col[2]})")
schema_str += f"Columns: {', '.join(col_strs)}\n\n"
elif self.db_type == 'postgres' or self.db_type == 'mssql':
cursor = self.conn.cursor()
# Query to get all tables and columns
# Works for both Postgres and MSSQL (Standard SQL)
query = """
SELECT table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema = 'public' OR table_schema = 'dbo'
ORDER BY table_name, ordinal_position;
"""
# Note: Adjust schema filter if necessary (public for pg, dbo for mssql usually)
# For more robustness, we might want to query tables first then columns,
# but standard information_schema is usually fine.
cursor.execute(query)
rows = cursor.fetchall()
current_table = ""
table_cols = []
for row in rows:
t_name = row[0]
c_name = row[1]
d_type = row[2]
if t_name != current_table:
if current_table:
schema_str += f"Table: {current_table}\nColumns: {', '.join(table_cols)}\n\n"
current_table = t_name
table_cols = []
table_cols.append(f"{c_name} ({d_type})")
if current_table:
schema_str += f"Table: {current_table}\nColumns: {', '.join(table_cols)}\n\n"
except Exception as e:
return f"Error retrieving schema: {e}"
return schema_str
def close(self):
if self.conn:
self.conn.close()