Skip to content

Commit 87819ab

Browse files
committed
adjust selectai parameters
1 parent b2b56f4 commit 87819ab

File tree

3 files changed

+61
-44
lines changed

3 files changed

+61
-44
lines changed

src/client/utils/st_common.py

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def history_sidebar() -> None:
133133
def ll_sidebar() -> None:
134134
"""Language Model Sidebar"""
135135
st.sidebar.subheader("Language Model Parameters", divider="red")
136-
# If no user_settings defined for , set to the first available_ll_model
136+
# If no user_settings defined for model, set to the first available_ll_model
137137
if state.user_settings["ll_model"].get("model") is None:
138138
default_ll_model = list(state.ll_model_enabled.keys())[0]
139139
defaults = {
@@ -144,24 +144,31 @@ def ll_sidebar() -> None:
144144
}
145145
state.user_settings["ll_model"].update(defaults)
146146

147-
ll_idx = list(state.ll_model_enabled.keys()).index(state.user_settings["ll_model"]["model"])
148-
selected_model = st.sidebar.selectbox(
149-
"Chat model:",
150-
options=list(state.ll_model_enabled.keys()),
151-
index=ll_idx,
152-
key="selected_ll_model_model",
153-
on_change=update_user_settings("ll_model"),
154-
)
147+
selected_model = state.user_settings["ll_model"]["model"]
148+
ll_idx = list(state.ll_model_enabled.keys()).index(selected_model)
149+
if not state.user_settings["selectai"]["enabled"]:
150+
selected_model = st.sidebar.selectbox(
151+
"Chat model:",
152+
options=list(state.ll_model_enabled.keys()),
153+
index=ll_idx,
154+
key="selected_ll_model_model",
155+
on_change=update_user_settings("ll_model"),
156+
disabled=state.user_settings["selectai"]["enabled"]
157+
)
155158

156159
# Temperature
157160
temperature = state.ll_model_enabled[selected_model]["temperature"]
158161
user_temperature = state.user_settings["ll_model"]["temperature"]
162+
max_value = 2.0
163+
if state.user_settings["selectai"]["enabled"]:
164+
user_temperature = 1.0
165+
max_value = 1.0
159166
st.sidebar.slider(
160167
f"Temperature (Default: {temperature}):",
161168
help=help_text.help_dict["temperature"],
162169
value=user_temperature if user_temperature is not None else temperature,
163170
min_value=0.0,
164-
max_value=2.0,
171+
max_value=max_value,
165172
key="selected_ll_model_temperature",
166173
on_change=update_user_settings("ll_model"),
167174
)
@@ -184,39 +191,40 @@ def ll_sidebar() -> None:
184191
)
185192

186193
# Top P
187-
st.sidebar.slider(
188-
"Top P (Default: 1.0):",
189-
help=help_text.help_dict["top_p"],
190-
value=state.user_settings["ll_model"]["top_p"],
191-
min_value=0.0,
192-
max_value=1.0,
193-
key="selected_ll_model_top_p",
194-
on_change=update_user_settings("ll_model"),
195-
)
194+
if not state.user_settings["selectai"]["enabled"]:
195+
st.sidebar.slider(
196+
"Top P (Default: 1.0):",
197+
help=help_text.help_dict["top_p"],
198+
value=state.user_settings["ll_model"]["top_p"],
199+
min_value=0.0,
200+
max_value=1.0,
201+
key="selected_ll_model_top_p",
202+
on_change=update_user_settings("ll_model"),
203+
)
196204

