11#include " llamaCPP.h"
22
3-
4- #include < iostream>
53#include < fstream>
4+ #include < iostream>
65#include " log.h"
7- #include " utils/nitro_utils.h"
86
97// External
108#include " common.h"
@@ -175,19 +173,18 @@ void llamaCPP::WarmupModel() {
175173}
176174
177175void llamaCPP::ChatCompletion (
178- const HttpRequestPtr& req ,
176+ inferences::ChatCompletionRequest&& completion ,
179177 std::function<void (const HttpResponsePtr&)>&& callback) {
180- const auto & jsonBody = req->getJsonObject ();
181178 // Check if model is loaded
182179 if (CheckModelLoaded (callback)) {
183180 // Model is loaded
184181 // Do Inference
185- InferenceImpl (jsonBody , callback);
182+ InferenceImpl (std::move (completion) , callback);
186183 }
187184}
188185
189186void llamaCPP::InferenceImpl (
190- std::shared_ptr<Json::Value> jsonBody ,
187+ inferences::ChatCompletionRequest&& completion ,
191188 std::function<void (const HttpResponsePtr&)>& callback) {
192189 std::string formatted_output = pre_prompt;
193190
@@ -196,131 +193,131 @@ void llamaCPP::InferenceImpl(
196193 int no_images = 0 ;
197194 // To set default value
198195
199- if (jsonBody) {
200- // Increase number of chats received and clean the prompt
201- no_of_chats++;
202- if (no_of_chats % clean_cache_threshold == 0 ) {
203- LOG_INFO << " Clean cache threshold reached!" ;
204- llama.kv_cache_clear ();
205- LOG_INFO << " Cache cleaned" ;
206- }
207-
208- // Default values to enable auto caching
209- data[" cache_prompt" ] = caching_enabled;
210- data[" n_keep" ] = -1 ;
211-
212- // Passing load value
213- data[" repeat_last_n" ] = this ->repeat_last_n ;
214-
215- data[" stream" ] = (*jsonBody).get (" stream" , false ).asBool ();
216- data[" n_predict" ] = (*jsonBody).get (" max_tokens" , 500 ).asInt ();
217- data[" top_p" ] = (*jsonBody).get (" top_p" , 0.95 ).asFloat ();
218- data[" temperature" ] = (*jsonBody).get (" temperature" , 0.8 ).asFloat ();
219- data[" frequency_penalty" ] =
220- (*jsonBody).get (" frequency_penalty" , 0 ).asFloat ();
221- data[" presence_penalty" ] = (*jsonBody).get (" presence_penalty" , 0 ).asFloat ();
222- const Json::Value& messages = (*jsonBody)[" messages" ];
196+ // Increase number of chats received and clean the prompt
197+ no_of_chats++;
198+ if (no_of_chats % clean_cache_threshold == 0 ) {
199+ LOG_INFO << " Clean cache threshold reached!" ;
200+ llama.kv_cache_clear ();
201+ LOG_INFO << " Cache cleaned" ;
202+ }
223203
224- if (!grammar_file_content.empty ()) {
225- data[" grammar" ] = grammar_file_content;
226- };
204+ // Default values to enable auto caching
205+ data[" cache_prompt" ] = caching_enabled;
206+ data[" n_keep" ] = -1 ;
207+
208+ // Passing load value
209+ data[" repeat_last_n" ] = this ->repeat_last_n ;
210+
211+ LOG_INFO << " Messages:" << completion.messages .toStyledString ();
212+ LOG_INFO << " Stop:" << completion.stop .toStyledString ();
213+
214+ data[" stream" ] = completion.stream ;
215+ data[" n_predict" ] = completion.max_tokens ;
216+ data[" top_p" ] = completion.top_p ;
217+ data[" temperature" ] = completion.temperature ;
218+ data[" frequency_penalty" ] = completion.frequency_penalty ;
219+ data[" presence_penalty" ] = completion.presence_penalty ;
220+ const Json::Value& messages = completion.messages ;
221+
222+ if (!grammar_file_content.empty ()) {
223+ data[" grammar" ] = grammar_file_content;
224+ };
225+
226+ if (!llama.multimodal ) {
227+ for (const auto & message : messages) {
228+ std::string input_role = message[" role" ].asString ();
229+ std::string role;
230+ if (input_role == " user" ) {
231+ role = user_prompt;
232+ std::string content = message[" content" ].asString ();
233+ formatted_output += role + content;
234+ } else if (input_role == " assistant" ) {
235+ role = ai_prompt;
236+ std::string content = message[" content" ].asString ();
237+ formatted_output += role + content;
238+ } else if (input_role == " system" ) {
239+ role = system_prompt;
240+ std::string content = message[" content" ].asString ();
241+ formatted_output = role + content + formatted_output;
227242
228- if (!llama.multimodal ) {
229- for (const auto & message : messages) {
230- std::string input_role = message[" role" ].asString ();
231- std::string role;
232- if (input_role == " user" ) {
233- role = user_prompt;
234- std::string content = message[" content" ].asString ();
235- formatted_output += role + content;
236- } else if (input_role == " assistant" ) {
237- role = ai_prompt;
238- std::string content = message[" content" ].asString ();
239- formatted_output += role + content;
240- } else if (input_role == " system" ) {
241- role = system_prompt;
242- std::string content = message[" content" ].asString ();
243- formatted_output = role + content + formatted_output;
244-
245- } else {
246- role = input_role;
247- std::string content = message[" content" ].asString ();
248- formatted_output += role + content;
249- }
243+ } else {
244+ role = input_role;
245+ std::string content = message[" content" ].asString ();
246+ formatted_output += role + content;
250247 }
251- formatted_output += ai_prompt;
252- } else {
253- data[" image_data" ] = json::array ();
254- for (const auto & message : messages) {
255- std::string input_role = message[" role" ].asString ();
256- std::string role;
257- if (input_role == " user" ) {
258- formatted_output += role;
259- for (auto content_piece : message[" content" ]) {
260- role = user_prompt;
261-
262- json content_piece_image_data;
263- content_piece_image_data[" data" ] = " " ;
264-
265- auto content_piece_type = content_piece[" type" ].asString ();
266- if (content_piece_type == " text" ) {
267- auto text = content_piece[" text" ].asString ();
268- formatted_output += text;
269- } else if (content_piece_type == " image_url" ) {
270- auto image_url = content_piece[" image_url" ][" url" ].asString ();
271- std::string base64_image_data;
272- if (image_url.find (" http" ) != std::string::npos) {
273- LOG_INFO << " Remote image detected but not supported yet" ;
274- } else if (image_url.find (" data:image" ) != std::string::npos) {
275- LOG_INFO << " Base64 image detected" ;
276- base64_image_data = nitro_utils::extractBase64 (image_url);
277- LOG_INFO << base64_image_data;
278- } else {
279- LOG_INFO << " Local image detected" ;
280- nitro_utils::processLocalImage (
281- image_url, [&](const std::string& base64Image) {
282- base64_image_data = base64Image;
283- });
284- LOG_INFO << base64_image_data;
285- }
286- content_piece_image_data[" data" ] = base64_image_data;
287-
288- formatted_output += " [img-" + std::to_string (no_images) + " ]" ;
289- content_piece_image_data[" id" ] = no_images;
290- data[" image_data" ].push_back (content_piece_image_data);
291- no_images++;
248+ }
249+ formatted_output += ai_prompt;
250+ } else {
251+ data[" image_data" ] = json::array ();
252+ for (const auto & message : messages) {
253+ std::string input_role = message[" role" ].asString ();
254+ std::string role;
255+ if (input_role == " user" ) {
256+ formatted_output += role;
257+ for (auto content_piece : message[" content" ]) {
258+ role = user_prompt;
259+
260+ json content_piece_image_data;
261+ content_piece_image_data[" data" ] = " " ;
262+
263+ auto content_piece_type = content_piece[" type" ].asString ();
264+ if (content_piece_type == " text" ) {
265+ auto text = content_piece[" text" ].asString ();
266+ formatted_output += text;
267+ } else if (content_piece_type == " image_url" ) {
268+ auto image_url = content_piece[" image_url" ][" url" ].asString ();
269+ std::string base64_image_data;
270+ if (image_url.find (" http" ) != std::string::npos) {
271+ LOG_INFO << " Remote image detected but not supported yet" ;
272+ } else if (image_url.find (" data:image" ) != std::string::npos) {
273+ LOG_INFO << " Base64 image detected" ;
274+ base64_image_data = nitro_utils::extractBase64 (image_url);
275+ LOG_INFO << base64_image_data;
276+ } else {
277+ LOG_INFO << " Local image detected" ;
278+ nitro_utils::processLocalImage (
279+ image_url, [&](const std::string& base64Image) {
280+ base64_image_data = base64Image;
281+ });
282+ LOG_INFO << base64_image_data;
292283 }
293- }
284+ content_piece_image_data[ " data " ] = base64_image_data;
294285
295- } else if (input_role == " assistant" ) {
296- role = ai_prompt;
297- std::string content = message[" content" ].asString ();
298- formatted_output += role + content;
299- } else if (input_role == " system" ) {
300- role = system_prompt;
301- std::string content = message[" content" ].asString ();
302- formatted_output = role + content + formatted_output;
303-
304- } else {
305- role = input_role;
306- std::string content = message[" content" ].asString ();
307- formatted_output += role + content;
286+ formatted_output += " [img-" + std::to_string (no_images) + " ]" ;
287+ content_piece_image_data[" id" ] = no_images;
288+ data[" image_data" ].push_back (content_piece_image_data);
289+ no_images++;
290+ }
308291 }
292+
293+ } else if (input_role == " assistant" ) {
294+ role = ai_prompt;
295+ std::string content = message[" content" ].asString ();
296+ formatted_output += role + content;
297+ } else if (input_role == " system" ) {
298+ role = system_prompt;
299+ std::string content = message[" content" ].asString ();
300+ formatted_output = role + content + formatted_output;
301+
302+ } else {
303+ role = input_role;
304+ std::string content = message[" content" ].asString ();
305+ formatted_output += role + content;
309306 }
310- formatted_output += ai_prompt;
311- LOG_INFO << formatted_output;
312307 }
308+ formatted_output += ai_prompt;
309+ LOG_INFO << formatted_output;
310+ }
313311
314- data[" prompt" ] = formatted_output;
315- for (const auto & stop_word : (*jsonBody)[" stop" ]) {
316- stopWords.push_back (stop_word.asString ());
317- }
318- // specify default stop words
319- // Ensure success case for chatML
320- stopWords.push_back (" <|im_end|>" );
321- stopWords.push_back (nitro_utils::rtrim (user_prompt));
322- data[" stop" ] = stopWords;
312+ data[" prompt" ] = formatted_output;
313+ for (const auto & stop_word : completion.stop ) {
314+ stopWords.push_back (stop_word.asString ());
323315 }
316+ // specify default stop words
317+ // Ensure success case for chatML
318+ stopWords.push_back (" <|im_end|>" );
319+ stopWords.push_back (nitro_utils::rtrim (user_prompt));
320+ data[" stop" ] = stopWords;
324321
325322 bool is_streamed = data[" stream" ];
326323// Enable full message debugging
0 commit comments