Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 789d1b3

Browse files
Feat/function calling (#1503)
* chore: change update to patch * fix: swagger * fix: pull api * chore: refactor server controller * fix: update status * feat: mimic openai function calling api with llama3.1 * chore: remove unnecessary cout * feat: add tool choice option to api * feat: add unitest * chore: format code * feat: function calling in user message * Update inference_service.cc --------- Co-authored-by: vansangpfiev <vansangpfiev@gmail.com>
1 parent 2e63d38 commit 789d1b3

File tree

3 files changed

+45
-16
lines changed

3 files changed

+45
-16
lines changed

engine/controllers/swagger.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,21 @@ Json::Value SwaggerController::generateOpenAPISpec() {
630630
"#/components/schemas/ChatMessage";
631631
schemas["ChatCompletionRequest"]["properties"]["stream"]["type"] = "boolean";
632632
schemas["ChatCompletionRequest"]["properties"]["engine"]["type"] = "string";
633+
schemas["ChatCompletionRequest"]["properties"]["tools"]["type"] = "array";
634+
schemas["ChatCompletionRequest"]["properties"]["tools"]["items"]["$ref"] =
635+
"#/components/schemas/ToolsCall";
636+
schemas["ChatCompletionRequest"]["properties"]["tools_call_in_user_message"]
637+
["type"] = "boolean";
638+
schemas["ChatCompletionRequest"]["properties"]["tools_call_in_user_message"]
639+
["default"] = false;
640+
schemas["ToolsCall"]["type"] = "object";
633641

634642
schemas["ChatMessage"]["type"] = "object";
635643
schemas["ChatMessage"]["properties"]["role"]["type"] = "string";
636644
schemas["ChatMessage"]["properties"]["content"]["type"] = "string";
645+
schemas["ChatMessage"]["properties"]["tools"]["type"] = "array";
646+
schemas["ChatMessage"]["properties"]["tools"]["items"]["$ref"] =
647+
"#/components/schemas/ToolsCall";
637648

638649
schemas["ChatCompletionResponse"]["type"] = "object";
639650
// Add properties based on your implementation

engine/services/inference_service.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,4 +389,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr<Json::Value> json_body,
389389
}
390390
return true;
391391
}
392-
} // namespace services
392+
} // namespace services

engine/utils/function_calling/common.h

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ inline std::string ReplaceCustomFunctions(const std::string& original,
5151
}
5252

5353
inline bool HasTools(const std::shared_ptr<Json::Value>& request) {
54-
return request->isMember("tools") && (*request)["tools"].isArray() &&
55-
(*request)["tools"].size() > 0;
54+
return (request->isMember("tools") && (*request)["tools"].isArray() &&
55+
(*request)["tools"].size() > 0) ||
56+
request->get("tools_call_in_user_message", false).asBool();
5657
}
5758

5859
inline std::string ProcessTools(const std::shared_ptr<Json::Value>& request) {
@@ -149,7 +150,7 @@ inline void UpdateMessages(std::string& system_prompt,
149150
Json::Value tool_choice = request->get("tool_choice", "auto");
150151
if (tool_choice.isString() && tool_choice.asString() == "required") {
151152
system_prompt +=
152-
"\n\nYou must use a function to answer the user's question.";
153+
"\n\nYou must call a function to answer the user's question.";
153154
} else if (!tool_choice.isString()) {
154155

155156
system_prompt +=
@@ -158,10 +159,14 @@ inline void UpdateMessages(std::string& system_prompt,
158159
"' to answer the user's question.";
159160
}
160161

162+
bool tools_call_in_user_message =
163+
request->get("tools_call_in_user_message", false).asBool();
164+
161165
bool original_stream_config = (*request).get("stream", false).asBool();
162166
// (*request)["grammar"] = function_calling_utils::gamma_json;
163167
(*request)["stream"] =
164168
false; //when using function calling, disable stream automatically because we need to parse the response to get function name and params
169+
165170
if (!request->isMember("messages") || !(*request)["messages"].isArray() ||
166171
(*request)["messages"].empty()) {
167172
// If no messages, add the system prompt as the first message
@@ -170,21 +175,34 @@ inline void UpdateMessages(std::string& system_prompt,
170175
systemMessage["content"] = system_prompt;
171176
(*request)["messages"].append(systemMessage);
172177
} else {
173-
Json::Value& firstMessage = (*request)["messages"][0];
174-
if (firstMessage["role"] == "system") {
175-
bool addCustomPrompt =
176-
request->get("add_custom_system_prompt", true).asBool();
177-
if (addCustomPrompt) {
178-
firstMessage["content"] =
179-
system_prompt + "\n" + firstMessage["content"].asString();
178+
179+
if (tools_call_in_user_message) {
180+
for (Json::Value& message : (*request)["messages"]) {
181+
if (message["role"] == "user" && message.isMember("tools") &&
182+
message["tools"].isArray() && message["tools"].size() > 0) {
183+
message["content"] = system_prompt + "\n User question: " +
184+
message["content"].asString();
185+
}
180186
}
181187
} else {
182-
// If the first message is not a system message, prepend the system prompt
183-
Json::Value systemMessage;
184-
systemMessage["role"] = "system";
185-
systemMessage["content"] = system_prompt;
186-
(*request)["messages"].insert(0, systemMessage);
188+
Json::Value& firstMessage = (*request)["messages"][0];
189+
if (firstMessage["role"] == "system") {
190+
bool addCustomPrompt =
191+
request->get("add_custom_system_prompt", true).asBool();
192+
if (addCustomPrompt) {
193+
firstMessage["content"] =
194+
system_prompt + "\n" + firstMessage["content"].asString();
195+
}
196+
} else {
197+
// If the first message is not a system message, prepend the system prompt
198+
Json::Value systemMessage;
199+
systemMessage["role"] = "system";
200+
systemMessage["content"] = system_prompt;
201+
(*request)["messages"].insert(0, systemMessage);
202+
}
187203
}
204+
205+
// transform last message role to tool if it is a function call
188206
Json::Value& lastMessage =
189207
(*request)["messages"][(*request)["messages"].size() - 1];
190208
if (lastMessage.get("role", "") == "tool") {

0 commit comments

Comments
 (0)