Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit c229a52

Browse files
committed
feat: make non stream completion possible to be fully compatible with openaiapi
1 parent e0cef1e commit c229a52

File tree

1 file changed

+109
-43
lines changed

1 file changed

+109
-43
lines changed

controllers/llamaCPP.cc

Lines changed: 109 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cstring>
66
#include <drogon/HttpResponse.h>
77
#include <drogon/HttpTypes.h>
8+
#include <json/value.h>
89
#include <regex>
910
#include <string>
1011
#include <thread>
@@ -28,6 +29,45 @@ std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {
2829

2930
// --------------------------------------------
3031

32+
std::string create_full_return_json(const std::string &id,
33+
const std::string &model,
34+
const std::string &content,
35+
const std::string &system_fingerprint,
36+
int prompt_tokens, int completion_tokens,
37+
Json::Value finish_reason = Json::Value()) {
38+
39+
Json::Value root;
40+
41+
root["id"] = id;
42+
root["model"] = model;
43+
root["created"] = static_cast<int>(std::time(nullptr));
44+
root["object"] = "chat.completion";
45+
root["system_fingerprint"] = system_fingerprint;
46+
47+
Json::Value choicesArray(Json::arrayValue);
48+
Json::Value choice;
49+
50+
choice["index"] = 0;
51+
Json::Value message;
52+
message["role"] = "assistant";
53+
message["content"] = content;
54+
choice["message"] = message;
55+
choice["finish_reason"] = finish_reason;
56+
57+
choicesArray.append(choice);
58+
root["choices"] = choicesArray;
59+
60+
Json::Value usage;
61+
usage["prompt_tokens"] = prompt_tokens;
62+
usage["completion_tokens"] = completion_tokens;
63+
usage["total_tokens"] = prompt_tokens + completion_tokens;
64+
root["usage"] = usage;
65+
66+
Json::StreamWriterBuilder writer;
67+
writer["indentation"] = ""; // Compact output
68+
return Json::writeString(writer, root);
69+
}
70+
3171
std::string create_return_json(const std::string &id, const std::string &model,
3272
const std::string &content,
3373
Json::Value finish_reason = Json::Value()) {
@@ -82,9 +122,9 @@ void llamaCPP::chatCompletion(
82122
json data;
83123
json stopWords;
84124
// To set default value
85-
data["stream"] = true;
86125

87126
if (jsonBody) {
127+
data["stream"] = (*jsonBody).get("stream", false).asBool();
88128
data["n_predict"] = (*jsonBody).get("max_tokens", 500).asInt();
89129
data["top_p"] = (*jsonBody).get("top_p", 0.95).asFloat();
90130
data["temperature"] = (*jsonBody).get("temperature", 0.8).asFloat();
@@ -119,62 +159,87 @@ void llamaCPP::chatCompletion(
119159
data["stop"] = stopWords;
120160
}
121161

162+
bool is_streamed = data["stream"];
163+
122164
const int task_id = llama.request_completion(data, false, false);
123165
LOG_INFO << "Resolved request for task_id:" << task_id;
124166

125-
auto state = createState(task_id, this);
167+
if (is_streamed) {
168+
auto state = createState(task_id, this);
126169

127-
auto chunked_content_provider =
128-
[state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
129-
if (!pBuffer) {
130-
LOG_INFO << "Connection closed or buffer is null. Reset context";
131-
state->instance->llama.request_cancel(state->task_id);
132-
return 0;
133-
}
134-
if (state->isStopped) {
135-
return 0;
136-
}
137-
138-
task_result result = state->instance->llama.next_result(state->task_id);
139-
if (!result.error) {
140-
const std::string to_send = result.result_json["content"];
141-
const std::string str =
142-
"data: " +
143-
create_return_json(nitro_utils::generate_random_string(20), "_",
144-
to_send) +
145-
"\n\n";
146-
147-
std::size_t nRead = std::min(str.size(), nBuffSize);
148-
memcpy(pBuffer, str.data(), nRead);
170+
auto chunked_content_provider =
171+
[state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
172+
if (!pBuffer) {
173+
LOG_INFO << "Connection closed or buffer is null. Reset context";
174+
state->instance->llama.request_cancel(state->task_id);
175+
return 0;
176+
}
177+
if (state->isStopped) {
178+
return 0;
179+
}
149180

150-
if (result.stop) {
181+
task_result result = state->instance->llama.next_result(state->task_id);
182+
if (!result.error) {
183+
const std::string to_send = result.result_json["content"];
151184
const std::string str =
152185
"data: " +
153-
create_return_json(nitro_utils::generate_random_string(20), "_", "",
154-
"stop") +
155-
"\n\n" + "data: [DONE]" + "\n\n";
186+
create_return_json(nitro_utils::generate_random_string(20), "_",
187+
to_send) +
188+
"\n\n";
156189

157-
LOG_VERBOSE("data stream", {{"to_send", str}});
158190
std::size_t nRead = std::min(str.size(), nBuffSize);
159191
memcpy(pBuffer, str.data(), nRead);
160-
LOG_INFO << "reached result stop";
161-
state->isStopped = true;
162-
state->instance->llama.request_cancel(state->task_id);
192+
193+
if (result.stop) {
194+
const std::string str =
195+
"data: " +
196+
create_return_json(nitro_utils::generate_random_string(20), "_",
197+
"", "stop") +
198+
"\n\n" + "data: [DONE]" + "\n\n";
199+
200+
LOG_VERBOSE("data stream", {{"to_send", str}});
201+
std::size_t nRead = std::min(str.size(), nBuffSize);
202+
memcpy(pBuffer, str.data(), nRead);
203+
LOG_INFO << "reached result stop";
204+
state->isStopped = true;
205+
state->instance->llama.request_cancel(state->task_id);
206+
return nRead;
207+
}
163208
return nRead;
209+
} else {
210+
return 0;
164211
}
165-
return nRead;
166-
} else {
167212
return 0;
168-
}
169-
return 0;
170-
};
171-
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
172-
"chat_completions.txt");
173-
callback(resp);
213+
};
214+
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
215+
"chat_completions.txt");
216+
callback(resp);
174217

175-
return;
218+
return;
219+
} else {
220+
Json::Value respData;
221+
auto resp = nitro_utils::nitroHttpResponse();
222+
respData["testing"] = "thunghiem value moi";
223+
if (!json_value(data, "stream", false)) {
224+
std::string completion_text;
225+
task_result result = llama.next_result(task_id);
226+
if (!result.error && result.stop) {
227+
int prompt_tokens = result.result_json["tokens_evaluated"];
228+
int predicted_tokens = result.result_json["tokens_predicted"];
229+
std::string full_return =
230+
create_full_return_json(nitro_utils::generate_random_string(20),
231+
"_", result.result_json["content"], "_",
232+
prompt_tokens, predicted_tokens);
233+
resp->setBody(full_return);
234+
} else {
235+
resp->setBody("internal error during inference");
236+
return;
237+
}
238+
callback(resp);
239+
return;
240+
}
241+
}
176242
}
177-
178243
void llamaCPP::embedding(
179244
const HttpRequestPtr &req,
180245
std::function<void(const HttpResponsePtr &)> &&callback) {
@@ -262,7 +327,8 @@ void llamaCPP::loadModel(
262327
this->pre_prompt =
263328
(*jsonBody)
264329
.get("pre_prompt",
265-
"A chat between a curious user and an artificial intelligence "
330+
"A chat between a curious user and an artificial "
331+
"intelligence "
266332
"assistant. The assistant follows the given rules no matter "
267333
"what.\\n")
268334
.asString();

0 commit comments

Comments
 (0)