Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit a043b9f

Browse files
chore: refactor inference service (#1536)
* fix: error for models start * fix: add unit tests * chore: refactor inference service * fix: build --------- Co-authored-by: vansangpfiev <sang@jan.ai>
1 parent c434ac2 commit a043b9f

File tree

10 files changed

+77
-65
lines changed

10 files changed

+77
-65
lines changed

engine/cli/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ add_executable(${TARGET_NAME} main.cc
7979
${CMAKE_CURRENT_SOURCE_DIR}/../services/download_service.cc
8080
${CMAKE_CURRENT_SOURCE_DIR}/../services/engine_service.cc
8181
${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc
82+
${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc
8283
)
8384

8485
target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib)

engine/cli/commands/model_start_cmd.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
namespace commands {
1313
bool ModelStartCmd::Exec(const std::string& host, int port,
14-
const std::string& model_handle) {
14+
const std::string& model_handle,
15+
bool print_success_log) {
1516
std::optional<std::string> model_id =
1617
SelectLocalModel(model_service_, model_handle);
1718

@@ -37,9 +38,12 @@ bool ModelStartCmd::Exec(const std::string& host, int port,
3738
data_str.size(), "application/json");
3839
if (res) {
3940
if (res->status == httplib::StatusCode::OK_200) {
40-
CLI_LOG(model_id.value() << " model started successfully. Use `"
41-
<< commands::GetCortexBinary() << " run "
42-
<< *model_id << "` for interactive chat shell");
41+
if (print_success_log) {
42+
CLI_LOG(model_id.value()
43+
<< " model started successfully. Use `"
44+
<< commands::GetCortexBinary() << " run " << *model_id
45+
<< "` for interactive chat shell");
46+
}
4347
return true;
4448
} else {
4549
auto root = json_helper::ParseJsonString(res->body);

engine/cli/commands/model_start_cmd.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ class ModelStartCmd {
99
explicit ModelStartCmd(const ModelService& model_service)
1010
: model_service_{model_service} {};
1111

12-
bool Exec(const std::string& host, int port, const std::string& model_handle);
12+
bool Exec(const std::string& host, int port, const std::string& model_handle,
13+
bool print_success_log = true);
1314

1415
private:
1516
ModelService model_service_;

engine/cli/commands/run_cmd.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "config/yaml_config.h"
44
#include "cortex_upd_cmd.h"
55
#include "database/models.h"
6+
#include "model_start_cmd.h"
67
#include "model_status_cmd.h"
78
#include "server_start_cmd.h"
89
#include "utils/cli_selection_utils.h"
@@ -82,7 +83,7 @@ void RunCmd::Exec(bool run_detach) {
8283
if (!model_id.has_value()) {
8384
return;
8485
}
85-
86+
8687
cortex::db::Models modellist_handler;
8788
config::YamlHandler yaml_handler;
8889
auto address = host_ + ":" + std::to_string(port_);
@@ -139,12 +140,10 @@ void RunCmd::Exec(bool run_detach) {
139140
!commands::ModelStatusCmd(model_service_)
140141
.IsLoaded(host_, port_, *model_id)) {
141142

142-
auto result = model_service_.StartModel(host_, port_, *model_id);
143-
if (result.has_error()) {
144-
CLI_LOG("Error: " + result.error());
145-
return;
146-
}
147-
if (!result.value()) {
143+
auto res =
144+
commands::ModelStartCmd(model_service_)
145+
.Exec(host_, port_, *model_id, false /*print_success_log*/);
146+
if (!res) {
148147
CLI_LOG("Error: Failed to start model");
149148
return;
150149
}

engine/controllers/server.cc

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ using namespace inferences;
1010
using json = nlohmann::json;
1111
namespace inferences {
1212

13-
server::server() {
13+
server::server(std::shared_ptr<services::InferenceService> inference_service)
14+
: inference_svc_(inference_service) {
1415
#if defined(_WIN32)
1516
SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS);
1617
#endif
@@ -25,7 +26,7 @@ void server::ChatCompletion(
2526
auto json_body = req->getJsonObject();
2627
bool is_stream = (*json_body).get("stream", false).asBool();
2728
auto q = std::make_shared<services::SyncQueue>();
28-
auto ir = inference_svc_.HandleChatCompletion(q, json_body);
29+
auto ir = inference_svc_->HandleChatCompletion(q, json_body);
2930
if (ir.has_error()) {
3031
auto err = ir.error();
3132
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(err));
@@ -47,7 +48,7 @@ void server::Embedding(const HttpRequestPtr& req,
4748
std::function<void(const HttpResponsePtr&)>&& callback) {
4849
LOG_TRACE << "Start embedding";
4950
auto q = std::make_shared<services::SyncQueue>();
50-
auto ir = inference_svc_.HandleEmbedding(q, req->getJsonObject());
51+
auto ir = inference_svc_->HandleEmbedding(q, req->getJsonObject());
5152
if (ir.has_error()) {
5253
auto err = ir.error();
5354
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(err));
@@ -64,7 +65,7 @@ void server::Embedding(const HttpRequestPtr& req,
6465
void server::UnloadModel(
6566
const HttpRequestPtr& req,
6667
std::function<void(const HttpResponsePtr&)>&& callback) {
67-
auto ir = inference_svc_.UnloadModel(req->getJsonObject());
68+
auto ir = inference_svc_->UnloadModel(req->getJsonObject());
6869
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
6970
resp->setStatusCode(
7071
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
@@ -74,7 +75,7 @@ void server::UnloadModel(
7475
void server::ModelStatus(
7576
const HttpRequestPtr& req,
7677
std::function<void(const HttpResponsePtr&)>&& callback) {
77-
auto ir = inference_svc_.GetModelStatus(req->getJsonObject());
78+
auto ir = inference_svc_->GetModelStatus(req->getJsonObject());
7879
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
7980
resp->setStatusCode(
8081
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
@@ -84,7 +85,7 @@ void server::ModelStatus(
8485
void server::GetModels(const HttpRequestPtr& req,
8586
std::function<void(const HttpResponsePtr&)>&& callback) {
8687
LOG_TRACE << "Start to get models";
87-
auto ir = inference_svc_.GetModels(req->getJsonObject());
88+
auto ir = inference_svc_->GetModels(req->getJsonObject());
8889
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
8990
resp->setStatusCode(
9091
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
@@ -95,15 +96,15 @@ void server::GetModels(const HttpRequestPtr& req,
9596
void server::GetEngines(
9697
const HttpRequestPtr& req,
9798
std::function<void(const HttpResponsePtr&)>&& callback) {
98-
auto ir = inference_svc_.GetEngines(req->getJsonObject());
99+
auto ir = inference_svc_->GetEngines(req->getJsonObject());
99100
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ir);
100101
callback(resp);
101102
}
102103

103104
void server::FineTuning(
104105
const HttpRequestPtr& req,
105106
std::function<void(const HttpResponsePtr&)>&& callback) {
106-
auto ir = inference_svc_.FineTuning(req->getJsonObject());
107+
auto ir = inference_svc_->FineTuning(req->getJsonObject());
107108
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
108109
resp->setStatusCode(
109110
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
@@ -113,7 +114,7 @@ void server::FineTuning(
113114

114115
void server::LoadModel(const HttpRequestPtr& req,
115116
std::function<void(const HttpResponsePtr&)>&& callback) {
116-
auto ir = inference_svc_.LoadModel(req->getJsonObject());
117+
auto ir = inference_svc_->LoadModel(req->getJsonObject());
117118
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
118119
resp->setStatusCode(
119120
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
@@ -124,7 +125,7 @@ void server::LoadModel(const HttpRequestPtr& req,
124125
void server::UnloadEngine(
125126
const HttpRequestPtr& req,
126127
std::function<void(const HttpResponsePtr&)>&& callback) {
127-
auto ir = inference_svc_.UnloadEngine(req->getJsonObject());
128+
auto ir = inference_svc_->UnloadEngine(req->getJsonObject());
128129
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
129130
resp->setStatusCode(
130131
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));

engine/controllers/server.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
2+
#pragma once
13
#include <nlohmann/json.hpp>
24
#include <string>
5+
#include <memory>
36

47
#if defined(_WIN32)
58
#define NOMINMAX
69
#endif
7-
#pragma once
810

911
#include <drogon/HttpController.h>
1012

@@ -31,12 +33,12 @@ using namespace drogon;
3133

3234
namespace inferences {
3335

34-
class server : public drogon::HttpController<server>,
36+
class server : public drogon::HttpController<server, false>,
3537
public BaseModel,
3638
public BaseChatCompletion,
3739
public BaseEmbedding {
3840
public:
39-
server();
41+
server(std::shared_ptr<services::InferenceService> inference_service);
4042
~server();
4143
METHOD_LIST_BEGIN
4244
// list path definitions here;
@@ -100,6 +102,6 @@ class server : public drogon::HttpController<server>,
100102
services::SyncQueue& q);
101103

102104
private:
103-
services::InferenceService inference_svc_;
105+
std::shared_ptr<services::InferenceService> inference_svc_;
104106
};
105107
}; // namespace inferences

engine/main.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "controllers/events.h"
66
#include "controllers/models.h"
77
#include "controllers/process_manager.h"
8+
#include "controllers/server.h"
89
#include "cortex-common/cortexpythoni.h"
910
#include "services/model_service.h"
1011
#include "utils/archive_utils.h"
@@ -88,21 +89,25 @@ void RunServer(std::optional<int> port) {
8889

8990
auto event_queue_ptr = std::make_shared<EventQueue>();
9091
cortex::event::EventProcessor event_processor(event_queue_ptr);
92+
auto inference_svc = std::make_shared<services::InferenceService>();
9193

9294
auto download_service = std::make_shared<DownloadService>(event_queue_ptr);
9395
auto engine_service = std::make_shared<EngineService>(download_service);
94-
auto model_service = std::make_shared<ModelService>(download_service);
96+
auto model_service =
97+
std::make_shared<ModelService>(download_service, inference_svc);
9598

9699
// initialize custom controllers
97100
auto engine_ctl = std::make_shared<Engines>(engine_service);
98101
auto model_ctl = std::make_shared<Models>(model_service, engine_service);
99102
auto event_ctl = std::make_shared<Events>(event_queue_ptr);
100103
auto pm_ctl = std::make_shared<ProcessManager>();
104+
auto server_ctl = std::make_shared<inferences::server>(inference_svc);
101105

102106
drogon::app().registerController(engine_ctl);
103107
drogon::app().registerController(model_ctl);
104108
drogon::app().registerController(event_ctl);
105109
drogon::app().registerController(pm_ctl);
110+
drogon::app().registerController(server_ctl);
106111

107112
LOG_INFO << "Server started, listening at: " << config.apiServerHost << ":"
108113
<< config.apiServerPort;

engine/services/inference_service.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#pragma once
22

3+
#include <condition_variable>
4+
#include <mutex>
35
#include <optional>
6+
#include <queue>
7+
#include <unordered_map>
48
#include <variant>
5-
#include "common/base.h"
69
#include "cortex-common/EngineI.h"
710
#include "cortex-common/cortexpythoni.h"
811
#include "utils/dylib.h"

engine/services/model_service.cc

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "utils/engine_constants.h"
1212
#include "utils/file_manager_utils.h"
1313
#include "utils/huggingface_utils.h"
14+
#include "utils/json_helper.h"
1415
#include "utils/logging_utils.h"
1516
#include "utils/result.hpp"
1617
#include "utils/string_utils.h"
@@ -611,28 +612,21 @@ cpp::result<bool, std::string> ModelService::StartModel(
611612
json_data["ai_prompt"] = mc.ai_template;
612613
}
613614

614-
auto data_str = json_data.toStyledString();
615-
CTL_INF(data_str);
616-
cli.set_read_timeout(std::chrono::seconds(60));
617-
auto res = cli.Post("/inferences/server/loadmodel", httplib::Headers(),
618-
data_str.data(), data_str.size(), "application/json");
619-
if (res) {
620-
if (res->status == httplib::StatusCode::OK_200) {
621-
return true;
622-
} else if (res->status == httplib::StatusCode::Conflict_409) {
623-
CTL_INF("Model '" + model_handle + "' is already loaded");
624-
return true;
625-
} else {
626-
auto root = json_helper::ParseJsonString(res->body);
627-
CTL_ERR("Model failed to load with status code: " << res->status);
628-
return cpp::fail("Model failed to start: " + root["message"].asString());
629-
}
615+
CTL_INF(json_data.toStyledString());
616+
assert(!!inference_svc_);
617+
auto ir =
618+
inference_svc_->LoadModel(std::make_shared<Json::Value>(json_data));
619+
auto status = std::get<0>(ir)["status_code"].asInt();
620+
auto data = std::get<1>(ir);
621+
if (status == httplib::StatusCode::OK_200) {
622+
return true;
623+
} else if (status == httplib::StatusCode::Conflict_409) {
624+
CTL_INF("Model '" + model_handle + "' is already loaded");
625+
return true;
630626
} else {
631-
auto err = res.error();
632-
CTL_ERR("HTTP error: " << httplib::to_string(err));
633-
return cpp::fail("HTTP error: " + httplib::to_string(err));
627+
CTL_ERR("Model failed to start with status code: " << status);
628+
return cpp::fail("Model failed to start: " + data["message"].asString());
634629
}
635-
636630
} catch (const std::exception& e) {
637631
return cpp::fail("Fail to load model with ID '" + model_handle +
638632
"': " + e.what());
@@ -663,25 +657,18 @@ cpp::result<bool, std::string> ModelService::StopModel(
663657
Json::Value json_data;
664658
json_data["model"] = model_handle;
665659
json_data["engine"] = mc.engine;
666-
auto data_str = json_data.toStyledString();
667-
CTL_INF(data_str);
668-
cli.set_read_timeout(std::chrono::seconds(60));
669-
auto res = cli.Post("/inferences/server/unloadmodel", httplib::Headers(),
670-
data_str.data(), data_str.size(), "application/json");
671-
if (res) {
672-
if (res->status == httplib::StatusCode::OK_200) {
673-
return true;
674-
} else {
675-
CTL_ERR("Model failed to unload with status code: " << res->status);
676-
return cpp::fail("Model failed to unload with status code: " +
677-
std::to_string(res->status));
678-
}
660+
CTL_INF(json_data.toStyledString());
661+
assert(inference_svc_);
662+
auto ir =
663+
inference_svc_->UnloadModel(std::make_shared<Json::Value>(json_data));
664+
auto status = std::get<0>(ir)["status_code"].asInt();
665+
auto data = std::get<1>(ir);
666+
if (status == httplib::StatusCode::OK_200) {
667+
return true;
679668
} else {
680-
auto err = res.error();
681-
CTL_ERR("HTTP error: " << httplib::to_string(err));
682-
return cpp::fail("HTTP error: " + httplib::to_string(err));
669+
CTL_ERR("Model failed to stop with status code: " << status);
670+
return cpp::fail("Model failed to stop: " + data["message"].asString());
683671
}
684-
685672
} catch (const std::exception& e) {
686673
return cpp::fail("Fail to unload model with ID '" + model_handle +
687674
"': " + e.what());

engine/services/model_service.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,21 @@
55
#include <string>
66
#include "config/model_config.h"
77
#include "services/download_service.h"
8+
#include "services/inference_service.h"
9+
810
class ModelService {
911
public:
1012
constexpr auto static kHuggingFaceHost = "huggingface.co";
1113

1214
explicit ModelService(std::shared_ptr<DownloadService> download_service)
1315
: download_service_{download_service} {};
1416

17+
explicit ModelService(
18+
std::shared_ptr<DownloadService> download_service,
19+
std::shared_ptr<services::InferenceService> inference_service)
20+
: download_service_{download_service},
21+
inference_svc_(inference_service) {};
22+
1523
/**
1624
* Return model id if download successfully
1725
*/
@@ -67,4 +75,5 @@ class ModelService {
6775
const std::string& modelName);
6876

6977
std::shared_ptr<DownloadService> download_service_;
78+
std::shared_ptr<services::InferenceService> inference_svc_;
7079
};

0 commit comments

Comments
 (0)