197-
# Frequency Penalty
198-
frequency_penalty = state.ll_model_enabled[selected_model]["frequency_penalty"]
199-
user_frequency_penalty = state.user_settings["ll_model"]["frequency_penalty"]
200-
st.sidebar.slider(
201-
f"Frequency penalty (Default: {frequency_penalty}):",
202-
help=help_text.help_dict["frequency_penalty"],
203-
value=user_frequency_penalty if user_frequency_penalty is not None else frequency_penalty,
204-
min_value=-2.0,
205-
max_value=2.0,
206-
key="selected_ll_model_frequency_penalty",
207-
on_change=update_user_settings("ll_model"),
208-
)
205+
# Frequency Penalty
206+
frequency_penalty = state.ll_model_enabled[selected_model]["frequency_penalty"]
207+
user_frequency_penalty = state.user_settings["ll_model"]["frequency_penalty"]
208+
st.sidebar.slider(
209+
f"Frequency penalty (Default: {frequency_penalty}):",
210+
help=help_text.help_dict["frequency_penalty"],
211+
value=user_frequency_penalty if user_frequency_penalty is not None else frequency_penalty,
212+
min_value=-2.0,
213+
max_value=2.0,
214+
key="selected_ll_model_frequency_penalty",
215+
on_change=update_user_settings("ll_model"),
216+
)
209217

210-
# Presence Penalty
211-
st.sidebar.slider(
212-
"Presence penalty (Default: 0.0):",
213-
help=help_text.help_dict["presence_penalty"],
214-
value=state.user_settings["ll_model"]["presence_penalty"],
215-
min_value=-2.0,
216-
max_value=2.0,
217-
key="selected_ll_model_presence_penalty",
218-
on_change=update_user_settings("ll_model"),
219-
)
218+
# Presence Penalty
219+
st.sidebar.slider(
220+
"Presence penalty (Default: 0.0):",
221+
help=help_text.help_dict["presence_penalty"],
222+
value=state.user_settings["ll_model"]["presence_penalty"],
223+
min_value=-2.0,
224+
max_value=2.0,
225+
key="selected_ll_model_presence_penalty",
226+
on_change=update_user_settings("ll_model"),
227+
)
220228

221229

222230
#####################################################

src/server/endpoints.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -608,11 +608,11 @@ async def completion_generator(
608608
) -> AsyncGenerator[str, None]:
609609
"""Generate a completion from agent, stream the results"""
610610
client_settings = get_client_settings(client)
611+
model = request.model_dump()
611612
logger.debug("Settings: %s", client_settings)
612-
logger.debug("Request: %s", request.model_dump())
613+
logger.debug("Request: %s", model)
613614

614615
# Establish LL schema.Model Params (if the request specs a model, otherwise override from settings)
615-
model = request.model_dump()
616616
if not model["model"]:
617617
model = client_settings.ll_model.model_dump()
618618

@@ -651,9 +651,17 @@ async def completion_generator(
651651
logger.error("A settings exception occurred: %s", ex)
652652
raise HTTPException(status_code=500, detail="Unexpected Error.") from ex
653653

654+
db_conn = None
655+
# Setup selectai
656+
if client_settings.selectai.enabled:
657+
db_conn = get_client_db(client).connection
658+
databases.set_selectai_profile(db_conn, "temperature", model["temperature"])
659+
databases.set_selectai_profile(db_conn, "max_tokens", model["max_completion_tokens"])
660+
654661
# Setup vector_search
655662
embed_client, ctx_prompt = None, None
656663
if client_settings.vector_search.enabled:
664+
db_conn = get_client_db(client).connection
657665
embed_client = await models.get_client(
658666
MODEL_OBJECTS, client_settings.vector_search.model_dump(), oci_config
659667
)
@@ -671,7 +679,7 @@ async def completion_generator(
671679
"thread_id": client,
672680
"ll_client": ll_client,
673681
"embed_client": embed_client,
674-
"db_conn": get_client_db(client).connection,
682+
"db_conn": db_conn,
675683
},
676684
metadata={
677685
"model_name": model["model"],

src/server/utils/databases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def set_selectai_profile(conn: oracledb.Connection, attribute_name: str, attribu
174174
"""Update SelectAI Profile"""
175175
logger.info("Updating SelectAI Profile attribute: %s = %s", attribute_name, attribute_value)
176176
# Attribute Names: provider, credential_name, object_list, provider_endpoint, model
177+
# Attribute Names: temperature, max_tokens
177178
binds = {"attribute_name": attribute_name, "attribute_value": attribute_value}
178179
sql = """
179180
BEGIN

0 commit comments

Comments
 (0)