@@ -51,8 +51,9 @@ inline std::string ReplaceCustomFunctions(const std::string& original,
5151}
5252
5353inline bool HasTools (const std::shared_ptr<Json::Value>& request) {
54- return request->isMember (" tools" ) && (*request)[" tools" ].isArray () &&
55- (*request)[" tools" ].size () > 0 ;
54+ return (request->isMember (" tools" ) && (*request)[" tools" ].isArray () &&
55+ (*request)[" tools" ].size () > 0 ) ||
56+ request->get (" tools_call_in_user_message" , false ).asBool ();
5657}
5758
5859inline std::string ProcessTools (const std::shared_ptr<Json::Value>& request) {
@@ -149,7 +150,7 @@ inline void UpdateMessages(std::string& system_prompt,
149150 Json::Value tool_choice = request->get (" tool_choice" , " auto" );
150151 if (tool_choice.isString () && tool_choice.asString () == " required" ) {
151152 system_prompt +=
152- " \n\n You must use a function to answer the user's question." ;
153+ " \n\n You must call a function to answer the user's question." ;
153154 } else if (!tool_choice.isString ()) {
154155
155156 system_prompt +=
@@ -158,10 +159,14 @@ inline void UpdateMessages(std::string& system_prompt,
158159 " ' to answer the user's question." ;
159160 }
160161
162+ bool tools_call_in_user_message =
163+ request->get (" tools_call_in_user_message" , false ).asBool ();
164+
161165 bool original_stream_config = (*request).get (" stream" , false ).asBool ();
162166 // (*request)["grammar"] = function_calling_utils::gamma_json;
163167 (*request)[" stream" ] =
164168 false ; // when using function calling, disable stream automatically because we need to parse the response to get function name and params
169+
165170 if (!request->isMember (" messages" ) || !(*request)[" messages" ].isArray () ||
166171 (*request)[" messages" ].empty ()) {
167172 // If no messages, add the system prompt as the first message
@@ -170,21 +175,34 @@ inline void UpdateMessages(std::string& system_prompt,
170175 systemMessage[" content" ] = system_prompt;
171176 (*request)[" messages" ].append (systemMessage);
172177 } else {
173- Json::Value& firstMessage = (*request)[" messages" ][0 ];
174- if (firstMessage[" role" ] == " system" ) {
175- bool addCustomPrompt =
176- request->get (" add_custom_system_prompt" , true ).asBool ();
177- if (addCustomPrompt) {
178- firstMessage[" content" ] =
179- system_prompt + " \n " + firstMessage[" content" ].asString ();
178+
179+ if (tools_call_in_user_message) {
180+ for (Json::Value& message : (*request)[" messages" ]) {
181+ if (message[" role" ] == " user" && message.isMember (" tools" ) &&
182+ message[" tools" ].isArray () && message[" tools" ].size () > 0 ) {
183+ message[" content" ] = system_prompt + " \n User question: " +
184+ message[" content" ].asString ();
185+ }
180186 }
181187 } else {
182- // If the first message is not a system message, prepend the system prompt
183- Json::Value systemMessage;
184- systemMessage[" role" ] = " system" ;
185- systemMessage[" content" ] = system_prompt;
186- (*request)[" messages" ].insert (0 , systemMessage);
188+ Json::Value& firstMessage = (*request)[" messages" ][0 ];
189+ if (firstMessage[" role" ] == " system" ) {
190+ bool addCustomPrompt =
191+ request->get (" add_custom_system_prompt" , true ).asBool ();
192+ if (addCustomPrompt) {
193+ firstMessage[" content" ] =
194+ system_prompt + " \n " + firstMessage[" content" ].asString ();
195+ }
196+ } else {
197+ // If the first message is not a system message, prepend the system prompt
198+ Json::Value systemMessage;
199+ systemMessage[" role" ] = " system" ;
200+ systemMessage[" content" ] = system_prompt;
201+ (*request)[" messages" ].insert (0 , systemMessage);
202+ }
187203 }
204+
205+ // transform last message role to tool if it is a function call
188206 Json::Value& lastMessage =
189207 (*request)[" messages" ][(*request)[" messages" ].size () - 1 ];
190208 if (lastMessage.get (" role" , " " ) == " tool" ) {
0 commit comments