11#include " chat_completion_cmd.h"
2+ #include < curl/curl.h>
23#include " config/yaml_config.h"
34#include " cortex_upd_cmd.h"
45#include " database/models.h"
5- #include " httplib.h"
66#include " model_status_cmd.h"
77#include " server_start_cmd.h"
88#include " utils/engine_constants.h"
@@ -16,29 +16,42 @@ constexpr const auto kMinDataChunkSize = 6u;
1616constexpr const char * kUser = " user" ;
1717constexpr const char * kAssistant = " assistant" ;
1818
19- } // namespace
19+ struct StreamingCallback {
20+ std::string* ai_chat;
21+ bool is_done;
2022
21- struct ChunkParser {
22- std::string content;
23- bool is_done = false ;
23+ StreamingCallback () : ai_chat(nullptr ), is_done(false ) {}
24+ };
2425
25- ChunkParser (const char * data, size_t data_length) {
26- if (data && data_length > kMinDataChunkSize ) {
27- std::string s (data + kMinDataChunkSize , data_length - kMinDataChunkSize );
28- if (s.find (" [DONE]" ) != std::string::npos) {
29- is_done = true ;
30- } else {
31- try {
32- content =
33- json_helper::ParseJsonString (s)[" choices" ][0 ][" delta" ][" content" ]
34- .asString ();
35- } catch (const std::exception& e) {
36- CTL_WRN (" JSON parse error: " << e.what ());
37- }
26+ size_t WriteCallback (char * ptr, size_t size, size_t nmemb, void * userdata) {
27+ auto * callback = static_cast <StreamingCallback*>(userdata);
28+ size_t data_length = size * nmemb;
29+
30+ if (ptr && data_length > kMinDataChunkSize ) {
31+ std::string chunk (ptr + kMinDataChunkSize , data_length - kMinDataChunkSize );
32+ if (chunk.find (" [DONE]" ) != std::string::npos) {
33+ callback->is_done = true ;
34+ std::cout << std::endl;
35+ return data_length;
36+ }
37+
38+ try {
39+ std::string content =
40+ json_helper::ParseJsonString (chunk)[" choices" ][0 ][" delta" ][" content" ]
41+ .asString ();
42+ std::cout << content << std::flush;
43+ if (callback->ai_chat ) {
44+ *callback->ai_chat += content;
3845 }
46+ } catch (const std::exception& e) {
47+ CTL_WRN (" JSON parse error: " << e.what ());
3948 }
4049 }
41- };
50+
51+ return data_length;
52+ }
53+
54+ } // namespace
4255
4356void ChatCompletionCmd::Exec (const std::string& host, int port,
4457 const std::string& model_handle, std::string msg) {
@@ -68,95 +81,101 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
6881 const std::string& model_handle,
6982 const config::ModelConfig& mc, std::string msg) {
7083 auto address = host + " :" + std::to_string (port);
84+
7185 // Check if server is started
72- {
73- if (!commands::IsServerAlive (host, port)) {
74- CLI_LOG (" Server is not started yet, please run `"
75- << commands::GetCortexBinary () << " start` to start server!" );
76- return ;
77- }
86+ if (!commands::IsServerAlive (host, port)) {
87+ CLI_LOG (" Server is not started yet, please run `"
88+ << commands::GetCortexBinary () << " start` to start server!" );
89+ return ;
7890 }
7991
8092 // Only check if llamacpp engine
8193 if ((mc.engine .find (kLlamaEngine ) != std::string::npos ||
8294 mc.engine .find (kLlamaRepo ) != std::string::npos) &&
83- !commands::ModelStatusCmd (model_service_)
84- .IsLoaded (host, port, model_handle)) {
95+ !commands::ModelStatusCmd ().IsLoaded (host, port, model_handle)) {
8596 CLI_LOG (" Model is not loaded yet!" );
8697 return ;
8798 }
8899
100+ auto curl = curl_easy_init ();
101+ if (!curl) {
102+ CLI_LOG (" Failed to initialize CURL" );
103+ return ;
104+ }
105+
106+ std::string url = " http://" + address + " /v1/chat/completions" ;
107+ curl_easy_setopt (curl, CURLOPT_URL, url.c_str ());
108+ curl_easy_setopt (curl, CURLOPT_POST, 1L );
109+
110+ struct curl_slist * headers = nullptr ;
111+ headers = curl_slist_append (headers, " Content-Type: application/json" );
112+ curl_easy_setopt (curl, CURLOPT_HTTPHEADER, headers);
113+
89114 // Interactive mode or not
90115 bool interactive = msg.empty ();
91116
92- // Some instruction for user here
93117 if (interactive) {
94- std::cout << " Inorder to exit, type `exit()`" << std::endl;
118+ std::cout << " In order to exit, type `exit()`" << std::endl;
95119 }
96- // Model is loaded, start to chat
97- {
98- do {
99- std::string user_input = std::move (msg);
100- if (user_input.empty ()) {
101- std::cout << " > " ;
102- if (!std::getline (std::cin, user_input)) {
103- break ;
104- }
105- }
106120
107- string_utils::Trim (user_input);
108- if (user_input == kExitChat ) {
121+ do {
122+ std::string user_input = std::move (msg);
123+ if (user_input.empty ()) {
124+ std::cout << " > " ;
125+ if (!std::getline (std::cin, user_input)) {
109126 break ;
110127 }
128+ }
129+
130+ string_utils::Trim (user_input);
131+ if (user_input == kExitChat ) {
132+ break ;
133+ }
134+
135+ if (!user_input.empty ()) {
136+ // Prepare JSON payload
137+ Json::Value new_data;
138+ new_data[" role" ] = kUser ;
139+ new_data[" content" ] = user_input;
140+ histories_.push_back (std::move (new_data));
141+
142+ Json::Value json_data = mc.ToJson ();
143+ json_data[" engine" ] = mc.engine ;
144+
145+ Json::Value msgs_array (Json::arrayValue);
146+ for (const auto & m : histories_) {
147+ msgs_array.append (m);
148+ }
149+
150+ json_data[" messages" ] = msgs_array;
151+ json_data[" model" ] = model_handle;
152+ json_data[" stream" ] = true ;
111153
112- if (!user_input.empty ()) {
113- httplib::Client cli (address);
114- Json::Value json_data = mc.ToJson ();
115- Json::Value new_data;
116- new_data[" role" ] = kUser ;
117- new_data[" content" ] = user_input;
118- histories_.push_back (std::move (new_data));
119- json_data[" engine" ] = mc.engine ;
120- Json::Value msgs_array (Json::arrayValue);
121- for (const auto & m : histories_) {
122- msgs_array.append (m);
123- }
124- json_data[" messages" ] = msgs_array;
125- json_data[" model" ] = model_handle;
126- // TODO: support non-stream
127- json_data[" stream" ] = true ;
128- auto data_str = json_data.toStyledString ();
129- // std::cout << data_str << std::endl;
130- cli.set_read_timeout (std::chrono::seconds (60 ));
131- // std::cout << "> ";
132- httplib::Request req;
133- req.headers = httplib::Headers ();
134- req.set_header (" Content-Type" , " application/json" );
135- req.method = " POST" ;
136- req.path = " /v1/chat/completions" ;
137- req.body = data_str;
138- std::string ai_chat;
139- req.content_receiver = [&](const char * data, size_t data_length,
140- uint64_t offset, uint64_t total_length) {
141- ChunkParser cp (data, data_length);
142- if (cp.is_done ) {
143- std::cout << std::endl;
144- return false ;
145- }
146- std::cout << cp.content << std::flush;
147- ai_chat += cp.content ;
148- return true ;
149- };
150- cli.send (req);
154+ std::string json_payload = json_data.toStyledString ();
151155
156+ curl_easy_setopt (curl, CURLOPT_POSTFIELDS, json_payload.c_str ());
157+
158+ std::string ai_chat;
159+ StreamingCallback callback;
160+ callback.ai_chat = &ai_chat;
161+
162+ curl_easy_setopt (curl, CURLOPT_WRITEFUNCTION, WriteCallback);
163+ curl_easy_setopt (curl, CURLOPT_WRITEDATA, &callback);
164+
165+ CURLcode res = curl_easy_perform (curl);
166+
167+ if (res != CURLE_OK) {
168+ CLI_LOG (" CURL request failed: " << curl_easy_strerror (res));
169+ } else {
152170 Json::Value ai_res;
153171 ai_res[" role" ] = kAssistant ;
154172 ai_res[" content" ] = ai_chat;
155173 histories_.push_back (std::move (ai_res));
156174 }
157- // std::cout << "ok Done" << std::endl;
158- } while (interactive);
159- }
160- }
175+ }
176+ } while (interactive);
161177
162- }; // namespace commands
178+ curl_slist_free_all (headers);
179+ curl_easy_cleanup (curl);
180+ }
181+ } // namespace commands
0 commit comments