Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 7d0f380

Browse files
committed
feat: nitro multi modal
1 parent 1df7af9 commit 7d0f380

File tree

3 files changed

+97
-24
lines changed

3 files changed

+97
-24
lines changed

controllers/llamaCPP.cc

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <regex>
99
#include <string>
1010
#include <thread>
11+
#include <trantor/utils/Logger.h>
1112

1213
using namespace inferences;
1314
using json = nlohmann::json;
@@ -174,6 +175,7 @@ void llamaCPP::chatCompletion(
174175

175176
json data;
176177
json stopWords;
178+
int no_images = 0;
177179
// To set default value
178180

179181
if (jsonBody) {
@@ -200,29 +202,78 @@ void llamaCPP::chatCompletion(
200202
(*jsonBody).get("frequency_penalty", 0).asFloat();
201203
data["presence_penalty"] = (*jsonBody).get("presence_penalty", 0).asFloat();
202204
const Json::Value &messages = (*jsonBody)["messages"];
203-
for (const auto &message : messages) {
204-
std::string input_role = message["role"].asString();
205-
std::string role;
206-
if (input_role == "user") {
207-
role = user_prompt;
208-
std::string content = message["content"].asString();
209-
formatted_output += role + content;
210-
} else if (input_role == "assistant") {
211-
role = ai_prompt;
212-
std::string content = message["content"].asString();
213-
formatted_output += role + content;
214-
} else if (input_role == "system") {
215-
role = system_prompt;
216-
std::string content = message["content"].asString();
217-
formatted_output = role + content + formatted_output;
218205

219-
} else {
220-
role = input_role;
221-
std::string content = message["content"].asString();
222-
formatted_output += role + content;
206+
if (!multi_modal) {
207+
208+
for (const auto &message : messages) {
209+
std::string input_role = message["role"].asString();
210+
std::string role;
211+
if (input_role == "user") {
212+
role = user_prompt;
213+
std::string content = message["content"].asString();
214+
formatted_output += role + content;
215+
} else if (input_role == "assistant") {
216+
role = ai_prompt;
217+
std::string content = message["content"].asString();
218+
formatted_output += role + content;
219+
} else if (input_role == "system") {
220+
role = system_prompt;
221+
std::string content = message["content"].asString();
222+
formatted_output = role + content + formatted_output;
223+
224+
} else {
225+
role = input_role;
226+
std::string content = message["content"].asString();
227+
formatted_output += role + content;
228+
}
223229
}
230+
formatted_output += ai_prompt;
231+
} else {
232+
233+
data["image_data"] = json::array();
234+
for (const auto &message : messages) {
235+
std::string input_role = message["role"].asString();
236+
std::string role;
237+
if (input_role == "user") {
238+
formatted_output += role;
239+
for (auto content_piece : message["content"]) {
240+
role = user_prompt;
241+
242+
auto content_piece_type = content_piece["type"].asString();
243+
if (content_piece_type == "text") {
244+
auto text = content_piece["text"].asString();
245+
formatted_output += text;
246+
} else if (content_piece_type == "image_url") {
247+
auto image_url = content_piece["image_url"]["url"].asString();
248+
auto base64_image_data = nitro_utils::extractBase64(image_url);
249+
formatted_output += "[img-" + std::to_string(no_images) + "]";
250+
251+
json content_piece_image_data;
252+
content_piece_image_data["data"] = base64_image_data;
253+
content_piece_image_data["id"] = no_images;
254+
data["image_data"].push_back(content_piece_image_data);
255+
no_images++;
256+
}
257+
}
258+
259+
} else if (input_role == "assistant") {
260+
role = ai_prompt;
261+
std::string content = message["content"].asString();
262+
formatted_output += role + content;
263+
} else if (input_role == "system") {
264+
role = system_prompt;
265+
std::string content = message["content"].asString();
266+
formatted_output = role + content + formatted_output;
267+
268+
} else {
269+
role = input_role;
270+
std::string content = message["content"].asString();
271+
formatted_output += role + content;
272+
}
273+
}
274+
formatted_output += ai_prompt;
275+
LOG_INFO << formatted_output;
224276
}
225-
formatted_output += ai_prompt;
226277

227278
data["prompt"] = formatted_output;
228279
for (const auto &stop_word : (*jsonBody)["stop"]) {
@@ -386,6 +437,11 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
386437
int drogon_thread = drogon::app().getThreadNum() - 1;
387438
LOG_INFO << "Drogon thread is:" << drogon_thread;
388439
if (jsonBody) {
440+
if (!jsonBody["mmproj"].isNull()) {
441+
LOG_INFO << "MMPROJ FILE detected, multi-model enabled!";
442+
params.mmproj = jsonBody["mmproj"].asString();
443+
multi_modal = true;
444+
}
389445
params.model = jsonBody["llama_model_path"].asString();
390446
params.n_gpu_layers = jsonBody.get("ngl", 100).asInt();
391447
params.n_ctx = jsonBody.get("ctx_len", 2048).asInt();

controllers/llamaCPP.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ struct llama_server_context {
732732
if (images_data != data.end() && images_data->is_array()) {
733733
for (const auto &img : *images_data) {
734734
std::string data_b64 = img["data"].get<std::string>();
735+
LOG_INFO << data_b64;
735736
slot_image img_sl;
736737
img_sl.id =
737738
img.count("id") != 0 ? img["id"].get<int>() : slot->images.size();
@@ -1834,7 +1835,7 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
18341835
public:
18351836
llamaCPP() {
18361837
// Some default values for now below
1837-
log_disable(); // Disable the log to file feature, reduce bloat for
1838+
// log_disable(); // Disable the log to file feature, reduce bloat for
18381839
// target
18391840
// system ()
18401841
std::vector<std::string> llama_models =
@@ -1877,8 +1878,9 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
18771878
METHOD_LIST_END
18781879
void chatCompletion(const HttpRequestPtr &req,
18791880
std::function<void(const HttpResponsePtr &)> &&callback);
1880-
void chatCompletionPrelight(const HttpRequestPtr &req,
1881-
std::function<void(const HttpResponsePtr &)> &&callback);
1881+
void chatCompletionPrelight(
1882+
const HttpRequestPtr &req,
1883+
std::function<void(const HttpResponsePtr &)> &&callback);
18821884
void embedding(const HttpRequestPtr &req,
18831885
std::function<void(const HttpResponsePtr &)> &&callback);
18841886
void loadModel(const HttpRequestPtr &req,
@@ -1911,5 +1913,6 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
19111913
bool caching_enabled;
19121914
std::atomic<int> no_of_chats = 0;
19131915
int clean_cache_threshold;
1916+
bool multi_modal = false;
19141917
};
19151918
}; // namespace inferences

utils/nitro_utils.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
#include <drogon/HttpResponse.h>
77
#include <iostream>
88
#include <ostream>
9+
#include <regex>
910
// Include platform-specific headers
1011
#ifdef _WIN32
11-
#include <winsock2.h>
1212
#include <windows.h>
13+
#include <winsock2.h>
1314
#else
1415
#include <dirent.h>
1516
#endif
@@ -18,6 +19,19 @@ namespace nitro_utils {
1819

1920
inline std::string models_folder = "./models";
2021

22+
inline std::string extractBase64(const std::string &input) {
23+
std::regex pattern("base64,(.*)");
24+
std::smatch match;
25+
26+
if (std::regex_search(input, match, pattern)) {
27+
std::string base64_data = match[1];
28+
base64_data = base64_data.substr(0, base64_data.length() - 1);
29+
return base64_data;
30+
}
31+
32+
return "";
33+
}
34+
2135
inline std::vector<std::string> listFilesInDir(const std::string &path) {
2236
std::vector<std::string> files;
2337

0 commit comments

Comments
 (0)