Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 6cda3ee

Browse files
authored
fix: add n_parallel to model yaml config (#1571)
* fix: add n_parallel to model yaml config * fix: models update
1 parent 106ddd4 commit 6cda3ee

File tree

5 files changed

+24
-0
lines changed

5 files changed

+24
-0
lines changed

engine/cli/command_line_parser.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ void CommandLineParser::ModelUpdate(CLI::App* parent) {
546546
"stream",
547547
"ngl",
548548
"ctx_len",
549+
"n_parallel",
549550
"engine",
550551
"prompt_template",
551552
"system_template",

engine/cli/commands/model_upd_cmd.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ void ModelUpdCmd::UpdateConfig(Json::Value& data, const std::string& key,
223223
data["ctx_len"] = static_cast<int>(f);
224224
});
225225
}},
226+
{"n_parallel",
227+
[this](Json::Value &data, const std::string& k, const std::string& v) {
228+
UpdateNumericField(k, v, [&data](float f) {
229+
data["n_parallel"] = static_cast<int>(f);
230+
});
231+
}},
226232
{"tp",
227233
[this](Json::Value &data, const std::string& k, const std::string& v) {
228234
UpdateNumericField(k, v, [&data](float f) {

engine/config/model_config.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct ModelConfig {
2222
bool stream = std::numeric_limits<bool>::quiet_NaN();
2323
int ngl = std::numeric_limits<int>::quiet_NaN();
2424
int ctx_len = std::numeric_limits<int>::quiet_NaN();
25+
int n_parallel = 1;
2526
std::string engine;
2627
std::string prompt_template;
2728
std::string system_template;
@@ -125,6 +126,8 @@ struct ModelConfig {
125126
ngl = json["ngl"].asInt();
126127
if (json.isMember("ctx_len"))
127128
ctx_len = json["ctx_len"].asInt();
129+
if (json.isMember("n_parallel"))
130+
n_parallel = json["n_parallel"].asInt();
128131
if (json.isMember("engine"))
129132
engine = json["engine"].asString();
130133
if (json.isMember("prompt_template"))
@@ -204,6 +207,7 @@ struct ModelConfig {
204207
obj["min_keep"] = min_keep;
205208
obj["ngl"] = ngl;
206209
obj["ctx_len"] = ctx_len;
210+
obj["n_parallel"] = n_parallel;
207211
obj["engine"] = engine;
208212
obj["prompt_template"] = prompt_template;
209213
obj["system_template"] = system_template;
@@ -313,6 +317,8 @@ struct ModelConfig {
313317
if (ctx_len != std::numeric_limits<int>::quiet_NaN())
314318
oss << format_utils::print_kv("ctx_len", std::to_string(ctx_len),
315319
format_utils::MAGENTA);
320+
oss << format_utils::print_kv("n_parallel", std::to_string(n_parallel),
321+
format_utils::MAGENTA);
316322
if (ngl != std::numeric_limits<int>::quiet_NaN())
317323
oss << format_utils::print_kv("ngl", std::to_string(ngl),
318324
format_utils::MAGENTA);

engine/config/yaml_config.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ void YamlHandler::ModelConfigFromYaml() {
113113
tmp.ngl = yaml_node_["ngl"].as<int>();
114114
if (yaml_node_["ctx_len"])
115115
tmp.ctx_len = yaml_node_["ctx_len"].as<int>();
116+
if (yaml_node_["n_parallel"])
117+
tmp.n_parallel = yaml_node_["n_parallel"].as<int>();
116118
if (yaml_node_["tp"])
117119
tmp.tp = yaml_node_["tp"].as<int>();
118120
if (yaml_node_["stream"])
@@ -216,6 +218,8 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
216218
yaml_node_["ngl"] = model_config_.ngl;
217219
if (!std::isnan(static_cast<double>(model_config_.ctx_len)))
218220
yaml_node_["ctx_len"] = model_config_.ctx_len;
221+
if (!std::isnan(static_cast<double>(model_config_.n_parallel)))
222+
yaml_node_["n_parallel"] = model_config_.n_parallel;
219223
if (!std::isnan(static_cast<double>(model_config_.tp)))
220224
yaml_node_["tp"] = model_config_.tp;
221225
if (!std::isnan(static_cast<double>(model_config_.stream)))
@@ -368,6 +372,7 @@ void YamlHandler::WriteYamlFile(const std::string& file_path) const {
368372
outFile << format_utils::writeKeyValue(
369373
"ctx_len", yaml_node_["ctx_len"],
370374
"llama.context_length | 0 or undefined = loaded from model");
375+
outFile << format_utils::writeKeyValue("n_parallel", yaml_node_["n_parallel"]);
371376
outFile << format_utils::writeKeyValue("ngl", yaml_node_["ngl"],
372377
"Undefined = loaded from model");
373378
outFile << "# END OPTIONAL\n";

engine/test/components/test_yaml_handler.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ top_p: 0.9
6262
temperature: 0.7
6363
max_tokens: 100
6464
stream: true
65+
n_parallel: 2
6566
stop:
6667
- "END"
6768
files:
@@ -82,6 +83,7 @@ stream: true
8283
EXPECT_FLOAT_EQ(config.temperature, 0.7f);
8384
EXPECT_EQ(config.max_tokens, 100);
8485
EXPECT_TRUE(config.stream);
86+
EXPECT_EQ(config.n_parallel, 2);
8587
EXPECT_EQ(config.stop.size(), 1);
8688
EXPECT_EQ(config.stop[0], "END");
8789
EXPECT_EQ(config.files.size(), 1);
@@ -101,6 +103,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) {
101103
new_config.temperature = 0.8f;
102104
new_config.max_tokens = 200;
103105
new_config.stream = false;
106+
new_config.n_parallel = 2;
104107
new_config.stop = {"STOP", "END"};
105108
new_config.files = {"updated_file1.gguf", "updated_file2.gguf"};
106109

@@ -116,6 +119,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) {
116119
EXPECT_FLOAT_EQ(config.temperature, 0.8f);
117120
EXPECT_EQ(config.max_tokens, 200);
118121
EXPECT_FALSE(config.stream);
122+
EXPECT_EQ(config.n_parallel, 2);
119123
EXPECT_EQ(config.stop.size(), 2);
120124
EXPECT_EQ(config.stop[0], "STOP");
121125
EXPECT_EQ(config.stop[1], "END");
@@ -135,6 +139,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) {
135139
new_config.temperature = 0.6f;
136140
new_config.max_tokens = 150;
137141
new_config.stream = true;
142+
new_config.n_parallel = 2;
138143
new_config.stop = {"HALT"};
139144
new_config.files = {"write_test_file.gguf"};
140145

@@ -158,6 +163,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) {
158163
EXPECT_FLOAT_EQ(read_config.temperature, 0.6f);
159164
EXPECT_EQ(read_config.max_tokens, 150);
160165
EXPECT_TRUE(read_config.stream);
166+
EXPECT_EQ(read_config.n_parallel, 2);
161167
EXPECT_EQ(read_config.stop.size(), 1);
162168
EXPECT_EQ(read_config.stop[0], "HALT");
163169
EXPECT_EQ(read_config.files.size(), 1);

0 commit comments

Comments
 (0)