Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 5f9a0a4

Browse files
committed
fix: Final update for embedding to support both single and vector of input string
1 parent 9218cd9 commit 5f9a0a4

File tree

1 file changed

+32
-39
lines changed

1 file changed

+32
-39
lines changed

controllers/llamaCPP.cc

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,8 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
2121

2222
// --------------------------------------------
2323

24-
std::string create_embedding_payload(const std::vector<float> &embedding,
24+
Json::Value create_embedding_payload(const std::vector<float> &embedding,
2525
int prompt_tokens) {
26-
Json::Value root;
27-
28-
root["object"] = "list";
29-
30-
Json::Value dataArray(Json::arrayValue);
3126
Json::Value dataItem;
3227

3328
dataItem["object"] = "embedding";
@@ -39,20 +34,7 @@ std::string create_embedding_payload(const std::vector<float> &embedding,
3934
dataItem["embedding"] = embeddingArray;
4035
dataItem["index"] = 0;
4136

42-
dataArray.append(dataItem);
43-
root["data"] = dataArray;
44-
45-
root["model"] = "_";
46-
47-
Json::Value usage;
48-
usage["prompt_tokens"] = prompt_tokens;
49-
usage["total_tokens"] = prompt_tokens; // Assuming total tokens equals prompt
50-
// tokens in this context
51-
root["usage"] = usage;
52-
53-
Json::StreamWriterBuilder writer;
54-
writer["indentation"] = ""; // Compact output
55-
return Json::writeString(writer, root);
37+
return dataItem;
5638
}
5739

5840
std::string create_full_return_json(const std::string &id,
@@ -406,31 +388,42 @@ void llamaCPP::embedding(
406388
std::function<void(const HttpResponsePtr &)> &&callback) {
407389
const auto &jsonBody = req->getJsonObject();
408390

409-
json prompt;
410-
if (jsonBody->isMember("input") != 0) {
411-
if ((*jsonBody)["input"].isString()) {
412-
prompt = (*jsonBody)["input"].asString();
413-
} else if ((*jsonBody)["input"].isArray()) {
414-
const auto &inputArray = (*jsonBody)["input"];
415-
std::vector<std::string> inputStrings;
416-
for (const auto &input : inputArray) {
417-
if (input.isString()) {
418-
inputStrings.push_back(input.asString());
391+
Json::Value responseData(Json::arrayValue);
392+
393+
if (jsonBody->isMember("input")) {
394+
const Json::Value &input = (*jsonBody)["input"];
395+
if (input.isString()) {
396+
// Process the single string input
397+
const int task_id = llama.request_completion(
398+
{{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1);
399+
task_result result = llama.next_result(task_id);
400+
std::vector<float> embedding_result = result.result_json["embedding"];
401+
responseData.append(create_embedding_payload(embedding_result, 0));
402+
} else if (input.isArray()) {
403+
// Process each element in the array input
404+
for (const auto &elem : input) {
405+
if (elem.isString()) {
406+
const int task_id = llama.request_completion(
407+
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, -1);
408+
task_result result = llama.next_result(task_id);
409+
std::vector<float> embedding_result = result.result_json["embedding"];
410+
responseData.append(create_embedding_payload(embedding_result, 0));
419411
}
420412
}
421-
prompt = inputStrings;
422413
}
423-
} else {
424-
prompt = "";
425414
}
426415

427-
const int task_id = llama.request_completion(
428-
{{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
429-
task_result result = llama.next_result(task_id);
430-
std::vector<float> embedding_result = result.result_json["embedding"];
431416
auto resp = nitro_utils::nitroHttpResponse();
432-
std::string embedding_resp = create_embedding_payload(embedding_result, 0);
433-
resp->setBody(embedding_resp);
417+
Json::Value root;
418+
root["data"] = responseData;
419+
root["model"] = "_";
420+
root["object"] = "list";
421+
Json::Value usage;
422+
usage["prompt_tokens"] = 0;
423+
usage["total_tokens"] = 0;
424+
root["usage"] = usage;
425+
426+
resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
434427
resp->setContentTypeString("application/json");
435428
callback(resp);
436429
return;

0 commit comments

Comments
 (0)