Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 0aad82c

Browse files
committed
feat(#1512): simplify cortex run
1 parent a6b64f8 commit 0aad82c

File tree

12 files changed

+220
-94
lines changed

12 files changed

+220
-94
lines changed

engine/cli/command_line_parser.cc

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "command_line_parser.h"
22
#include <memory>
3+
#include <optional>
34
#include "commands/chat_cmd.h"
45
#include "commands/chat_completion_cmd.h"
56
#include "commands/cortex_upd_cmd.h"
@@ -82,27 +83,30 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) {
8283
// Check new update
8384
#ifdef CORTEX_CPP_VERSION
8485
if (cml_data_.check_upd) {
85-
// TODO(sang) find a better way to handle
86-
// This is an extremely ungly way to deal with connection
87-
// hang when network down
88-
std::atomic<bool> done = false;
89-
std::thread t([&]() {
90-
if (auto latest_version =
91-
commands::CheckNewUpdate(commands::kTimeoutCheckUpdate);
92-
latest_version.has_value() && *latest_version != CORTEX_CPP_VERSION) {
93-
CLI_LOG("\nA new release of cortex is available: "
94-
<< CORTEX_CPP_VERSION << " -> " << *latest_version);
95-
CLI_LOG("To upgrade, run: " << commands::GetRole()
96-
<< commands::GetCortexBinary()
97-
<< " update");
86+
if (strcmp(CORTEX_CPP_VERSION, "default_version") != 0) {
87+
// TODO(sang) find a better way to handle
88+
// This is an extremely ugly way to deal with connection
89+
// hang when network down
90+
std::atomic<bool> done = false;
91+
std::thread t([&]() {
92+
if (auto latest_version =
93+
commands::CheckNewUpdate(commands::kTimeoutCheckUpdate);
94+
latest_version.has_value() &&
95+
*latest_version != CORTEX_CPP_VERSION) {
96+
CLI_LOG("\nA new release of cortex is available: "
97+
<< CORTEX_CPP_VERSION << " -> " << *latest_version);
98+
CLI_LOG("To upgrade, run: " << commands::GetRole()
99+
<< commands::GetCortexBinary()
100+
<< " update");
101+
}
102+
done = true;
103+
});
104+
// Do not wait for http connection timeout
105+
t.detach();
106+
int retry = 10;
107+
while (!done && retry--) {
108+
std::this_thread::sleep_for(commands::kTimeoutCheckUpdate / 10);
98109
}
99-
done = true;
100-
});
101-
// Do not wait for http connection timeout
102-
t.detach();
103-
int retry = 10;
104-
while (!done && retry--) {
105-
std::this_thread::sleep_for(commands::kTimeoutCheckUpdate / 10);
106110
}
107111
}
108112
#endif
@@ -143,11 +147,6 @@ void CommandLineParser::SetupCommonCommands() {
143147
run_cmd->callback([this, run_cmd] {
144148
if (std::exchange(executed_, true))
145149
return;
146-
if (cml_data_.model_id.empty()) {
147-
CLI_LOG("[model_id] is required\n");
148-
CLI_LOG(run_cmd->help());
149-
return;
150-
}
151150
commands::RunCmd rc(cml_data_.config.apiServerHost,
152151
std::stoi(cml_data_.config.apiServerPort),
153152
cml_data_.model_id, download_service_);
@@ -247,12 +246,19 @@ void CommandLineParser::SetupModelCommands() {
247246

248247
auto list_models_cmd =
249248
models_cmd->add_subcommand("list", "List all models locally");
249+
list_models_cmd->add_option("filter", cml_data_.filter, "Filter model id");
250+
list_models_cmd->add_flag("-e,--engine", cml_data_.display_engine,
251+
"Display engine");
252+
list_models_cmd->add_flag("-v,--version", cml_data_.display_version,
253+
"Display version");
250254
list_models_cmd->group(kSubcommands);
251255
list_models_cmd->callback([this]() {
252256
if (std::exchange(executed_, true))
253257
return;
254258
commands::ModelListCmd().Exec(cml_data_.config.apiServerHost,
255-
std::stoi(cml_data_.config.apiServerPort));
259+
std::stoi(cml_data_.config.apiServerPort),
260+
cml_data_.filter, cml_data_.display_engine,
261+
cml_data_.display_version);
256262
});
257263

258264
auto get_models_cmd =

engine/cli/command_line_parser.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ class CommandLineParser {
4545
std::string cortex_version;
4646
bool check_upd = true;
4747
bool run_detach = false;
48+
49+
// for model list
50+
bool display_engine = false;
51+
bool display_version = false;
52+
std::string filter = "";
53+
4854
int port;
4955
config_yaml_utils::CortexConfig config;
5056
std::unordered_map<std::string, std::string> model_update_options;
Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
#include "model_list_cmd.h"
2+
#include <json/reader.h>
3+
#include <json/value.h>
24
#include <iostream>
35

46
#include <vector>
57
#include "httplib.h"
6-
#include "json/json.h"
78
#include "server_start_cmd.h"
89
#include "utils/logging_utils.h"
10+
#include "utils/string_utils.h"
911
// clang-format off
1012
#include <tabulate/table.hpp>
1113
// clang-format on
14+
1215
namespace commands {
16+
using namespace tabulate;
17+
using Row_t =
18+
std::vector<variant<std::string, const char*, string_view, Table>>;
1319

14-
void ModelListCmd::Exec(const std::string& host, int port) {
20+
void ModelListCmd::Exec(const std::string& host, int port, std::string filter,
21+
bool display_engine, bool display_version) {
1522
// Start server if server is not started yet
1623
if (!commands::IsServerAlive(host, port)) {
1724
CLI_LOG("Starting server ...");
@@ -21,27 +28,48 @@ void ModelListCmd::Exec(const std::string& host, int port) {
2128
}
2229
}
2330

24-
tabulate::Table table;
31+
Table table;
32+
std::vector<std::string> column_headers{"(Index)", "ID"};
33+
if (display_engine) {
34+
column_headers.push_back("Engine");
35+
}
36+
if (display_version) {
37+
column_headers.push_back("Version");
38+
}
2539

26-
table.add_row({"(Index)", "ID", "model alias", "engine", "version"});
27-
table.format().font_color(tabulate::Color::green);
40+
Row_t header{column_headers.begin(), column_headers.end()};
41+
table.add_row(header);
42+
table.format().font_color(Color::green);
2843
int count = 0;
2944
// Iterate through directory
3045

3146
httplib::Client cli(host + ":" + std::to_string(port));
3247
auto res = cli.Get("/v1/models");
3348
if (res) {
3449
if (res->status == httplib::StatusCode::OK_200) {
35-
// CLI_LOG(res->body);
3650
Json::Value body;
3751
Json::Reader reader;
3852
reader.parse(res->body, body);
3953
if (!body["data"].isNull()) {
4054
for (auto const& v : body["data"]) {
55+
auto model_id = v["model"].asString();
56+
if (!filter.empty() &&
57+
!string_utils::StringContainsIgnoreCase(model_id, filter)) {
58+
continue;
59+
}
60+
4161
count += 1;
42-
table.add_row({std::to_string(count), v["model"].asString(),
43-
v["model_alias"].asString(), v["engine"].asString(),
44-
v["version"].asString()});
62+
63+
std::vector<std::string> row = {std::to_string(count),
64+
v["model"].asString()};
65+
if (display_engine) {
66+
row.push_back(v["engine"].asString());
67+
}
68+
if (display_version) {
69+
row.push_back(v["version"].asString());
70+
}
71+
72+
table.add_row({row.begin(), row.end()});
4573
}
4674
}
4775
} else {
@@ -54,24 +82,6 @@ void ModelListCmd::Exec(const std::string& host, int port) {
5482
return;
5583
}
5684

57-
for (int i = 0; i < 5; i++) {
58-
table[0][i]
59-
.format()
60-
.font_color(tabulate::Color::white) // Set font color
61-
.font_style({tabulate::FontStyle::bold})
62-
.font_align(tabulate::FontAlign::center);
63-
}
64-
for (int i = 1; i <= count; i++) {
65-
table[i][0] //index value
66-
.format()
67-
.font_color(tabulate::Color::white) // Set font color
68-
.font_align(tabulate::FontAlign::center);
69-
table[i][4] //version value
70-
.format()
71-
.font_align(tabulate::FontAlign::center);
72-
}
7385
std::cout << table << std::endl;
7486
}
75-
}
76-
77-
; // namespace commands
87+
}; // namespace commands
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#pragma once
2+
23
#include <string>
34

45
namespace commands {
56

67
class ModelListCmd {
78
public:
8-
void Exec(const std::string& host, int port);
9+
void Exec(const std::string& host, int port, std::string filter,
10+
bool display_engine = false, bool display_version = false);
911
};
1012
} // namespace commands

engine/cli/commands/run_cmd.cc

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,54 @@ void RunCmd::Exec(bool run_detach) {
3030
config::YamlHandler yaml_handler;
3131
auto address = host_ + ":" + std::to_string(port_);
3232

33-
// Download model if it does not exist
3433
{
35-
auto related_models_ids = modellist_handler.FindRelatedModel(model_handle_);
36-
if (related_models_ids.has_error() || related_models_ids.value().empty()) {
37-
auto result = model_service_.DownloadModel(model_handle_);
38-
model_id = result.value();
39-
CTL_INF("model_id: " << model_id.value());
40-
} else if (related_models_ids.value().size() == 1) {
41-
model_id = related_models_ids.value().front();
42-
} else { // multiple models with nearly same name found
43-
auto selection = cli_selection_utils::PrintSelection(
44-
related_models_ids.value(), "Local Models: (press enter to select)");
45-
if (!selection.has_value()) {
34+
if (model_handle_.empty()) {
35+
auto all_local_models = modellist_handler.LoadModelList();
36+
if (all_local_models.has_error() || all_local_models.value().empty()) {
37+
CLI_LOG("No local models available!");
4638
return;
4739
}
48-
model_id = selection.value();
49-
CLI_LOG("Selected: " << selection.value());
40+
41+
if (all_local_models.value().size() == 1) {
42+
model_id = all_local_models.value().front().model;
43+
} else {
44+
std::vector<std::string> model_id_list{};
45+
for (const auto& model : all_local_models.value()) {
46+
model_id_list.push_back(model.model);
47+
}
48+
49+
auto selection = cli_selection_utils::PrintSelection(
50+
model_id_list, "Please select an option");
51+
if (!selection.has_value()) {
52+
return;
53+
}
54+
model_id = selection.value();
55+
CLI_LOG("Selected: " << selection.value());
56+
}
57+
} else {
58+
auto related_models_ids =
59+
modellist_handler.FindRelatedModel(model_handle_);
60+
if (related_models_ids.has_error() ||
61+
related_models_ids.value().empty()) {
62+
auto result = model_service_.DownloadModel(model_handle_);
63+
if (result.has_error()) {
64+
CLI_LOG("Model " << model_handle_ << " not found!");
65+
return;
66+
}
67+
model_id = result.value();
68+
CTL_INF("model_id: " << model_id.value());
69+
} else if (related_models_ids.value().size() == 1) {
70+
model_id = related_models_ids.value().front();
71+
} else { // multiple models with nearly same name found
72+
auto selection = cli_selection_utils::PrintSelection(
73+
related_models_ids.value(),
74+
"Local Models: (press enter to select)");
75+
if (!selection.has_value()) {
76+
return;
77+
}
78+
model_id = selection.value();
79+
CLI_LOG("Selected: " << selection.value());
80+
}
5081
}
5182
}
5283

engine/cli/commands/server_start_cmd.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
#include "server_start_cmd.h"
22
#include "commands/cortex_upd_cmd.h"
3-
#include "httplib.h"
4-
#include "trantor/utils/Logger.h"
53
#include "utils/cortex_utils.h"
64
#include "utils/file_manager_utils.h"
7-
#include "utils/logging_utils.h"
85

96
namespace commands {
107

@@ -124,4 +121,4 @@ bool ServerStartCmd::Exec(const std::string& host, int port) {
124121
return true;
125122
}
126123

127-
}; // namespace commands
124+
}; // namespace commands

engine/cli/commands/server_start_cmd.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
23
#include <string>
34
#include "httplib.h"
45

@@ -18,4 +19,4 @@ class ServerStartCmd {
1819
ServerStartCmd();
1920
bool Exec(const std::string& host, int port);
2021
};
21-
} // namespace commands
22+
} // namespace commands

engine/database/models.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,17 +275,11 @@ cpp::result<bool, std::string> Models::DeleteModelEntry(
275275

276276
cpp::result<std::vector<std::string>, std::string> Models::FindRelatedModel(
277277
const std::string& identifier) const {
278-
// TODO (namh): add check for alias as well
279278
try {
280279
std::vector<std::string> related_models;
281280
SQLite::Statement query(
282-
db_,
283-
"SELECT model_id FROM models WHERE model_id LIKE ? OR model_id LIKE ? "
284-
"OR model_id LIKE ? OR model_id LIKE ?");
285-
query.bind(1, identifier + ":%");
286-
query.bind(2, "%:" + identifier);
287-
query.bind(3, "%:" + identifier + ":%");
288-
query.bind(4, identifier);
281+
db_, "SELECT model_id FROM models WHERE model_id LIKE ?");
282+
query.bind(1, "%" + identifier + "%");
289283

290284
while (query.executeStep()) {
291285
related_models.push_back(query.getColumn(0).getString());

0 commit comments

Comments
 (0)