Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 4cc64f3

Browse files
authored
Merge pull request #457 from janhq/refactor/parameters-mapping-entity
refactor: parameter mapping - model entity
2 parents 71824b9 + 73b3a75 commit 4cc64f3

File tree

5 files changed

+173
-145
lines changed

5 files changed

+173
-145
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ endif()
7878
aux_source_directory(controllers CTL_SRC)
7979
aux_source_directory(common COMMON_SRC)
8080
aux_source_directory(context CONTEXT_SRC)
81+
aux_source_directory(models MODEL_SRC)
8182
# aux_source_directory(filters FILTER_SRC) aux_source_directory(plugins
82-
# PLUGIN_SRC) aux_source_directory(models MODEL_SRC)
83+
# PLUGIN_SRC)
8384

8485
# drogon_create_views(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/views
8586
# ${CMAKE_CURRENT_BINARY_DIR}) use the following line to create views with

common/base.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <drogon/HttpController.h>
3+
#include <models/chat_completion_request.h>
34

45
using namespace drogon;
56

@@ -8,9 +9,8 @@ class BaseModel {
89
virtual ~BaseModel() {}
910

1011
// Model management
11-
virtual void LoadModel(
12-
const HttpRequestPtr& req,
13-
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
12+
virtual void LoadModel(const HttpRequestPtr& req,
13+
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
1414
virtual void UnloadModel(
1515
const HttpRequestPtr& req,
1616
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
@@ -25,7 +25,7 @@ class BaseChatCompletion {
2525

2626
// General chat method
2727
virtual void ChatCompletion(
28-
const HttpRequestPtr& req,
28+
inferences::ChatCompletionRequest &&completion,
2929
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
3030
};
3131

controllers/llamaCPP.cc

Lines changed: 119 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
#include "llamaCPP.h"
22

3-
4-
#include <iostream>
53
#include <fstream>
4+
#include <iostream>
65
#include "log.h"
7-
#include "utils/nitro_utils.h"
86

97
// External
108
#include "common.h"
@@ -175,19 +173,18 @@ void llamaCPP::WarmupModel() {
175173
}
176174

177175
void llamaCPP::ChatCompletion(
178-
const HttpRequestPtr& req,
176+
inferences::ChatCompletionRequest&& completion,
179177
std::function<void(const HttpResponsePtr&)>&& callback) {
180-
const auto& jsonBody = req->getJsonObject();
181178
// Check if model is loaded
182179
if (CheckModelLoaded(callback)) {
183180
// Model is loaded
184181
// Do Inference
185-
InferenceImpl(jsonBody, callback);
182+
InferenceImpl(std::move(completion), callback);
186183
}
187184
}
188185

189186
void llamaCPP::InferenceImpl(
190-
std::shared_ptr<Json::Value> jsonBody,
187+
inferences::ChatCompletionRequest&& completion,
191188
std::function<void(const HttpResponsePtr&)>& callback) {
192189
std::string formatted_output = pre_prompt;
193190

@@ -196,131 +193,131 @@ void llamaCPP::InferenceImpl(
196193
int no_images = 0;
197194
// To set default value
198195

199-
if (jsonBody) {
200-
// Increase number of chats received and clean the prompt
201-
no_of_chats++;
202-
if (no_of_chats % clean_cache_threshold == 0) {
203-
LOG_INFO << "Clean cache threshold reached!";
204-
llama.kv_cache_clear();
205-
LOG_INFO << "Cache cleaned";
206-
}
207-
208-
// Default values to enable auto caching
209-
data["cache_prompt"] = caching_enabled;
210-
data["n_keep"] = -1;
211-
212-
// Passing load value
213-
data["repeat_last_n"] = this->repeat_last_n;
214-
215-
data["stream"] = (*jsonBody).get("stream", false).asBool();
216-
data["n_predict"] = (*jsonBody).get("max_tokens", 500).asInt();
217-
data["top_p"] = (*jsonBody).get("top_p", 0.95).asFloat();
218-
data["temperature"] = (*jsonBody).get("temperature", 0.8).asFloat();
219-
data["frequency_penalty"] =
220-
(*jsonBody).get("frequency_penalty", 0).asFloat();
221-
data["presence_penalty"] = (*jsonBody).get("presence_penalty", 0).asFloat();
222-
const Json::Value& messages = (*jsonBody)["messages"];
196+
// Increase number of chats received and clean the prompt
197+
no_of_chats++;
198+
if (no_of_chats % clean_cache_threshold == 0) {
199+
LOG_INFO << "Clean cache threshold reached!";
200+
llama.kv_cache_clear();
201+
LOG_INFO << "Cache cleaned";
202+
}
223203

224-
if (!grammar_file_content.empty()) {
225-
data["grammar"] = grammar_file_content;
226-
};
204+
// Default values to enable auto caching
205+
data["cache_prompt"] = caching_enabled;
206+
data["n_keep"] = -1;
207+
208+
// Passing load value
209+
data["repeat_last_n"] = this->repeat_last_n;
210+
211+
LOG_INFO << "Messages:" << completion.messages.toStyledString();
212+
LOG_INFO << "Stop:" << completion.stop.toStyledString();
213+
214+
data["stream"] = completion.stream;
215+
data["n_predict"] = completion.max_tokens;
216+
data["top_p"] = completion.top_p;
217+
data["temperature"] = completion.temperature;
218+
data["frequency_penalty"] = completion.frequency_penalty;
219+
data["presence_penalty"] = completion.presence_penalty;
220+
const Json::Value& messages = completion.messages;
221+
222+
if (!grammar_file_content.empty()) {
223+
data["grammar"] = grammar_file_content;
224+
};
225+
226+
if (!llama.multimodal) {
227+
for (const auto& message : messages) {
228+
std::string input_role = message["role"].asString();
229+
std::string role;
230+
if (input_role == "user") {
231+
role = user_prompt;
232+
std::string content = message["content"].asString();
233+
formatted_output += role + content;
234+
} else if (input_role == "assistant") {
235+
role = ai_prompt;
236+
std::string content = message["content"].asString();
237+
formatted_output += role + content;
238+
} else if (input_role == "system") {
239+
role = system_prompt;
240+
std::string content = message["content"].asString();
241+
formatted_output = role + content + formatted_output;
227242

228-
if (!llama.multimodal) {
229-
for (const auto& message : messages) {
230-
std::string input_role = message["role"].asString();
231-
std::string role;
232-
if (input_role == "user") {
233-
role = user_prompt;
234-
std::string content = message["content"].asString();
235-
formatted_output += role + content;
236-
} else if (input_role == "assistant") {
237-
role = ai_prompt;
238-
std::string content = message["content"].asString();
239-
formatted_output += role + content;
240-
} else if (input_role == "system") {
241-
role = system_prompt;
242-
std::string content = message["content"].asString();
243-
formatted_output = role + content + formatted_output;
244-
245-
} else {
246-
role = input_role;
247-
std::string content = message["content"].asString();
248-
formatted_output += role + content;
249-
}
243+
} else {
244+
role = input_role;
245+
std::string content = message["content"].asString();
246+
formatted_output += role + content;
250247
}
251-
formatted_output += ai_prompt;
252-
} else {
253-
data["image_data"] = json::array();
254-
for (const auto& message : messages) {
255-
std::string input_role = message["role"].asString();
256-
std::string role;
257-
if (input_role == "user") {
258-
formatted_output += role;
259-
for (auto content_piece : message["content"]) {
260-
role = user_prompt;
261-
262-
json content_piece_image_data;
263-
content_piece_image_data["data"] = "";
264-
265-
auto content_piece_type = content_piece["type"].asString();
266-
if (content_piece_type == "text") {
267-
auto text = content_piece["text"].asString();
268-
formatted_output += text;
269-
} else if (content_piece_type == "image_url") {
270-
auto image_url = content_piece["image_url"]["url"].asString();
271-
std::string base64_image_data;
272-
if (image_url.find("http") != std::string::npos) {
273-
LOG_INFO << "Remote image detected but not supported yet";
274-
} else if (image_url.find("data:image") != std::string::npos) {
275-
LOG_INFO << "Base64 image detected";
276-
base64_image_data = nitro_utils::extractBase64(image_url);
277-
LOG_INFO << base64_image_data;
278-
} else {
279-
LOG_INFO << "Local image detected";
280-
nitro_utils::processLocalImage(
281-
image_url, [&](const std::string& base64Image) {
282-
base64_image_data = base64Image;
283-
});
284-
LOG_INFO << base64_image_data;
285-
}
286-
content_piece_image_data["data"] = base64_image_data;
287-
288-
formatted_output += "[img-" + std::to_string(no_images) + "]";
289-
content_piece_image_data["id"] = no_images;
290-
data["image_data"].push_back(content_piece_image_data);
291-
no_images++;
248+
}
249+
formatted_output += ai_prompt;
250+
} else {
251+
data["image_data"] = json::array();
252+
for (const auto& message : messages) {
253+
std::string input_role = message["role"].asString();
254+
std::string role;
255+
if (input_role == "user") {
256+
formatted_output += role;
257+
for (auto content_piece : message["content"]) {
258+
role = user_prompt;
259+
260+
json content_piece_image_data;
261+
content_piece_image_data["data"] = "";
262+
263+
auto content_piece_type = content_piece["type"].asString();
264+
if (content_piece_type == "text") {
265+
auto text = content_piece["text"].asString();
266+
formatted_output += text;
267+
} else if (content_piece_type == "image_url") {
268+
auto image_url = content_piece["image_url"]["url"].asString();
269+
std::string base64_image_data;
270+
if (image_url.find("http") != std::string::npos) {
271+
LOG_INFO << "Remote image detected but not supported yet";
272+
} else if (image_url.find("data:image") != std::string::npos) {
273+
LOG_INFO << "Base64 image detected";
274+
base64_image_data = nitro_utils::extractBase64(image_url);
275+
LOG_INFO << base64_image_data;
276+
} else {
277+
LOG_INFO << "Local image detected";
278+
nitro_utils::processLocalImage(
279+
image_url, [&](const std::string& base64Image) {
280+
base64_image_data = base64Image;
281+
});
282+
LOG_INFO << base64_image_data;
292283
}
293-
}
284+
content_piece_image_data["data"] = base64_image_data;
294285

295-
} else if (input_role == "assistant") {
296-
role = ai_prompt;
297-
std::string content = message["content"].asString();
298-
formatted_output += role + content;
299-
} else if (input_role == "system") {
300-
role = system_prompt;
301-
std::string content = message["content"].asString();
302-
formatted_output = role + content + formatted_output;
303-
304-
} else {
305-
role = input_role;
306-
std::string content = message["content"].asString();
307-
formatted_output += role + content;
286+
formatted_output += "[img-" + std::to_string(no_images) + "]";
287+
content_piece_image_data["id"] = no_images;
288+
data["image_data"].push_back(content_piece_image_data);
289+
no_images++;
290+
}
308291
}
292+
293+
} else if (input_role == "assistant") {
294+
role = ai_prompt;
295+
std::string content = message["content"].asString();
296+
formatted_output += role + content;
297+
} else if (input_role == "system") {
298+
role = system_prompt;
299+
std::string content = message["content"].asString();
300+
formatted_output = role + content + formatted_output;
301+
302+
} else {
303+
role = input_role;
304+
std::string content = message["content"].asString();
305+
formatted_output += role + content;
309306
}
310-
formatted_output += ai_prompt;
311-
LOG_INFO << formatted_output;
312307
}
308+
formatted_output += ai_prompt;
309+
LOG_INFO << formatted_output;
310+
}
313311

314-
data["prompt"] = formatted_output;
315-
for (const auto& stop_word : (*jsonBody)["stop"]) {
316-
stopWords.push_back(stop_word.asString());
317-
}
318-
// specify default stop words
319-
// Ensure success case for chatML
320-
stopWords.push_back("<|im_end|>");
321-
stopWords.push_back(nitro_utils::rtrim(user_prompt));
322-
data["stop"] = stopWords;
312+
data["prompt"] = formatted_output;
313+
for (const auto& stop_word : completion.stop) {
314+
stopWords.push_back(stop_word.asString());
323315
}
316+
// specify default stop words
317+
// Ensure success case for chatML
318+
stopWords.push_back("<|im_end|>");
319+
stopWords.push_back(nitro_utils::rtrim(user_prompt));
320+
data["stop"] = stopWords;
324321

325322
bool is_streamed = data["stream"];
326323
// Enable full message debugging

controllers/llamaCPP.h

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,37 @@
22
#if defined(_WIN32)
33
#define NOMINMAX
44
#endif
5-
65
#pragma once
7-
#define LOG_TARGET stdout
86

97
#include <drogon/HttpController.h>
108

11-
#include "stb_image.h"
12-
#include "context/llama_server_context.h"
13-
149
#ifndef NDEBUG
1510
// crash the server in debug mode, otherwise send an http 500 error
1611
#define CPPHTTPLIB_NO_EXCEPTIONS 1
1712
#endif
1813

1914
#include <trantor/utils/ConcurrentTaskQueue.h>
20-
#include "common/base.h"
21-
#include "utils/json.hpp"
22-
23-
// auto generated files (update with ./deps.sh)
24-
2515
#include <cstddef>
16+
#include <string>
2617
#include <thread>
2718

28-
#include <cstddef>
29-
#include <thread>
19+
#include "common/base.h"
20+
#include "context/llama_server_context.h"
21+
#include "stb_image.h"
22+
#include "utils/json.hpp"
23+
24+
#include "models/chat_completion_request.h"
3025

3126
#ifndef SERVER_VERBOSE
3227
#define SERVER_VERBOSE 1
3328
#endif
3429

35-
3630
using json = nlohmann::json;
3731

3832
using namespace drogon;
3933

4034
namespace inferences {
35+
4136
class llamaCPP : public drogon::HttpController<llamaCPP>,
4237
public BaseModel,
4338
public BaseChatCompletion,
@@ -64,14 +59,13 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,
6459
// PATH_ADD("/llama/chat_completion", Post);
6560
METHOD_LIST_END
6661
void ChatCompletion(
67-
const HttpRequestPtr& req,
62+
inferences::ChatCompletionRequest &&completion,
6863
std::function<void(const HttpResponsePtr&)>&& callback) override;
6964
void Embedding(
7065
const HttpRequestPtr& req,
7166
std::function<void(const HttpResponsePtr&)>&& callback) override;
72-
void LoadModel(
73-
const HttpRequestPtr& req,
74-
std::function<void(const HttpResponsePtr&)>&& callback) override;
67+
void LoadModel(const HttpRequestPtr& req,
68+
std::function<void(const HttpResponsePtr&)>&& callback) override;
7569
void UnloadModel(
7670
const HttpRequestPtr& req,
7771
std::function<void(const HttpResponsePtr&)>&& callback) override;
@@ -101,7 +95,7 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,
10195
trantor::ConcurrentTaskQueue* queue;
10296

10397
bool LoadModelImpl(std::shared_ptr<Json::Value> jsonBody);
104-
void InferenceImpl(std::shared_ptr<Json::Value> jsonBody,
98+
void InferenceImpl(inferences::ChatCompletionRequest&& completion,
10599
std::function<void(const HttpResponsePtr&)>& callback);
106100
void EmbeddingImpl(std::shared_ptr<Json::Value> jsonBody,
107101
std::function<void(const HttpResponsePtr&)>& callback);

0 commit comments

Comments
 (0)