@@ -10,7 +10,7 @@ using json = nlohmann::json;
1010/* *
1111 * The state of the inference task
1212 */
13- enum InferenceStatus { PENDING, RUNNING, FINISHED };
13+ enum InferenceStatus { PENDING, RUNNING, EOS, FINISHED };
1414
1515/* *
1616 * There is a need to save state of current ongoing inference status of a
@@ -21,7 +21,7 @@ enum InferenceStatus { PENDING, RUNNING, FINISHED };
2121 */
2222struct inferenceState {
2323 int task_id;
24- InferenceStatus inferenceStatus = PENDING;
24+ InferenceStatus inference_status = PENDING;
2525 llamaCPP *instance;
2626
2727 inferenceState (llamaCPP *inst) : instance(inst) {}
@@ -104,14 +104,13 @@ std::string create_full_return_json(const std::string &id,
104104 root[" usage" ] = usage;
105105
106106 Json::StreamWriterBuilder writer;
107- writer[" indentation" ] = " " ; // Compact output
107+ writer[" indentation" ] = " " ; // Compact output
108108 return Json::writeString (writer, root);
109109}
110110
111111std::string create_return_json (const std::string &id, const std::string &model,
112112 const std::string &content,
113113 Json::Value finish_reason = Json::Value()) {
114-
115114 Json::Value root;
116115
117116 root[" id" ] = id;
@@ -132,16 +131,16 @@ std::string create_return_json(const std::string &id, const std::string &model,
132131 root[" choices" ] = choicesArray;
133132
134133 Json::StreamWriterBuilder writer;
135- writer[" indentation" ] = " " ; // This sets the indentation to an empty string,
136- // producing compact output.
134+ writer[" indentation" ] = " " ; // This sets the indentation to an empty string,
135+ // producing compact output.
137136 return Json::writeString (writer, root);
138137}
139138
140139llamaCPP::llamaCPP ()
141140 : queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel,
142141 " llamaCPP" )) {
143142 // Some default values for now below
144- log_disable (); // Disable the log to file feature, reduce bloat for
143+ log_disable (); // Disable the log to file feature, reduce bloat for
145144 // target
146145 // system ()
147146};
@@ -167,7 +166,6 @@ void llamaCPP::warmupModel() {
167166void llamaCPP::inference (
168167 const HttpRequestPtr &req,
169168 std::function<void (const HttpResponsePtr &)> &&callback) {
170-
171169 const auto &jsonBody = req->getJsonObject ();
172170 // Check if model is loaded
173171 if (checkModelLoaded (callback)) {
@@ -180,7 +178,6 @@ void llamaCPP::inference(
180178void llamaCPP::inferenceImpl (
181179 std::shared_ptr<Json::Value> jsonBody,
182180 std::function<void (const HttpResponsePtr &)> &callback) {
183-
184181 std::string formatted_output = pre_prompt;
185182
186183 json data;
@@ -218,7 +215,6 @@ void llamaCPP::inferenceImpl(
218215 };
219216
220217 if (!llama.multimodal ) {
221-
222218 for (const auto &message : messages) {
223219 std::string input_role = message[" role" ].asString ();
224220 std::string role;
@@ -243,7 +239,6 @@ void llamaCPP::inferenceImpl(
243239 }
244240 formatted_output += ai_prompt;
245241 } else {
246-
247242 data[" image_data" ] = json::array ();
248243 for (const auto &message : messages) {
249244 std::string input_role = message[" role" ].asString ();
@@ -327,18 +322,33 @@ void llamaCPP::inferenceImpl(
327322 auto state = create_inference_state (this );
328323 auto chunked_content_provider =
329324 [state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
330- if (state->inferenceStatus == PENDING) {
331- state->inferenceStatus = RUNNING;
332- } else if (state->inferenceStatus == FINISHED) {
325+ if (state->inference_status == PENDING) {
326+ state->inference_status = RUNNING;
327+ } else if (state->inference_status == FINISHED) {
333328 return 0 ;
334329 }
335330
336331 if (!pBuffer) {
337332 LOG_INFO << " Connection closed or buffer is null. Reset context" ;
338- state->inferenceStatus = FINISHED;
333+ state->inference_status = FINISHED;
339334 return 0 ;
340335 }
341336
337+ if (state->inference_status == EOS) {
338+ LOG_INFO << " End of result" ;
339+ const std::string str =
340+ " data: " +
341+ create_return_json (nitro_utils::generate_random_string (20 ), " _" , " " ,
342+ " stop" ) +
343+ " \n\n " + " data: [DONE]" + " \n\n " ;
344+
345+ LOG_VERBOSE (" data stream" , {{" to_send" , str}});
346+ std::size_t nRead = std::min (str.size (), nBuffSize);
347+ memcpy (pBuffer, str.data (), nRead);
348+ state->inference_status = FINISHED;
349+ return nRead;
350+ }
351+
342352 task_result result = state->instance ->llama .next_result (state->task_id );
343353 if (!result.error ) {
344354 const std::string to_send = result.result_json [" content" ];
@@ -352,28 +362,22 @@ void llamaCPP::inferenceImpl(
352362 memcpy (pBuffer, str.data (), nRead);
353363
354364 if (result.stop ) {
355- const std::string str =
356- " data: " +
357- create_return_json (nitro_utils::generate_random_string (20 ), " _" ,
358- " " , " stop" ) +
359- " \n\n " + " data: [DONE]" + " \n\n " ;
360-
361- LOG_VERBOSE (" data stream" , {{" to_send" , str}});
362- std::size_t nRead = std::min (str.size (), nBuffSize);
363- memcpy (pBuffer, str.data (), nRead);
364365 LOG_INFO << " reached result stop" ;
365- state->inferenceStatus = FINISHED;
366+ state->inference_status = EOS;
367+ return nRead;
366368 }
367369
368370 // Make sure nBufferSize is not zero
369371 // Otherwise it stop streaming
370372 if (!nRead) {
371- state->inferenceStatus = FINISHED;
373+ state->inference_status = FINISHED;
372374 }
373375
374376 return nRead;
377+ } else {
378+ LOG_INFO << " Error during inference" ;
375379 }
376- state->inferenceStatus = FINISHED;
380+ state->inference_status = FINISHED;
377381 return 0 ;
378382 };
379383 // Queued task
@@ -391,16 +395,17 @@ void llamaCPP::inferenceImpl(
391395
392396 // Since this is an async task, we will wait for the task to be
393397 // completed
394- while (state->inferenceStatus != FINISHED && retries < 10 ) {
398+ while (state->inference_status != FINISHED && retries < 10 ) {
395399 // Should wait chunked_content_provider lambda to be called within
396400 // 3s
397- if (state->inferenceStatus == PENDING) {
401+ if (state->inference_status == PENDING) {
398402 retries += 1 ;
399403 }
400- if (state->inferenceStatus != RUNNING)
404+ if (state->inference_status != RUNNING)
401405 LOG_INFO << " Wait for task to be released:" << state->task_id ;
402406 std::this_thread::sleep_for (std::chrono::milliseconds (100 ));
403407 }
408+ LOG_INFO << " Task completed, release it" ;
404409 // Request completed, release it
405410 state->instance ->llama .request_cancel (state->task_id );
406411 });
@@ -445,7 +450,6 @@ void llamaCPP::embedding(
445450void llamaCPP::embeddingImpl (
446451 std::shared_ptr<Json::Value> jsonBody,
447452 std::function<void (const HttpResponsePtr &)> &callback) {
448-
449453 // Queue embedding task
450454 auto state = create_inference_state (this );
451455
@@ -532,7 +536,6 @@ void llamaCPP::modelStatus(
532536void llamaCPP::loadModel (
533537 const HttpRequestPtr &req,
534538 std::function<void (const HttpResponsePtr &)> &&callback) {
535-
536539 if (llama.model_loaded_external ) {
537540 LOG_INFO << " model loaded" ;
538541 Json::Value jsonResp;
@@ -561,7 +564,6 @@ void llamaCPP::loadModel(
561564}
562565
563566bool llamaCPP::loadModelImpl (std::shared_ptr<Json::Value> jsonBody) {
564-
565567 gpt_params params;
566568 // By default will setting based on number of handlers
567569 if (jsonBody) {
@@ -570,11 +572,9 @@ bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
570572 params.mmproj = jsonBody->operator [](" mmproj" ).asString ();
571573 }
572574 if (!jsonBody->operator [](" grp_attn_n" ).isNull ()) {
573-
574575 params.grp_attn_n = jsonBody->operator [](" grp_attn_n" ).asInt ();
575576 }
576577 if (!jsonBody->operator [](" grp_attn_w" ).isNull ()) {
577-
578578 params.grp_attn_w = jsonBody->operator [](" grp_attn_w" ).asInt ();
579579 }
580580 if (!jsonBody->operator [](" mlock" ).isNull ()) {
@@ -620,12 +620,12 @@ bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
620620 std::string llama_log_folder =
621621 jsonBody->operator [](" llama_log_folder" ).asString ();
622622 log_set_target (llama_log_folder + " llama.log" );
623- } // Set folder for llama log
623+ } // Set folder for llama log
624624 }
625625#ifdef GGML_USE_CUBLAS
626626 LOG_INFO << " Setting up GGML CUBLAS PARAMS" ;
627627 params.mul_mat_q = false ;
628- #endif // GGML_USE_CUBLAS
628+ #endif // GGML_USE_CUBLAS
629629 if (params.model_alias == " unknown" ) {
630630 params.model_alias = params.model ;
631631 }
@@ -644,7 +644,7 @@ bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
644644 // load the model
645645 if (!llama.load_model (params)) {
646646 LOG_ERROR << " Error loading the model" ;
647- return false ; // Indicate failure
647+ return false ; // Indicate failure
648648 }
649649 llama.initialize ();
650650
0 commit comments