Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 9de123f

Browse files
authored
Merge pull request #431 from janhq/fix/lambda-race-condition-infinite-waiting
fix: race condition issue
2 parents c30e271 + c995887 commit 9de123f

File tree

2 files changed

+73
-41
lines changed

2 files changed

+73
-41
lines changed

controllers/llamaCPP.cc

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
11
#include "llamaCPP.h"
2+
3+
#include <trantor/utils/SerialTaskQueue.h>
4+
25
#include "llama.h"
36
#include "log.h"
47
#include "utils/nitro_utils.h"
58

69
using namespace inferences;
710
using json = nlohmann::json;
811

12+
/**
13+
* Queue to handle the inference task, this is to ensure that the inference
14+
* task is handled in a sequential manner
15+
*/
16+
static trantor::SerialTaskQueue queue("worker");
17+
18+
/**
19+
* The state of the inference task
20+
*/
21+
enum InferenceStatus {
22+
PENDING,
23+
RUNNING,
24+
FINISHED
25+
};
26+
927
/**
1028
* There is a need to save state of current ongoing inference status of a
1129
* handler, this struct is to solve that issue
@@ -15,8 +33,8 @@ using json = nlohmann::json;
1533
*/
1634
struct inferenceState {
1735
bool is_stopped = false;
18-
bool is_streaming = false;
1936
int task_id;
37+
InferenceStatus inferenceStatus = PENDING;
2038
llamaCPP *instance;
2139

2240
inferenceState(llamaCPP *inst) : instance(inst) {}
@@ -35,7 +53,7 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
3553
* Check if model already loaded if not return message to user
3654
* @param callback the function to return message to user
3755
*/
38-
void llamaCPP::checkModelLoaded(
56+
bool llamaCPP::checkModelLoaded(
3957
std::function<void(const HttpResponsePtr &)> &callback) {
4058
if (!llama.model_loaded_external) {
4159
Json::Value jsonResp;
@@ -44,8 +62,9 @@ void llamaCPP::checkModelLoaded(
4462
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
4563
resp->setStatusCode(drogon::k409Conflict);
4664
callback(resp);
47-
return;
65+
return false;
4866
}
67+
return true;
4968
}
5069

5170
Json::Value create_embedding_payload(const std::vector<float> &embedding,
@@ -70,7 +89,6 @@ std::string create_full_return_json(const std::string &id,
7089
const std::string &system_fingerprint,
7190
int prompt_tokens, int completion_tokens,
7291
Json::Value finish_reason = Json::Value()) {
73-
7492
Json::Value root;
7593

7694
root["id"] = id;
@@ -163,9 +181,11 @@ void llamaCPP::inference(
163181

164182
const auto &jsonBody = req->getJsonObject();
165183
// Check if model is loaded
166-
checkModelLoaded(callback);
167-
168-
inferenceImpl(jsonBody, callback);
184+
if(checkModelLoaded(callback)) {
185+
// Model is loaded
186+
// Do Inference
187+
inferenceImpl(jsonBody, callback);
188+
}
169189
}
170190

171191
void llamaCPP::inferenceImpl(
@@ -318,28 +338,24 @@ void llamaCPP::inferenceImpl(
318338
auto state = create_inference_state(this);
319339
auto chunked_content_provider =
320340
[state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
321-
if (!state->is_streaming) {
322-
state->task_id =
323-
state->instance->llama.request_completion(data, false, false, -1);
324-
state->instance->single_queue_is_busy = true;
341+
342+
if(state->inferenceStatus == PENDING) {
343+
state->inferenceStatus = RUNNING;
325344
}
345+
326346
if (!pBuffer) {
327347
LOG_INFO << "Connection closed or buffer is null. Reset context";
328348
state->instance->llama.request_cancel(state->task_id);
329-
state->is_streaming = false;
330349
state->instance->single_queue_is_busy = false;
331350
return 0;
332351
}
333352
if (state->is_stopped) {
334-
state->is_streaming = false;
335353
state->instance->single_queue_is_busy = false;
336354
return 0;
337355
}
338356

339357
task_result result = state->instance->llama.next_result(state->task_id);
340358
if (!result.error) {
341-
// Update streaming state to being streamed
342-
state->is_streaming = true;
343359
const std::string to_send = result.result_json["content"];
344360
const std::string str =
345361
"data: " +
@@ -363,35 +379,48 @@ void llamaCPP::inferenceImpl(
363379
LOG_INFO << "reached result stop";
364380
state->is_stopped = true;
365381
state->instance->llama.request_cancel(state->task_id);
366-
state->is_streaming = false;
367382
state->instance->single_queue_is_busy = false;
368-
369-
return nRead;
370383
}
371-
return nRead;
372-
} else {
373-
if (state->instance->llama.params.n_parallel == 1) {
374-
while (state->instance->single_queue_is_busy) {
375-
LOG_INFO << "Waiting for task to be released status:"
376-
<< state->instance->single_queue_is_busy;
377-
std::this_thread::sleep_for(std::chrono::milliseconds(
378-
500)); // Waiting in 500 miliseconds step
379-
}
384+
385+
// Make sure nBufferSize is not zero
386+
// Otherwise it stop streaming
387+
if(!nRead) {
388+
state->instance->single_queue_is_busy = false;
380389
}
381-
std::string str = "\n\n";
382-
std::size_t nRead = str.size();
383-
memcpy(pBuffer, str.data(), nRead);
384-
LOG_INFO << "Failing retrying now";
390+
385391
return nRead;
386392
}
387-
state->is_streaming = false;
388393
state->instance->single_queue_is_busy = false;
389394
return 0;
390395
};
391-
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
392-
"chat_completions.txt");
393-
callback(resp);
396+
397+
// Run task in serial queue
398+
queue.runTaskInQueue([callback, state, data,
399+
chunked_content_provider]() {
400+
state->task_id =
401+
state->instance->llama.request_completion(data, false, false, -1);
402+
403+
state->instance->single_queue_is_busy = true;
404+
405+
// Start streaming response
406+
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
407+
"chat_completions.txt");
408+
callback(resp);
394409

410+
int retries = 0;
411+
412+
// Since this is an async task, we will wait for the task to be completed
413+
while (state->instance->single_queue_is_busy && retries < 10) {
414+
// Should wait chunked_content_provider lambda to be called within 3s
415+
if(state->inferenceStatus == PENDING) {
416+
retries += 1;
417+
}
418+
LOG_INFO << "Wait for task to be released:" << state->task_id;
419+
std::this_thread::sleep_for(std::chrono::milliseconds(300));
420+
}
421+
422+
state->inferenceStatus = FINISHED;
423+
});
395424
return;
396425
} else {
397426
Json::Value respData;
@@ -423,11 +452,14 @@ void llamaCPP::inferenceImpl(
423452
void llamaCPP::embedding(
424453
const HttpRequestPtr &req,
425454
std::function<void(const HttpResponsePtr &)> &&callback) {
426-
checkModelLoaded(callback);
427-
const auto &jsonBody = req->getJsonObject();
428-
429-
embeddingImpl(jsonBody, callback);
430-
return;
455+
// Check if model is loaded
456+
if(checkModelLoaded(callback)) {
457+
// Model is loaded
458+
const auto &jsonBody = req->getJsonObject();
459+
// Run embedding
460+
embeddingImpl(jsonBody, callback);
461+
return;
462+
}
431463
}
432464

433465
void llamaCPP::embeddingImpl(

controllers/llamaCPP.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2571,7 +2571,7 @@ class llamaCPP : public drogon::HttpController<llamaCPP>, public ChatProvider {
25712571
std::function<void(const HttpResponsePtr &)> &callback);
25722572
void embeddingImpl(std::shared_ptr<Json::Value> jsonBody,
25732573
std::function<void(const HttpResponsePtr &)> &callback);
2574-
void checkModelLoaded(std::function<void(const HttpResponsePtr &)> &callback);
2574+
bool checkModelLoaded(std::function<void(const HttpResponsePtr &)> &callback);
25752575
void warmupModel();
25762576
void backgroundTask();
25772577
void stopBackgroundTask();

0 commit comments

Comments
 (0)