Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 53e8f20

Browse files
authored
feat: add customize parameters to /v1/models/start (#1544)
* feat: add customize parameters to /v1/models/start * fix: ngl
1 parent 4486408 commit 53e8f20

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

engine/controllers/models.cc

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,31 @@ void Models::StartModel(
345345
return;
346346
auto config = file_manager_utils::GetCortexConfig();
347347
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
348-
auto custom_prompt_template =
349-
(*(req->getJsonObject())).get("prompt_template", "").asString();
348+
StartParameterOverride params_override;
349+
if (auto& o = (*(req->getJsonObject()))["prompt_template"]; !o.isNull()) {
350+
params_override.custom_prompt_template = o.asString();
351+
}
352+
353+
if (auto& o = (*(req->getJsonObject()))["cache_enabled"]; !o.isNull()) {
354+
params_override.cache_enabled = o.asBool();
355+
}
356+
357+
if (auto& o = (*(req->getJsonObject()))["ngl"]; !o.isNull()) {
358+
params_override.ngl = o.asInt();
359+
}
360+
361+
if (auto& o = (*(req->getJsonObject()))["n_parallel"]; !o.isNull()) {
362+
params_override.n_parallel = o.asInt();
363+
}
364+
365+
if (auto& o = (*(req->getJsonObject()))["ctx_len"]; !o.isNull()) {
366+
params_override.ctx_len = o.asInt();
367+
}
368+
369+
if (auto& o = (*(req->getJsonObject()))["cache_type"]; !o.isNull()) {
370+
params_override.cache_type = o.asString();
371+
}
372+
350373
auto model_entry = model_service_->GetDownloadedModel(model_handle);
351374
if (!model_entry.has_value()) {
352375
Json::Value ret;
@@ -375,9 +398,9 @@ void Models::StartModel(
375398
return;
376399
}
377400

378-
auto result = model_service_->StartModel(
379-
config.apiServerHost, std::stoi(config.apiServerPort), model_handle,
380-
custom_prompt_template);
401+
auto result = model_service_->StartModel(config.apiServerHost,
402+
std::stoi(config.apiServerPort),
403+
model_handle, params_override);
381404
if (result.has_error()) {
382405
Json::Value ret;
383406
ret["message"] = result.error();

engine/services/model_service.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ cpp::result<void, std::string> ModelService::DeleteModel(
570570

571571
cpp::result<bool, std::string> ModelService::StartModel(
572572
const std::string& host, int port, const std::string& model_handle,
573-
std::optional<std::string> custom_prompt_template) {
573+
const StartParameterOverride& params_override) {
574574
namespace fs = std::filesystem;
575575
namespace fmu = file_manager_utils;
576576
cortex::db::Models modellist_handler;
@@ -600,9 +600,9 @@ cpp::result<bool, std::string> ModelService::StartModel(
600600
return false;
601601
}
602602
json_data["model"] = model_handle;
603-
if (!custom_prompt_template.value_or("").empty()) {
604-
auto parse_prompt_result =
605-
string_utils::ParsePrompt(custom_prompt_template.value());
603+
if (auto& cpt = params_override.custom_prompt_template;
604+
!cpt.value_or("").empty()) {
605+
auto parse_prompt_result = string_utils::ParsePrompt(cpt.value());
606606
json_data["system_prompt"] = parse_prompt_result.system_prompt;
607607
json_data["user_prompt"] = parse_prompt_result.user_prompt;
608608
json_data["ai_prompt"] = parse_prompt_result.ai_prompt;
@@ -612,6 +612,18 @@ cpp::result<bool, std::string> ModelService::StartModel(
612612
json_data["ai_prompt"] = mc.ai_template;
613613
}
614614

615+
#define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \
616+
if (param_override.param_name) { \
617+
json_obj[#param_name] = param_override.param_name.value(); \
618+
}
619+
620+
ASSIGN_IF_PRESENT(json_data, params_override, cache_enabled);
621+
ASSIGN_IF_PRESENT(json_data, params_override, ngl);
622+
ASSIGN_IF_PRESENT(json_data, params_override, n_parallel);
623+
ASSIGN_IF_PRESENT(json_data, params_override, ctx_len);
624+
ASSIGN_IF_PRESENT(json_data, params_override, cache_type);
625+
#undef ASSIGN_IF_PRESENT;
626+
615627
CTL_INF(json_data.toStyledString());
616628
assert(!!inference_svc_);
617629
auto ir =

engine/services/model_service.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
#include "services/download_service.h"
88
#include "services/inference_service.h"
99

10+
struct StartParameterOverride {
11+
std::optional<bool> cache_enabled;
12+
std::optional<int> ngl;
13+
std::optional<int> n_parallel;
14+
std::optional<int> ctx_len;
15+
std::optional<std::string> custom_prompt_template;
16+
std::optional<std::string> cache_type;
17+
};
1018
class ModelService {
1119
public:
1220
constexpr auto static kHuggingFaceHost = "huggingface.co";
@@ -46,7 +54,7 @@ class ModelService {
4654

4755
cpp::result<bool, std::string> StartModel(
4856
const std::string& host, int port, const std::string& model_handle,
49-
std::optional<std::string> custom_prompt_template = std::nullopt);
57+
const StartParameterOverride& params_override);
5058

5159
cpp::result<bool, std::string> StopModel(const std::string& host, int port,
5260
const std::string& model_handle);

0 commit comments

Comments
 (0)