11#include " llamaCPP.h"
2+
3+ #include < trantor/utils/SerialTaskQueue.h>
4+
25#include " llama.h"
36#include " log.h"
47#include " utils/nitro_utils.h"
58
69using namespace inferences ;
710using json = nlohmann::json;
811
12+ /* *
13+ * Queue to handle the inference task, this is to ensure that the inference
14+ * task is handled in a sequential manner
15+ */
16+ static trantor::SerialTaskQueue queue (" worker" );
17+
18+ /* *
19+ * The state of the inference task
20+ */
21+ enum InferenceStatus {
22+ PENDING,
23+ RUNNING,
24+ FINISHED
25+ };
26+
927/* *
1028 * There is a need to save state of current ongoing inference status of a
1129 * handler, this struct is to solve that issue
@@ -15,8 +33,8 @@ using json = nlohmann::json;
1533 */
1634struct inferenceState {
1735 bool is_stopped = false ;
18- bool is_streaming = false ;
1936 int task_id;
37+ InferenceStatus inferenceStatus = PENDING;
2038 llamaCPP *instance;
2139
2240 inferenceState (llamaCPP *inst) : instance(inst) {}
@@ -35,7 +53,7 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
3553 * Check if model already loaded if not return message to user
3654 * @param callback the function to return message to user
3755 */
38- void llamaCPP::checkModelLoaded (
56+ bool llamaCPP::checkModelLoaded (
3957 std::function<void (const HttpResponsePtr &)> &callback) {
4058 if (!llama.model_loaded_external ) {
4159 Json::Value jsonResp;
@@ -44,8 +62,9 @@ void llamaCPP::checkModelLoaded(
4462 auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
4563 resp->setStatusCode (drogon::k409Conflict);
4664 callback (resp);
47- return ;
65+ return false ;
4866 }
67+ return true ;
4968}
5069
5170Json::Value create_embedding_payload (const std::vector<float > &embedding,
@@ -70,7 +89,6 @@ std::string create_full_return_json(const std::string &id,
7089 const std::string &system_fingerprint,
7190 int prompt_tokens, int completion_tokens,
7291 Json::Value finish_reason = Json::Value()) {
73-
7492 Json::Value root;
7593
7694 root[" id" ] = id;
@@ -163,9 +181,11 @@ void llamaCPP::inference(
163181
164182 const auto &jsonBody = req->getJsonObject ();
165183 // Check if model is loaded
166- checkModelLoaded (callback);
167-
168- inferenceImpl (jsonBody, callback);
184+ if (checkModelLoaded (callback)) {
185+ // Model is loaded
186+ // Do Inference
187+ inferenceImpl (jsonBody, callback);
188+ }
169189}
170190
171191void llamaCPP::inferenceImpl (
@@ -318,28 +338,24 @@ void llamaCPP::inferenceImpl(
318338 auto state = create_inference_state (this );
319339 auto chunked_content_provider =
320340 [state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
321- if (!state->is_streaming ) {
322- state->task_id =
323- state->instance ->llama .request_completion (data, false , false , -1 );
324- state->instance ->single_queue_is_busy = true ;
341+
342+ if (state->inferenceStatus == PENDING) {
343+ state->inferenceStatus = RUNNING;
325344 }
345+
326346 if (!pBuffer) {
327347 LOG_INFO << " Connection closed or buffer is null. Reset context" ;
328348 state->instance ->llama .request_cancel (state->task_id );
329- state->is_streaming = false ;
330349 state->instance ->single_queue_is_busy = false ;
331350 return 0 ;
332351 }
333352 if (state->is_stopped ) {
334- state->is_streaming = false ;
335353 state->instance ->single_queue_is_busy = false ;
336354 return 0 ;
337355 }
338356
339357 task_result result = state->instance ->llama .next_result (state->task_id );
340358 if (!result.error ) {
341- // Update streaming state to being streamed
342- state->is_streaming = true ;
343359 const std::string to_send = result.result_json [" content" ];
344360 const std::string str =
345361 " data: " +
@@ -363,35 +379,48 @@ void llamaCPP::inferenceImpl(
363379 LOG_INFO << " reached result stop" ;
364380 state->is_stopped = true ;
365381 state->instance ->llama .request_cancel (state->task_id );
366- state->is_streaming = false ;
367382 state->instance ->single_queue_is_busy = false ;
368-
369- return nRead;
370383 }
371- return nRead;
372- } else {
373- if (state->instance ->llama .params .n_parallel == 1 ) {
374- while (state->instance ->single_queue_is_busy ) {
375- LOG_INFO << " Waiting for task to be released status:"
376- << state->instance ->single_queue_is_busy ;
377- std::this_thread::sleep_for (std::chrono::milliseconds (
378- 500 )); // Waiting in 500 miliseconds step
379- }
384+
385+ // Make sure nBufferSize is not zero
386+ // Otherwise it stop streaming
387+ if (!nRead) {
388+ state->instance ->single_queue_is_busy = false ;
380389 }
381- std::string str = " \n\n " ;
382- std::size_t nRead = str.size ();
383- memcpy (pBuffer, str.data (), nRead);
384- LOG_INFO << " Failing retrying now" ;
390+
385391 return nRead;
386392 }
387- state->is_streaming = false ;
388393 state->instance ->single_queue_is_busy = false ;
389394 return 0 ;
390395 };
391- auto resp = nitro_utils::nitroStreamResponse (chunked_content_provider,
392- " chat_completions.txt" );
393- callback (resp);
396+
397+ // Run task in serial queue
398+ queue.runTaskInQueue ([callback, state, data,
399+ chunked_content_provider]() {
400+ state->task_id =
401+ state->instance ->llama .request_completion (data, false , false , -1 );
402+
403+ state->instance ->single_queue_is_busy = true ;
404+
405+ // Start streaming response
406+ auto resp = nitro_utils::nitroStreamResponse (chunked_content_provider,
407+ " chat_completions.txt" );
408+ callback (resp);
394409
410+ int retries = 0 ;
411+
412+ // Since this is an async task, we will wait for the task to be completed
413+ while (state->instance ->single_queue_is_busy && retries < 10 ) {
414+ // Should wait chunked_content_provider lambda to be called within 3s
415+ if (state->inferenceStatus == PENDING) {
416+ retries += 1 ;
417+ }
418+ LOG_INFO << " Wait for task to be released:" << state->task_id ;
419+ std::this_thread::sleep_for (std::chrono::milliseconds (300 ));
420+ }
421+
422+ state->inferenceStatus = FINISHED;
423+ });
395424 return ;
396425 } else {
397426 Json::Value respData;
@@ -423,11 +452,14 @@ void llamaCPP::inferenceImpl(
423452void llamaCPP::embedding (
424453 const HttpRequestPtr &req,
425454 std::function<void (const HttpResponsePtr &)> &&callback) {
426- checkModelLoaded (callback);
427- const auto &jsonBody = req->getJsonObject ();
428-
429- embeddingImpl (jsonBody, callback);
430- return ;
455+ // Check if model is loaded
456+ if (checkModelLoaded (callback)) {
457+ // Model is loaded
458+ const auto &jsonBody = req->getJsonObject ();
459+ // Run embedding
460+ embeddingImpl (jsonBody, callback);
461+ return ;
462+ }
431463}
432464
433465void llamaCPP::embeddingImpl (
0 commit comments