Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit f9b7567

Browse files
committed
update
1 parent b982159 commit f9b7567

File tree

12 files changed

+120
-123
lines changed

12 files changed

+120
-123
lines changed

engine/common/download_task.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
#include <filesystem>
55
#include <sstream>
66
#include <string>
7-
#include <unordered_map>
87

98
enum class DownloadType { Model, Engine, Miscellaneous, CudaToolkit, Cortex };
109

1110
struct DownloadItem {
1211

1312
std::string id;
1413

15-
std::optional<std::unordered_map<std::string, std::string>> headers;
16-
1714
std::string downloadUrl;
1815

1916
/**

engine/services/download_service.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ostream>
99
#include <utility>
1010
#include "download_service.h"
11+
#include "utils/curl_utils.h"
1112
#include "utils/format_utils.h"
1213
#include "utils/logging_utils.h"
1314
#include "utils/result.hpp"
@@ -53,10 +54,7 @@ cpp::result<bool, std::string> DownloadService::AddDownloadTask(
5354
}
5455

5556
cpp::result<uint64_t, std::string> DownloadService::GetFileSize(
56-
const std::string& url,
57-
const std::optional<
58-
std::reference_wrapper<std::unordered_map<std::string, std::string>>>&
59-
headers) const noexcept {
57+
const std::string& url) const noexcept {
6058

6159
auto curl = curl_easy_init();
6260
if (!curl) {
@@ -67,10 +65,11 @@ cpp::result<uint64_t, std::string> DownloadService::GetFileSize(
6765
curl_easy_setopt(curl, CURLOPT_NOBODY, 1L);
6866
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
6967

68+
auto headers = curl_utils::GetHeaders(url);
7069
if (headers.has_value()) {
7170
curl_slist* curl_headers = nullptr;
7271

73-
for (const auto& [key, value] : headers->get()) {
72+
for (const auto& [key, value] : headers.value()) {
7473
auto header = key + ": " + value;
7574
curl_headers = curl_slist_append(curl_headers, header.c_str());
7675
}
@@ -151,10 +150,11 @@ cpp::result<bool, std::string> DownloadService::Download(
151150
}
152151

153152
curl_easy_setopt(curl, CURLOPT_URL, download_item.downloadUrl.c_str());
154-
if (download_item.headers.has_value()) {
153+
auto headers = curl_utils::GetHeaders(download_item.downloadUrl);
154+
if (headers.has_value()) {
155155
curl_slist* curl_headers = nullptr;
156156

157-
for (const auto& [key, value] : download_item.headers.value()) {
157+
for (const auto& [key, value] : headers.value()) {
158158
auto header = key + ": " + value;
159159
curl_headers = curl_slist_append(curl_headers, header.c_str());
160160
}
@@ -248,10 +248,11 @@ void DownloadService::ProcessTask(DownloadTask& task) {
248248
});
249249
downloading_data_map_.insert(std::make_pair(item.id, dl_data_ptr));
250250

251-
if (item.headers.has_value()) {
251+
auto headers = curl_utils::GetHeaders(item.downloadUrl);
252+
if (headers.has_value()) {
252253
curl_slist* curl_headers = nullptr;
253254

254-
for (const auto& [key, value] : item.headers.value()) {
255+
for (const auto& [key, value] : headers.value()) {
255256
auto header = key + ": " + value;
256257
curl_headers = curl_slist_append(curl_headers, header.c_str());
257258
}

engine/services/download_service.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,7 @@ class DownloadService {
7373
* @param url - url to get file size
7474
*/
7575
cpp::result<uint64_t, std::string> GetFileSize(
76-
const std::string& url,
77-
const std::optional<
78-
std::reference_wrapper<std::unordered_map<std::string, std::string>>>&
79-
headers = std::nullopt) const noexcept;
76+
const std::string& url) const noexcept;
8077

8178
cpp::result<std::string, std::string> StopTask(const std::string& task_id);
8279

engine/services/engine_service.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ std::string GetEnginePath(std::string_view e) {
5555
}
5656
return kLlamaLibPath;
5757
};
58-
59-
constexpr const int k200OK = 200;
60-
constexpr const int k400BadRequest = 400;
61-
constexpr const int k409Conflict = 409;
62-
constexpr const int k500InternalServerError = 500;
6358
} // namespace
6459

6560
// cpp::result<EngineInfo, std::string> EngineService::GetEngineInfo(

