Skip to content

Commit b5a1850

Browse files
committed
Allow multiple profiles
1 parent a39d020 commit b5a1850

File tree

9 files changed

+257
-199
lines changed

9 files changed

+257
-199
lines changed

src/client/content/config/databases.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,21 @@
1616
from streamlit import session_state as state
1717

1818
import client.utils.api_call as api_call
19+
import client.utils.st_common as st_common
1920
import common.logging_config as logging_config
21+
from common.schema import SelectAIProfileType
2022

2123
logger = logging_config.logging.getLogger("client.content.config.database")
2224

2325

2426
#####################################################
2527
# Functions
2628
#####################################################
27-
def get_databases(force: bool = False) -> dict[str, dict]:
29+
def get_databases(force: bool = False) -> None:
2830
"""Get a dictionary of all Databases and Store Vector Store Tables"""
2931
if "database_config" not in state or state["database_config"] == {} or force:
3032
try:
31-
endpoint = "v1/databases"
32-
response = api_call.get(endpoint=endpoint)
33+
response = api_call.get(endpoint="v1/databases")
3334
state["database_config"] = {
3435
item["name"]: {k: v for k, v in item.items() if k != "name"} for item in response
3536
}
@@ -51,9 +52,8 @@ def patch_database(name: str, user: str, password: str, dsn: str, wallet_passwor
5152
or not state.database_config[name]["connected"]
5253
):
5354
try:
54-
endpoint = f"v1/databases/{name}"
5555
_ = api_call.patch(
56-
endpoint=endpoint,
56+
endpoint=f"v1/databases/{name}",
5757
payload={
5858
"json": {
5959
"user": user,
@@ -66,28 +66,6 @@ def patch_database(name: str, user: str, password: str, dsn: str, wallet_passwor
6666
logger.info("Database updated: %s", name)
6767
state.database_config[name]["connected"] = True
6868
get_databases(force=True)
69-
endpoint = "v1/selectai/enabled"
70-
selectai = api_call.get(
71-
endpoint=endpoint,
72-
)
73-
logger.info("SelectAI enabled: %s", selectai["enabled"])
74-
state.database_config[name]["selectai"] = selectai["enabled"]
75-
76-
# Check if SelectAI is enabled and get objects if so
77-
if selectai["enabled"]:
78-
try:
79-
endpoint = "v1/selectai/objects"
80-
selectai_objects = api_call.get(
81-
endpoint=endpoint,
82-
)
83-
logger.info("SelectAI objects retrieved: %d objects", len(selectai_objects))
84-
state.database_config[name]["selectai_objects"] = selectai_objects
85-
except api_call.ApiError as ex:
86-
logger.error("Failed to retrieve SelectAI objects: %s", ex)
87-
state.database_config[name]["selectai_objects"] = []
88-
else:
89-
state.database_config[name]["selectai_objects"] = None
90-
9169
except api_call.ApiError as ex:
9270
logger.error("Database not updated: %s (%s)", name, ex)
9371
state.database_config[name]["connected"] = False
@@ -102,21 +80,39 @@ def drop_vs(vs: dict) -> None:
10280
api_call.delete(endpoint=f"v1/embed/{vs['vector_store']}")
10381
get_databases(force=True)
10482

83+
def select_ai_profile() -> None:
84+
"""Update the chosen SelectAI Profile"""
85+
st_common.update_user_settings("selectai")
86+
st_common.patch_settings()
87+
selectai_df.clear()
10588

106-
def update_selectai(sai_df: pd.DataFrame ) -> None:
89+
@st.cache_data
90+
def selectai_df(profile):
91+
"""Get SelectAI Object List and produce Dataframe"""
92+
logger.info("Retrieving objects from SelectAI Profile: %s", profile)
93+
st_common.patch_settings()
94+
selectai_objects = api_call.get(endpoint="v1/selectai/objects")
95+
df = pd.DataFrame(selectai_objects, columns=["owner", "name", "enabled"])
96+
df.columns = ["Owner", "Name", "Enabled"]
97+
return df
98+
99+
100+
def update_selectai(sai_new_df: pd.DataFrame, sai_old_df: pd.DataFrame) -> None:
107101
"""Update SelectAI Object List"""
108-
enabled_objects = sai_df[sai_df["Enabled"]].drop(columns=["Enabled"])
109-
enabled_objects.columns = enabled_objects.columns.str.lower()
110-
try:
111-
endpoint = "v1/selectai/objects"
112-
result = api_call.patch(
113-
endpoint=endpoint,
114-
payload={"json": json.loads(enabled_objects.to_json(orient="records"))}
115-
)
116-
logger.info("SelectAI Updated.")
117-
state.database_config["DEFAULT"]["selectai_objects"] = result
118-
except api_call.ApiError as ex:
119-
logger.error("SelectAI not updated: %s", ex)
102+
changes = sai_new_df[sai_new_df["Enabled"] != sai_old_df["Enabled"]]
103+
if changes.empty:
104+
st.toast("No changes detected.", icon="ℹ️")
105+
else:
106+
enabled_objects = sai_new_df[sai_new_df["Enabled"]].drop(columns=["Enabled"])
107+
enabled_objects.columns = enabled_objects.columns.str.lower()
108+
try:
109+
_ = api_call.patch(
110+
endpoint="v1/selectai/objects", payload={"json": json.loads(enabled_objects.to_json(orient="records"))}
111+
)
112+
logger.info("SelectAI Updated. Clearing Cache.")
113+
selectai_df.clear()
114+
except api_call.ApiError as ex:
115+
logger.error("SelectAI not updated: %s", ex)
120116

121117

122118
#####################################################
@@ -209,29 +205,33 @@ def main() -> None:
209205
# Select AI
210206
#############################################
211207
st.subheader("SelectAI", divider="red")
212-
if state.database_config[name]["selectai"]:
213-
st.write("Tables eligible and enabled/disabled for SelectAI.")
214-
if state.database_config[name]["selectai_objects"]:
215-
df = pd.DataFrame(
216-
state.database_config[name]["selectai_objects"], columns=["owner", "name", "enabled"]
217-
)
218-
df.columns = ["Owner", "Name", "Enabled"]
208+
selectai_profiles = state.database_config[name]["selectai_profiles"]
209+
if state.database_config[name]["selectai"] and len(selectai_profiles) > 0:
210+
if not state.user_settings["selectai"]["profile"]:
211+
state.user_settings["selectai"]["profile"] = selectai_profiles[0]
212+
# Select Profile
213+
st.selectbox(
214+
"Profile:",
215+
options=selectai_profiles,
216+
index=selectai_profiles.index(state.user_settings["selectai"]["profile"]),
217+
key="selected_selectai_profile",
218+
on_change=select_ai_profile,
219+
)
220+
selectai_objects = selectai_df(state.user_settings["selectai"]["profile"])
221+
if not selectai_objects.empty:
219222
sai_df = st.data_editor(
220-
df,
223+
selectai_objects,
221224
column_config={
222225
"enabled": st.column_config.CheckboxColumn(label="Enabled", help="Toggle to enable or disable")
223226
},
224227
use_container_width=True,
225228
hide_index=True,
226229
)
227-
if st.button("Apply SelectAI Changes", type='primary'):
228-
changes = sai_df[sai_df["Enabled"] != df["Enabled"]]
229-
if not changes.empty:
230-
update_selectai(sai_df)
231-
else:
232-
st.write("No changes detected.")
230+
if st.button("Apply SelectAI Changes", type="secondary"):
231+
update_selectai(sai_df, selectai_objects)
232+
st.rerun()
233233
else:
234-
st.write("No tables found for SelectAI.")
234+
st.write("No objects found for SelectAI.")
235235
else:
236236
st.write("Unable to use SelectAI with Database.")
237237

src/client/utils/api_call.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def patch(
142142
timeout: int = 60,
143143
retries: int = 5,
144144
backoff_factor: float = 1.5,
145+
toast = True
145146
) -> None:
146147
"""PATCH Requests"""
147148
response = send_request(
@@ -153,7 +154,8 @@ def patch(
153154
retries=retries,
154155
backoff_factor=backoff_factor,
155156
)
156-
st.toast("Update Successful.", icon="✅")
157+
if toast:
158+
st.toast("Update Successful.", icon="✅")
157159
return response.json()
158160

159161

@@ -162,8 +164,10 @@ def delete(
162164
timeout: int = 60,
163165
retries: int = 5,
164166
backoff_factor: float = 1.5,
167+
toast = True
165168
) -> None:
166169
"""DELETE Requests"""
167170
response = send_request("DELETE", endpoint, timeout=timeout, retries=retries, backoff_factor=backoff_factor)
168171
success = response.json()["message"]
169-
st.toast(success, icon="✅")
172+
if toast:
173+
st.toast(success, icon="✅")

src/client/utils/st_common.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def patch_settings() -> None:
6565
endpoint="v1/settings",
6666
payload={"json": state.user_settings},
6767
params={"client": state.user_settings["client"]},
68+
toast=False,
6869
)
6970
except api_call.ApiError as ex:
7071
logger.error("%s Settings Update failed: %s", state.user_settings["client"], ex)
@@ -96,7 +97,7 @@ def update_user_settings(
9697
def is_db_configured() -> bool:
9798
"""Verify that a database is configured"""
9899
get_databases()
99-
return state.database_config[state.user_settings["vector_search"]["database"]].get("connected")
100+
return state.database_config[state.user_settings["database"]["alias"]].get("connected")
100101

101102

102103
def set_server_state() -> None:
@@ -154,7 +155,7 @@ def ll_sidebar() -> None:
154155
index=ll_idx,
155156
key="selected_ll_model_model",
156157
on_change=update_user_settings("ll_model"),
157-
disabled=state.user_settings["selectai"]["enabled"]
158+
disabled=state.user_settings["selectai"]["enabled"],
158159
)
159160

160161
# Temperature
@@ -234,7 +235,7 @@ def ll_sidebar() -> None:
234235
def tools_sidebar() -> None:
235236
"""SelectAI Sidebar Settings, conditional if all sorts of bs setup"""
236237

237-
def update_settings():
238+
def update_set_tool():
238239
state.user_settings["vector_search"]["enabled"] = state.selected_tool == "VectorSearch"
239240
state.user_settings["selectai"]["enabled"] = state.selected_tool == "SelectAI"
240241

@@ -248,6 +249,7 @@ def update_settings():
248249
state.user_settings["vector_search"]["enabled"] = disable_vector_search
249250
switch_prompt("sys", "Basic Example")
250251
else:
252+
db_alias = state.user_settings["database"]["alias"]
251253
tools = [
252254
("LLM Only", "Do not use tools", False),
253255
("SelectAI", "Use AI with Structured Data", disable_selectai),
@@ -262,14 +264,23 @@ def update_settings():
262264
logger.debug("SelectAI Disabled (OCI not configured.)")
263265
st.warning("OCI is not fully configured. Disabling SelectAI.", icon="⚠️")
264266
tools = [t for t in tools if t[0] != "SelectAI"]
267+
elif not state.database_config[db_alias]["selectai"]:
268+
logger.debug("SelectAI Disabled (Database not Compatible.)")
269+
st.warning("Database not Compatible. Disabling SelectAI.", icon="⚠️")
270+
tools = [t for t in tools if t[0] != "SelectAI"]
271+
elif len(state.database_config[db_alias]["selectai_profiles"]) == 0:
272+
logger.debug("SelectAI Disabled (No profiles found.)")
273+
st.warning("No profiles found. Disabling SelectAI.", icon="⚠️")
274+
tools = [t for t in tools if t[0] != "SelectAI"]
275+
265276
# Vector Search Requirements
266277
get_models(model_type="embed")
267278
available_embed_models = list(state.embed_model_enabled.keys())
268279
if not available_embed_models:
269280
logger.debug("Vector Search Disabled (no Embedding Models)")
270281
st.warning("No embedding models are configured and/or enabled. Disabling Vector Search.", icon="⚠️")
271282
tools = [t for t in tools if t[0] != "Vector Search"]
272-
elif not state.database_config[state.user_settings["vector_search"]["database"]].get("vector_stores"):
283+
elif not state.database_config[db_alias].get("vector_stores"):
273284
logger.debug("Vector Search Disabled (Database has no vector stores.)")
274285
st.warning("Database has no Vector Stores. Disabling Vector Search.", icon="⚠️")
275286
tools = [t for t in tools if t[0] != "Vector Search"]
@@ -293,19 +304,30 @@ def update_settings():
293304
# captions=tool_cap,
294305
index=tool_index,
295306
label_visibility="collapsed",
296-
on_change=update_settings,
307+
on_change=update_set_tool,
297308
key="selected_tool",
298309
)
299310
if state.selected_tool == "None":
300311
switch_prompt("sys", "Basic Example")
301312

313+
302314
#####################################################
303315
# SelectAI Options
304316
#####################################################
305317
def selectai_sidebar() -> None:
306318
"""SelectAI Sidebar Settings, conditional if Database/SelectAI are configured"""
307319
if state.user_settings["selectai"]["enabled"]:
308320
st.sidebar.subheader("SelectAI", divider="red")
321+
selectai_profiles = state.database_config[state.user_settings["database"]["alias"]]["selectai_profiles"]
322+
if not state.user_settings["selectai"]["profile"]:
323+
state.user_settings["selectai"]["profile"] = selectai_profiles[0]
324+
st.sidebar.selectbox(
325+
"Profile:",
326+
options=selectai_profiles,
327+
index=selectai_profiles.index(state.user_settings["selectai"]["profile"]),
328+
key="selected_selectai_profile",
329+
on_change=update_user_settings("selectai"),
330+
)
309331
st.sidebar.selectbox(
310332
"Action:",
311333
get_args(SelectAISettings.__annotations__["action"]),

src/common/schema.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
class DatabaseVectorStorage(BaseModel):
4444
"""Database Vector Storage Tables"""
4545

46-
database: Optional[str] = Field(default="DEFAULT", description="Name of Database (Alias)")
4746
vector_store: Optional[str] = Field(
4847
default=None, description="Vector Store Table Name (auto-generated, do not set)", readOnly=True
4948
)
@@ -58,7 +57,6 @@ class DatabaseVectorStorage(BaseModel):
5857
class DatabaseSelectAIObjects(BaseModel):
5958
"""Database SelectAI Objects"""
6059

61-
database: Optional[str] = Field(default="DEFAULT", description="Name of Database (Alias)")
6260
owner: Optional[str] = Field(default=None, description="Object Owner", readOnly=True)
6361
name: Optional[str] = Field(default=None, description="Object Name", readOnly=True)
6462
enabled: bool = Field(default=False, description="SelectAI Enabled")
@@ -80,14 +78,12 @@ class Database(DatabaseAuth):
8078
"""Database Object"""
8179

8280
name: str = Field(default="DEFAULT", description="Name of Database (Alias)")
83-
selectai: bool = Field(default=False, description="SelectAI Possible")
8481
connected: bool = Field(default=False, description="Connection Established")
8582
vector_stores: Optional[list[DatabaseVectorStorage]] = Field(
8683
default=None, description="Vector Storage (read-only)", readOnly=True
8784
)
88-
selectai_objects: Optional[list[DatabaseSelectAIObjects]] = Field(
89-
default=None, description="SelectAI Eligible Objects (read-only)", readOnly=True
90-
)
85+
selectai: bool = Field(default=False, description="SelectAI Possible")
86+
selectai_profiles: Optional[list] = Field(default=[], description="SelectAI Profiles (read-only)", readOnly=True)
9187
# Do not expose the connection to the endpoint
9288
_connection: oracledb.Connection = PrivateAttr(default=None)
9389

@@ -242,7 +238,7 @@ class SelectAISettings(BaseModel):
242238
"""Store SelectAI Settings"""
243239

244240
enabled: bool = Field(default=False, description="SelectAI Enabled")
245-
profile: str = Field(default="OPTIMIZER_PROFILE", description="SelectAI Profile", readOnly=True)
241+
profile: Optional[str] = Field(default=None, description="SelectAI Profile")
246242
action: Literal["runsql", "showsql", "explainsql", "narrate"] = Field(
247243
default="narrate", description="SelectAI Action"
248244
)
@@ -254,6 +250,12 @@ class OciSettings(BaseModel):
254250
auth_profile: Optional[str] = Field(default="DEFAULT", description="Oracle Cloud Settings Profile")
255251

256252

253+
class DatabaseSettings(BaseModel):
254+
"""Database Settings"""
255+
256+
alias: str = Field(default="DEFAULT", description="Name of Database (Alias)")
257+
258+
257259
class Settings(BaseModel):
258260
"""Server Settings"""
259261

@@ -269,6 +271,7 @@ class Settings(BaseModel):
269271
default_factory=PromptSettings, description="Prompt Engineering Settings"
270272
)
271273
oci: Optional[OciSettings] = Field(default_factory=OciSettings, description="OCI Settings")
274+
database: Optional[DatabaseSettings] = Field(default_factory=DatabaseSettings, description="Database Settings")
272275
vector_search: Optional[VectorSearchSettings] = Field(
273276
default_factory=VectorSearchSettings, description="Vector Search Settings"
274277
)
@@ -401,6 +404,7 @@ class EvaluationReport(Evaluation):
401404
PromptNameType = Prompt.__annotations__["name"]
402405
PromptCategoryType = Prompt.__annotations__["category"]
403406
PromptPromptType = PromptText.__annotations__["prompt"]
407+
SelectAIProfileType = Database.__annotations__["selectai_profiles"]
404408
TestSetsIdType = TestSets.__annotations__["tid"]
405409
TestSetsNameType = TestSets.__annotations__["name"]
406410
TestSetDateType = TestSets.__annotations__["created"]

src/server/bootstrap/database_def.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import server.utils.databases as databases
99
import server.utils.embedding as embedding
10+
import server.utils.selectai as selectai
1011
from common.schema import Database
1112
import common.logging_config as logging_config
1213

@@ -48,9 +49,9 @@ def main() -> list[Database]:
4849
db.connected = False
4950
continue
5051
db.vector_stores = embedding.get_vs(conn)
51-
db.selectai = databases.selectai_enabled(conn)
52+
db.selectai = selectai.enabled(conn)
5253
if db.selectai:
53-
db.selectai_objects = databases.get_selectai_objects(conn)
54+
db.selectai_profiles = selectai.get_profiles(conn)
5455
if not db.connection and len(database_objects) > 1:
5556
db.set_connection = databases.disconnect(conn)
5657
else:

0 commit comments

Comments
 (0)