Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 5562284

Browse files
authored
Merge pull request #1522 from janhq/j/update-download-api
fix: update download api
2 parents fcee871 + 5050bc3 commit 5562284

File tree

17 files changed

+347
-153
lines changed

17 files changed

+347
-153
lines changed

docs/static/openapi/jan.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,15 @@
16511651
"value": "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf"
16521652
}
16531653
]
1654+
},
1655+
"id": {
1656+
"type": "string",
1657+
"description": "The id which will be used to register the model.",
1658+
"examples": [
1659+
{
1660+
"value": "my-custom-model-id"
1661+
}
1662+
]
16541663
}
16551664
}
16561665
},

engine/cli/commands/engine_get_cmd.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#include "engine_get_cmd.h"
2+
#include <json/reader.h>
3+
#include <json/value.h>
24
#include <iostream>
35

46
#include "httplib.h"
5-
#include "json/json.h"
67
#include "server_start_cmd.h"
78
#include "utils/logging_utils.h"
89

@@ -29,7 +30,6 @@ void EngineGetCmd::Exec(const std::string& host, int port,
2930
auto res = cli.Get("/v1/engines/" + engine_name);
3031
if (res) {
3132
if (res->status == httplib::StatusCode::OK_200) {
32-
// CLI_LOG(res->body);
3333
Json::Value v;
3434
Json::Reader reader;
3535
reader.parse(res->body, v);
@@ -39,7 +39,8 @@ void EngineGetCmd::Exec(const std::string& host, int port,
3939
v["status"].asString()});
4040

4141
} else {
42-
CLI_LOG_ERROR("Failed to get engine list with status code: " << res->status);
42+
CLI_LOG_ERROR(
43+
"Failed to get engine list with status code: " << res->status);
4344
return;
4445
}
4546
} else {

engine/cli/commands/model_list_cmd.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ using namespace tabulate;
1717
using Row_t =
1818
std::vector<variant<std::string, const char*, string_view, Table>>;
1919

20-
void ModelListCmd::Exec(const std::string& host, int port, std::string filter,
21-
bool display_engine, bool display_version) {
20+
void ModelListCmd::Exec(const std::string& host, int port,
21+
const std::string& filter, bool display_engine,
22+
bool display_version) {
2223
// Start server if server is not started yet
2324
if (!commands::IsServerAlive(host, port)) {
2425
CLI_LOG("Starting server ...");

engine/cli/commands/model_list_cmd.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace commands {
66

77
class ModelListCmd {
88
public:
9-
void Exec(const std::string& host, int port, std::string filter,
9+
void Exec(const std::string& host, int port, const std::string& filter,
1010
bool display_engine = false, bool display_version = false);
1111
};
1212
} // namespace commands

engine/common/download_task.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
#include <string>
88

99
enum class DownloadType { Model, Engine, Miscellaneous, CudaToolkit, Cortex };
10+
1011
using namespace nlohmann;
1112

1213
struct DownloadItem {
14+
1315
std::string id;
1416

1517
std::string downloadUrl;
@@ -54,8 +56,12 @@ inline std::string DownloadTypeToString(DownloadType type) {
5456
}
5557

5658
struct DownloadTask {
59+
enum class Status { Pending, InProgress, Completed, Cancelled, Error };
60+
5761
std::string id;
5862

63+
Status status;
64+
5965
DownloadType type;
6066

6167
std::vector<DownloadItem> items;
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include <condition_variable>
2+
#include <deque>
3+
#include <mutex>
4+
#include <optional>
5+
#include <shared_mutex>
6+
#include <string>
7+
#include <unordered_map>
8+
#include "common/download_task.h"
9+
10+
class DownloadTaskQueue {
11+
private:
12+
std::deque<DownloadTask> taskQueue;
13+
std::unordered_map<std::string, typename std::deque<DownloadTask>::iterator>
14+
taskMap;
15+
mutable std::shared_mutex mutex;
16+
std::condition_variable_any cv;
17+
18+
public:
19+
void push(DownloadTask task) {
20+
std::unique_lock lock(mutex);
21+
taskQueue.push_back(std::move(task));
22+
taskMap[taskQueue.back().id] = std::prev(taskQueue.end());
23+
cv.notify_one();
24+
}
25+
26+
std::optional<DownloadTask> pop() {
27+
std::unique_lock lock(mutex);
28+
if (taskQueue.empty()) {
29+
return std::nullopt;
30+
}
31+
DownloadTask task = std::move(taskQueue.front());
32+
taskQueue.pop_front();
33+
taskMap.erase(task.id);
34+
return task;
35+
}
36+
37+
bool cancelTask(const std::string& taskId) {
38+
std::unique_lock lock(mutex);
39+
auto it = taskMap.find(taskId);
40+
if (it != taskMap.end()) {
41+
it->second->status = DownloadTask::Status::Cancelled;
42+
taskQueue.erase(it->second);
43+
taskMap.erase(it);
44+
return true;
45+
}
46+
return false;
47+
}
48+
49+
bool updateTaskStatus(const std::string& taskId,
50+
DownloadTask::Status newStatus) {
51+
std::unique_lock lock(mutex);
52+
auto it = taskMap.find(taskId);
53+
if (it != taskMap.end()) {
54+
it->second->status = newStatus;
55+
if (newStatus == DownloadTask::Status::Cancelled ||
56+
newStatus == DownloadTask::Status::Error) {
57+
taskQueue.erase(it->second);
58+
taskMap.erase(it);
59+
}
60+
return true;
61+
}
62+
return false;
63+
}
64+
65+
std::optional<DownloadTask> getNextPendingTask() {
66+
std::shared_lock lock(mutex);
67+
auto it = std::find_if(
68+
taskQueue.begin(), taskQueue.end(), [](const DownloadTask& task) {
69+
return task.status == DownloadTask::Status::Pending;
70+
});
71+
72+
if (it != taskQueue.end()) {
73+
return *it;
74+
}
75+
return std::nullopt;
76+
}
77+
};

engine/controllers/models.cc

Lines changed: 38 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "database/models.h"
22
#include <drogon/HttpTypes.h>
3+
#include <optional>
34
#include "config/gguf_parser.h"
45
#include "config/yaml_config.h"
56
#include "models.h"
@@ -26,15 +27,22 @@ void Models::PullModel(const HttpRequestPtr& req,
2627
return;
2728
}
2829

30+
std::optional<std::string> desired_model_id = std::nullopt;
31+
auto id = (*(req->getJsonObject())).get("id", "").asString();
32+
if (!id.empty()) {
33+
desired_model_id = id;
34+
}
35+
2936
auto handle_model_input =
3037
[&, model_handle]() -> cpp::result<DownloadTask, std::string> {
3138
CTL_INF("Handle model input, model handle: " + model_handle);
3239
if (string_utils::StartsWith(model_handle, "https")) {
33-
return model_service_->HandleDownloadUrlAsync(model_handle);
40+
return model_service_->HandleDownloadUrlAsync(model_handle,
41+
desired_model_id);
3442
} else if (model_handle.find(":") != std::string::npos) {
3543
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
3644
return model_service_->DownloadModelFromCortexsoAsync(
37-
model_and_branch[0], model_and_branch[1]);
45+
model_and_branch[0], model_and_branch[1], desired_model_id);
3846
}
3947

4048
return cpp::fail("Invalid model handle or not supported!");
@@ -107,7 +115,6 @@ void Models::ListModel(
107115
auto list_entry = modellist_handler.LoadModelList();
108116
if (list_entry) {
109117
for (const auto& model_entry : list_entry.value()) {
110-
// auto model_entry = modellist_handler.GetModelInfo(model_handle);
111118
try {
112119
yaml_handler.ModelConfigFromFile(
113120
fmu::ToAbsoluteCortexDataPath(
@@ -116,7 +123,6 @@ void Models::ListModel(
116123
auto model_config = yaml_handler.GetModelConfig();
117124
Json::Value obj = model_config.ToJson();
118125
obj["id"] = model_entry.model;
119-
obj["model_alias"] = model_entry.model_alias;
120126
obj["model"] = model_entry.model;
121127
data.append(std::move(obj));
122128
yaml_handler.Reset();
@@ -156,7 +162,6 @@ void Models::GetModel(const HttpRequestPtr& req,
156162
config::YamlHandler yaml_handler;
157163
auto model_entry = modellist_handler.GetModelInfo(model_id);
158164
if (model_entry.has_error()) {
159-
// CLI_LOG("Error: " + model_entry.error());
160165
ret["id"] = model_id;
161166
ret["object"] = "model";
162167
ret["result"] = "Fail to get model information";
@@ -333,71 +338,6 @@ void Models::ImportModel(
333338
}
334339
}
335340

336-
void Models::SetModelAlias(
337-
const HttpRequestPtr& req,
338-
std::function<void(const HttpResponsePtr&)>&& callback) const {
339-
if (!http_util::HasFieldInReq(req, callback, "model") ||
340-
!http_util::HasFieldInReq(req, callback, "modelAlias")) {
341-
return;
342-
}
343-
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
344-
auto model_alias = (*(req->getJsonObject())).get("modelAlias", "").asString();
345-
LOG_DEBUG << "GetModel, Model handle: " << model_handle
346-
<< ", Model alias: " << model_alias;
347-
348-
cortex::db::Models modellist_handler;
349-
try {
350-
auto result = modellist_handler.UpdateModelAlias(model_handle, model_alias);
351-
if (result.has_error()) {
352-
std::string message = result.error();
353-
LOG_ERROR << message;
354-
Json::Value ret;
355-
ret["result"] = "Set alias failed!";
356-
ret["modelHandle"] = model_handle;
357-
ret["message"] = message;
358-
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
359-
resp->setStatusCode(k400BadRequest);
360-
callback(resp);
361-
} else {
362-
if (result.value()) {
363-
std::string message = "Successfully set model alias '" + model_alias +
364-
"' for modeID '" + model_handle + "'.";
365-
LOG_INFO << message;
366-
Json::Value ret;
367-
ret["result"] = "OK";
368-
ret["modelHandle"] = model_handle;
369-
ret["message"] = message;
370-
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
371-
resp->setStatusCode(k200OK);
372-
callback(resp);
373-
} else {
374-
std::string message = "Unable to set model alias for modelID '" +
375-
model_handle + "': model alias '" + model_alias +
376-
"' is not unique!";
377-
LOG_ERROR << message;
378-
Json::Value ret;
379-
ret["result"] = "Set alias failed!";
380-
ret["modelHandle"] = model_handle;
381-
ret["message"] = message;
382-
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
383-
resp->setStatusCode(k400BadRequest);
384-
callback(resp);
385-
}
386-
}
387-
} catch (const std::exception& e) {
388-
std::string message = "Error when setting model alias ('" + model_alias +
389-
"') for modelID '" + model_handle + "':" + e.what();
390-
LOG_ERROR << message;
391-
Json::Value ret;
392-
ret["result"] = "Set alias failed!";
393-
ret["modelHandle"] = model_handle;
394-
ret["message"] = message;
395-
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
396-
resp->setStatusCode(k400BadRequest);
397-
callback(resp);
398-
}
399-
}
400-
401341
void Models::StartModel(
402342
const HttpRequestPtr& req,
403343
std::function<void(const HttpResponsePtr&)>&& callback) {
@@ -407,6 +347,34 @@ void Models::StartModel(
407347
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
408348
auto custom_prompt_template =
409349
(*(req->getJsonObject())).get("prompt_template", "").asString();
350+
auto model_entry = model_service_->GetDownloadedModel(model_handle);
351+
if (!model_entry.has_value()) {
352+
Json::Value ret;
353+
ret["message"] = "Cannot find model: " + model_handle;
354+
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
355+
resp->setStatusCode(drogon::k400BadRequest);
356+
callback(resp);
357+
return;
358+
}
359+
auto engine_name = model_entry.value().engine;
360+
auto engine_entry = engine_service_->GetEngineInfo(engine_name);
361+
if (engine_entry.has_error()) {
362+
Json::Value ret;
363+
ret["message"] = "Cannot find engine: " + engine_name;
364+
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
365+
resp->setStatusCode(drogon::k400BadRequest);
366+
callback(resp);
367+
return;
368+
}
369+
if (engine_entry->status != "Ready") {
370+
Json::Value ret;
371+
ret["message"] = "Engine is not ready! Please install first!";
372+
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
373+
resp->setStatusCode(drogon::k400BadRequest);
374+
callback(resp);
375+
return;
376+
}
377+
410378
auto result = model_service_->StartModel(
411379
config.apiServerHost, std::stoi(config.apiServerPort), model_handle,
412380
custom_prompt_template);

engine/controllers/models.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <drogon/HttpController.h>
44
#include <trantor/utils/Logger.h>
5+
#include "services/engine_service.h"
56
#include "services/model_service.h"
67

78
using namespace drogon;
@@ -16,7 +17,6 @@ class Models : public drogon::HttpController<Models, false> {
1617
METHOD_ADD(Models::UpdateModel, "/{1}", Patch);
1718
METHOD_ADD(Models::ImportModel, "/import", Post);
1819
METHOD_ADD(Models::DeleteModel, "/{1}", Delete);
19-
METHOD_ADD(Models::SetModelAlias, "/alias", Post);
2020
METHOD_ADD(Models::StartModel, "/start", Post);
2121
METHOD_ADD(Models::StopModel, "/stop", Post);
2222
METHOD_ADD(Models::GetModelStatus, "/status/{1}", Get);
@@ -28,14 +28,14 @@ class Models : public drogon::HttpController<Models, false> {
2828
ADD_METHOD_TO(Models::UpdateModel, "/v1/models/{1}", Patch);
2929
ADD_METHOD_TO(Models::ImportModel, "/v1/models/import", Post);
3030
ADD_METHOD_TO(Models::DeleteModel, "/v1/models/{1}", Delete);
31-
ADD_METHOD_TO(Models::SetModelAlias, "/v1/models/alias", Post);
3231
ADD_METHOD_TO(Models::StartModel, "/v1/models/start", Post);
3332
ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Post);
3433
ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get);
3534
METHOD_LIST_END
3635

37-
explicit Models(std::shared_ptr<ModelService> model_service)
38-
: model_service_{model_service} {}
36+
explicit Models(std::shared_ptr<ModelService> model_service,
37+
std::shared_ptr<EngineService> engine_service)
38+
: model_service_{model_service}, engine_service_{engine_service} {}
3939

4040
void PullModel(const HttpRequestPtr& req,
4141
std::function<void(const HttpResponsePtr&)>&& callback);
@@ -71,4 +71,5 @@ class Models : public drogon::HttpController<Models, false> {
7171

7272
private:
7373
std::shared_ptr<ModelService> model_service_;
74+
std::shared_ptr<EngineService> engine_service_;
7475
};

engine/database/models.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#pragma once
22

3+
#include <SQLiteCpp/Database.h>
34
#include <trantor/utils/Logger.h>
45
#include <string>
56
#include <vector>
6-
#include "SQLiteCpp/SQLiteCpp.h"
77
#include "utils/result.hpp"
88

99
namespace cortex::db {

engine/e2e-test/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from test_api_model_start import TestApiModelStart
1212
from test_api_model_stop import TestApiModelStop
1313
from test_api_model_get import TestApiModelGet
14-
from test_api_model_alias import TestApiModelAlias
1514
from test_api_model_list import TestApiModelList
1615
from test_api_model_update import TestApiModelUpdate
1716
from test_api_model_delete import TestApiModelDelete

0 commit comments

Comments
 (0)