engine/services/model_service.cc

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ void ParseGguf(const DownloadItem& ggufDownloadItem,
4343
}
4444

4545
auto url_obj = url_parser::FromUrlString(ggufDownloadItem.downloadUrl);
46-
auto branch = url_obj.pathParams[3];
46+
if (url_obj.has_error()) {
47+
CTL_WRN("Error parsing url: " << ggufDownloadItem.downloadUrl);
48+
return;
49+
}
50+
51+
auto branch = url_obj->pathParams[3];
4752
CTL_INF("Adding model to modellist with branch: " << branch);
4853

4954
auto rel = file_manager_utils::ToRelativeCortexDataPath(yaml_name);
@@ -66,11 +71,10 @@ cpp::result<DownloadTask, std::string> GetDownloadTask(
6671
const std::string& modelId, const std::string& branch = "main") {
6772
url_parser::Url url = {
6873
.protocol = "https",
69-
.host = ModelService::kHuggingFaceHost,
74+
.host = kHuggingFaceHost,
7075
.pathParams = {"api", "models", "cortexso", modelId, "tree", branch}};
7176

7277
httplib::Client cli(url.GetProtocolAndHost());
73-
// TODO: namh add header here
7478
auto result = curl_utils::SimpleGetJson(url.ToFullPath());
7579
if (result.has_error()) {
7680
return cpp::fail("Model " + modelId + " not found");
@@ -89,7 +93,7 @@ cpp::result<DownloadTask, std::string> GetDownloadTask(
8993
}
9094
url_parser::Url download_url = {
9195
.protocol = "https",
92-
.host = ModelService::kHuggingFaceHost,
96+
.host = kHuggingFaceHost,
9397
.pathParams = {"cortexso", modelId, "resolve", branch, path}};
9498

9599
auto local_path = model_container_path / path;
@@ -221,21 +225,24 @@ std::optional<config::ModelConfig> ModelService::GetDownloadedModel(
221225
cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
222226
const std::string& url, std::optional<std::string> temp_model_id) {
223227
auto url_obj = url_parser::FromUrlString(url);
228+
if (url_obj.has_error()) {
229+
return cpp::fail("Invalid url: " + url);
230+
}
224231

225-
if (url_obj.host == kHuggingFaceHost) {
226-
if (url_obj.pathParams[2] == "blob") {
227-
url_obj.pathParams[2] = "resolve";
232+
if (url_obj->host == kHuggingFaceHost) {
233+
if (url_obj->pathParams[2] == "blob") {
234+
url_obj->pathParams[2] = "resolve";
228235
}
229236
}
230-
auto author{url_obj.pathParams[0]};
231-
auto model_id{url_obj.pathParams[1]};
232-
auto file_name{url_obj.pathParams.back()};
237+
auto author{url_obj->pathParams[0]};
238+
auto model_id{url_obj->pathParams[1]};
239+
auto file_name{url_obj->pathParams.back()};
233240

234241
if (author == "cortexso") {
235242
return DownloadModelFromCortexsoAsync(model_id);
236243
}
237244

238-
if (url_obj.pathParams.size() < 5) {
245+
if (url_obj->pathParams.size() < 5) {
239246
return cpp::fail("Invalid url: " + url);
240247
}
241248

@@ -255,7 +262,7 @@ cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
255262
}
256263

257264
auto local_path{file_manager_utils::GetModelsContainerPath() /
258-
"huggingface.co" / author / model_id / file_name};
265+
kHuggingFaceHost / author / model_id / file_name};
259266

260267
try {
261268
std::filesystem::create_directories(local_path.parent_path());
@@ -265,7 +272,7 @@ cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
265272
std::filesystem::create_directories(local_path.parent_path());
266273
}
267274

268-
auto download_url = url_parser::FromUrl(url_obj);
275+
auto download_url = url_parser::FromUrl(url_obj.value());
269276
// this assume that the model being downloaded is a single gguf file
270277
auto downloadTask{DownloadTask{.id = model_id,
271278
.type = DownloadType::Model,
@@ -287,22 +294,25 @@ cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
287294
cpp::result<std::string, std::string> ModelService::HandleUrl(
288295
const std::string& url) {
289296
auto url_obj = url_parser::FromUrlString(url);
297+
if (url_obj.has_error()) {
298+
return cpp::fail("Invalid url: " + url);
299+
}
290300

291-
if (url_obj.host == kHuggingFaceHost) {
292-
if (url_obj.pathParams[2] == "blob") {
293-
url_obj.pathParams[2] = "resolve";
301+
if (url_obj->host == kHuggingFaceHost) {
302+
if (url_obj->pathParams[2] == "blob") {
303+
url_obj->pathParams[2] = "resolve";
294304
}
295305
}
296-
auto author{url_obj.pathParams[0]};
297-
auto model_id{url_obj.pathParams[1]};
298-
auto file_name{url_obj.pathParams.back()};
306+
auto author{url_obj->pathParams[0]};
307+
auto model_id{url_obj->pathParams[1]};
308+
auto file_name{url_obj->pathParams.back()};
299309

300310
if (author == "cortexso") {
301311
return DownloadModelFromCortexso(model_id);
302312
}
303313

304-
if (url_obj.pathParams.size() < 5) {
305-
if (url_obj.pathParams.size() < 2) {
314+
if (url_obj->pathParams.size() < 5) {
315+
if (url_obj->pathParams.size() < 2) {
306316
return cpp::fail("Invalid url: " + url);
307317
}
308318
return DownloadHuggingFaceGgufModel(author, model_id, std::nullopt);
@@ -320,7 +330,7 @@ cpp::result<std::string, std::string> ModelService::HandleUrl(
320330
}
321331

322332
auto local_path{file_manager_utils::GetModelsContainerPath() /
323-
"huggingface.co" / author / model_id / file_name};
333+
kHuggingFaceHost / author / model_id / file_name};
324334

325335
try {
326336
std::filesystem::create_directories(local_path.parent_path());
@@ -330,7 +340,7 @@ cpp::result<std::string, std::string> ModelService::HandleUrl(
330340
std::filesystem::create_directories(local_path.parent_path());
331341
}
332342

333-
auto download_url = url_parser::FromUrl(url_obj);
343+
auto download_url = url_parser::FromUrl(url_obj.value());
334344
// this assume that the model being downloaded is a single gguf file
335345
auto downloadTask{DownloadTask{.id = model_id,
336346
.type = DownloadType::Model,

engine/services/model_service.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,16 @@
88
#include "services/inference_service.h"
99

1010
struct StartParameterOverride {
11-
std::optional<bool> cache_enabled;
12-
std::optional<int> ngl;
13-
std::optional<int> n_parallel;
14-
std::optional<int> ctx_len;
15-
std::optional<std::string> custom_prompt_template;
16-
std::optional<std::string> cache_type;
11+
std::optional<bool> cache_enabled;
12+
std::optional<int> ngl;
13+
std::optional<int> n_parallel;
14+
std::optional<int> ctx_len;
15+
std::optional<std::string> custom_prompt_template;
16+
std::optional<std::string> cache_type;
1717
};
18+
1819
class ModelService {
1920
public:
20-
constexpr auto static kHuggingFaceHost = "huggingface.co";
21-
2221
explicit ModelService(std::shared_ptr<DownloadService> download_service)
2322
: download_service_{download_service} {};
2423

engine/test/components/test_url_parser.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ class UrlParserTestSuite : public ::testing::Test {
99
TEST_F(UrlParserTestSuite, TestParseUrlCorrectly) {
1010
auto url = url_parser::FromUrlString(kValidUrlWithOnlyPaths);
1111

12-
EXPECT_EQ(url.host, "jan.ai");
13-
EXPECT_EQ(url.protocol, "https");
14-
EXPECT_EQ(url.pathParams.size(), 2);
12+
EXPECT_EQ(url->host, "jan.ai");
13+
EXPECT_EQ(url->protocol, "https");
14+
EXPECT_EQ(url->pathParams.size(), 2);
1515
}
1616

1717
TEST_F(UrlParserTestSuite, ConstructUrlCorrectly) {

engine/utils/curl_utils.h

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
#include <yaml-cpp/node/node.h>
77
#include <yaml-cpp/node/parse.h>
88
#include <string>
9+
#include "utils/engine_constants.h"
10+
#include "utils/file_manager_utils.h"
911
#include "utils/result.hpp"
12+
#include "utils/url_parser.h"
1013

1114
namespace curl_utils {
1215
namespace {
@@ -18,12 +21,10 @@ size_t WriteCallback(void* contents, size_t size, size_t nmemb,
1821
}
1922
} // namespace
2023

21-
inline cpp::result<std::string, std::string> SimpleGet(
22-
const std::string& url,
23-
const std::optional<std::reference_wrapper<
24-
const std::unordered_map<std::string, std::string>>>& headers =
25-
std::nullopt) {
24+
inline std::optional<std::unordered_map<std::string, std::string>> GetHeaders(
25+
const std::string& url);
2626

27+
inline cpp::result<std::string, std::string> SimpleGet(const std::string& url) {
2728
// Initialize libcurl
2829
curl_global_init(CURL_GLOBAL_DEFAULT);
2930
auto curl = curl_easy_init();
@@ -32,10 +33,11 @@ inline cpp::result<std::string, std::string> SimpleGet(
3233
return cpp::fail("Failed to init CURL");
3334
}
3435

36+
auto headers = GetHeaders(url);
3537
if (headers.has_value()) {
3638
curl_slist* curl_headers = nullptr;
3739

38-
for (const auto& [key, value] : headers->get()) {
40+
for (const auto& [key, value] : headers.value()) {
3941
auto header = key + ": " + value;
4042
curl_headers = curl_slist_append(curl_headers, header.c_str());
4143
}
@@ -62,11 +64,8 @@ inline cpp::result<std::string, std::string> SimpleGet(
6264
}
6365

6466
inline cpp::result<YAML::Node, std::string> ReadRemoteYaml(
65-
const std::string& url,
66-
const std::optional<std::reference_wrapper<
67-
const std::unordered_map<std::string, std::string>>>& headers =
68-
std::nullopt) {
69-
auto result = SimpleGet(url, headers);
67+
const std::string& url) {
68+
auto result = SimpleGet(url);
7069
if (result.has_error()) {
7170
return cpp::fail(result.error());
7271
}
@@ -80,23 +79,48 @@ inline cpp::result<YAML::Node, std::string> ReadRemoteYaml(
8079
}
8180

8281
inline cpp::result<Json::Value, std::string> SimpleGetJson(
83-
const std::string& url,
84-
const std::optional<std::reference_wrapper<
85-
const std::unordered_map<std::string, std::string>>>& headers =
86-
std::nullopt) {
87-
88-
auto result = SimpleGet(url, headers);
82+
const std::string& url) {
83+
auto result = SimpleGet(url);
8984
if (result.has_error()) {
9085
return cpp::fail(result.error());
9186
}
87+
9288
Json::Value root;
9389
Json::Reader reader;
94-
9590
if (!reader.parse(result.value(), root)) {
9691
return cpp::fail("JSON from " + url +
9792
" parsing error: " + reader.getFormattedErrorMessages());
9893
}
9994

10095
return root;
10196
}
97+
98+
inline std::optional<std::unordered_map<std::string, std::string>> GetHeaders(
99+
const std::string& url) {
100+
auto url_obj = url_parser::FromUrlString(url);
101+
if (url_obj.has_error()) {
102+
return std::nullopt;
103+
}
104+
105+
if (url_obj->host == kHuggingFaceHost) {
106+
std::unordered_map<std::string, std::string> headers{};
107+
headers["Content-Type"] = "application/json";
108+
auto const& token = file_manager_utils::GetCortexConfig().huggingFaceToken;
109+
if (!token.empty()) {
110+
headers["Authorization"] = "Bearer " + token;
111+
}
112+
// TODO: namh printout last 6 characters of the token
113+
114+
return headers;
115+
}
116+
117+
if (url_obj->host == kGitHubHost) {
118+
std::unordered_map<std::string, std::string> headers{};
119+
headers["Accept"] = "application/vnd.github.v3+json";
120+
headers["User-Agent"] = "cortex";
121+
return headers;
122+
}
123+
124+
return std::nullopt;
125+
}
102126
} // namespace curl_utils

engine/utils/engine_constants.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,8 @@ constexpr const auto kPythonRuntimeRepo = "cortex.python";
1212
constexpr const auto kLlamaLibPath = "/engines/cortex.llamacpp";
1313
constexpr const auto kPythonRuntimeLibPath = "/engines/cortex.python";
1414
constexpr const auto kOnnxLibPath = "/engines/cortex.onnx";
15-
constexpr const auto kTensorrtLlmPath = "/engines/cortex.tensorrt-llm";
15+
constexpr const auto kTensorrtLlmPath = "/engines/cortex.tensorrt-llm";
16+
17+
// other constants
18+
constexpr auto static kHuggingFaceHost = "huggingface.co";
19+
constexpr auto static kGitHubHost = "api.github.com";

0 commit comments

Comments
 (0)