Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit c8eab8a

Browse files
chore: add embedding capabilities
1 parent f648f63 commit c8eab8a

File tree

4 files changed

+402
-3
lines changed

4 files changed

+402
-3
lines changed

docs/docs/capabilities/embeddings.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,98 @@ title: Embeddings
66
:::
77

88
cortex.cpp now support embeddings endpoint with fully OpenAI compatible.
9+
10+
For embeddings API usage please refer to [API references](/api-reference#tag/chat/POST/v1/embeddings). This tutorial show you how to use embeddings in cortex with openai python SDK.
11+
12+
## Embedding with openai compatible
13+
14+
### 1. Start server and run model
15+
16+
```
17+
cortex run llama3.1:8b-gguf-q4-km
18+
```
19+
20+
### 2. Create script `embeddings.py` with this content
21+
22+
```
23+
from datetime import datetime
24+
from openai import OpenAI
25+
from pydantic import BaseModel
26+
ENDPOINT = "http://localhost:39281/v1"
27+
MODEL = "llama3.1:8bb-gguf-q4-km"
28+
client = OpenAI(
29+
base_url=ENDPOINT,
30+
api_key="not-needed"
31+
)
32+
```
33+
34+
### 3. Create embeddings
35+
36+
```
37+
response = client.embeddings.create(input = "embedding", model=MODEL, encoding_format="base64")
38+
print(response)
39+
```
40+
41+
The reponse will be like this
42+
43+
```
44+
CreateEmbeddingResponse(
45+
data=[
46+
Embedding(
47+
embedding='hjuAPOD8TryuPU8...',
48+
index=0,
49+
object='embedding'
50+
)
51+
],
52+
model='meta-llama3.1-8b-instruct',
53+
object='list',
54+
usage=Usage(
55+
prompt_tokens=2,
56+
total_tokens=2
57+
)
58+
)
59+
```
60+
61+
62+
The output embeddings is encoded as base64 string. Default the model will output the embeddings in float mode.
63+
64+
```
65+
response = client.embeddings.create(input = "embedding", model=MODEL)
66+
print(response)
67+
```
68+
69+
Result will be
70+
71+
```
72+
CreateEmbeddingResponse(
73+
data=[
74+
Embedding(
75+
embedding=[0.1, 0.3, 0.4 ....],
76+
index=0,
77+
object='embedding'
78+
)
79+
],
80+
model='meta-llama3.1-8b-instruct',
81+
object='list',
82+
usage=Usage(
83+
prompt_tokens=2,
84+
total_tokens=2
85+
)
86+
)
87+
```
88+
89+
Cortex also supports all input types as [OpenAI](https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-input).
90+
91+
```sh
92+
# input as string
93+
response = client.embeddings.create(input = "embedding", model=MODEL)
94+
95+
# input as array of string
96+
response = client.embeddings.create(input = ["embedding"], model=MODEL)
97+
98+
# input as array of tokens
99+
response = client.embeddings.create(input = [12,44,123], model=MODEL)
100+
101+
# input as array of arrays contain tokens
102+
response = client.embeddings.create(input = [[912,312,54],[12,433,1241]], model=MODEL)
103+
```

docs/static/openapi/cortex.json

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@
190190
]
191191
}
192192
},
193-
"v1/embeddings": {
193+
"/v1/embeddings": {
194194
"post": {
195195
"summary": "Create embeddings",
196196
"description": "Creates an embedding vector representing the input text.",
@@ -204,22 +204,27 @@
204204
"input": {
205205
"oneOf": [
206206
{
207-
"type": "string"
207+
"type": "string",
208+
"description":"The string that will be turned into an embedding."
208209
},
209210
{
210211
"type": "array",
212+
"description" : "The array of strings that will be turned into an embedding.",
211213
"items": {
212214
"type": "string"
213215
}
214216
},
215217
{
216218
"type": "array",
219+
"description": "The array of integers that will be turned into an embedding.",
217220
"items": {
218221
"type": "integer"
222+
219223
}
220224
},
221225
{
222226
"type": "array",
227+
"description" : "The array of arrays containing integers that will be turned into an embedding.",
223228
"items": {
224229
"type": "array",
225230
"items": {
@@ -290,7 +295,10 @@
290295
}
291296
}
292297
}
293-
}
298+
},
299+
"tags": [
300+
"Embeddings"
301+
]
294302
}
295303
},
296304
"/v1/chat/completions": {
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
#include "remote_engine.h"
2+
#include <sstream>
3+
4+
// Static callback function for CURL
5+
static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, std::string* data) {
6+
data->append(ptr, size * nmemb);
7+
return size * nmemb;
8+
}
9+
10+
RemoteEngine::RemoteEngine() {
11+
curl_global_init(CURL_GLOBAL_ALL);
12+
}
13+
14+
RemoteEngine::~RemoteEngine() {
15+
curl_global_cleanup();
16+
}
17+
18+
CurlResponse RemoteEngine::makeRequest(const std::string& url,
19+
const std::string& api_key,
20+
const std::string& body,
21+
const std::string& method) {
22+
CURL* curl = curl_easy_init();
23+
CurlResponse response;
24+
25+
if (!curl) {
26+
response.error = true;
27+
response.error_message = "Failed to initialize CURL";
28+
return response;
29+
}
30+
31+
// Set up headers
32+
struct curl_slist* headers = nullptr;
33+
if (!api_key.empty()) {
34+
std::string auth_header = renderTemplate(config_.api_key_template, {{"api_key", api_key}});
35+
headers = curl_slist_append(headers, auth_header.c_str());
36+
}
37+
headers = curl_slist_append(headers, "Content-Type: application/json");
38+
39+
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
40+
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
41+
42+
if (method == "POST") {
43+
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str());
44+
}
45+
46+
std::string response_string;
47+
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
48+
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_string);
49+
50+
CURLcode res = curl_easy_perform(curl);
51+
if (res != CURLE_OK) {
52+
response.error = true;
53+
response.error_message = curl_easy_strerror(res);
54+
} else {
55+
response.body = response_string;
56+
}
57+
58+
curl_slist_free_all(headers);
59+
curl_easy_cleanup(curl);
60+
return response;
61+
}
62+
63+
std::string RemoteEngine::renderTemplate(const std::string& templ,
64+
const std::unordered_map<std::string, std::string>& values) {
65+
std::string result = templ;
66+
for (const auto& [key, value] : values) {
67+
std::string placeholder = "{{" + key + "}}";
68+
size_t pos = result.find(placeholder);
69+
if (pos != std::string::npos) {
70+
result.replace(pos, placeholder.length(), value);
71+
}
72+
}
73+
return result;
74+
}
75+
76+
Json::Value RemoteEngine::transformRequest(const Json::Value& input, const std::string& type) {
77+
if (!config_.transform_req.isMember(type)) {
78+
return input;
79+
}
80+
81+
Json::Value output = input;
82+
const Json::Value& transforms = config_.transform_req[type];
83+
84+
for (const auto& transform : transforms) {
85+
if (transform.isString()) {
86+
// Handle template-based transformation
87+
if (transform.asString().find("template") != std::string::npos) {
88+
// Implement template rendering logic here
89+
continue;
90+
}
91+
} else if (transform.isObject()) {
92+
// Handle key mapping transformations
93+
for (const auto& key : transform.getMemberNames()) {
94+
if (input.isMember(key)) {
95+
output[transform[key].asString()] = input[key];
96+
output.removeMember(key);
97+
}
98+
}
99+
}
100+
}
101+
return output;
102+
}
103+
104+
void RemoteEngine::GetModels(std::shared_ptr<Json::Value> json_body,
105+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
106+
if (!json_body->isMember("url") || !json_body->isMember("api_key")) {
107+
Json::Value error;
108+
error["error"] = "Missing required fields: url or api_key";
109+
callback(Json::Value(), std::move(error));
110+
return;
111+
}
112+
113+
const std::string& url = (*json_body)["url"].asString();
114+
const std::string& api_key = (*json_body)["api_key"].asString();
115+
116+
auto response = makeRequest(url, api_key, "", "GET");
117+
118+
if (response.error) {
119+
Json::Value error;
120+
error["error"] = response.error_message;
121+
callback(Json::Value(), std::move(error));
122+
return;
123+
}
124+
125+
Json::Value response_json;
126+
Json::Reader reader;
127+
if (!reader.parse(response.body, response_json)) {
128+
Json::Value error;
129+
error["error"] = "Failed to parse response";
130+
callback(Json::Value(), std::move(error));
131+
return;
132+
}
133+
134+
callback(std::move(response_json), Json::Value());
135+
}
136+
137+
void RemoteEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_body,
138+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
139+
if (!json_body->isMember("url") || !json_body->isMember("api_key") ||
140+
!json_body->isMember("request_body")) {
141+
Json::Value error;
142+
error["error"] = "Missing required fields: url, api_key, or request_body";
143+
callback(Json::Value(), std::move(error));
144+
return;
145+
}
146+
147+
const std::string& url = (*json_body)["url"].asString();
148+
const std::string& api_key = (*json_body)["api_key"].asString();
149+
150+
Json::Value transformed_request = transformRequest((*json_body)["request_body"], "chat_completion");
151+
152+
Json::FastWriter writer;
153+
std::string request_body = writer.write(transformed_request);
154+
155+
auto response = makeRequest(url, api_key, request_body);
156+
157+
if (response.error) {
158+
Json::Value error;
159+
error["error"] = response.error_message;
160+
callback(Json::Value(), std::move(error));
161+
return;
162+
}
163+
164+
Json::Value response_json;
165+
Json::Reader reader;
166+
if (!reader.parse(response.body, response_json)) {
167+
Json::Value error;
168+
error["error"] = "Failed to parse response";
169+
callback(Json::Value(), std::move(error));
170+
return;
171+
}
172+
173+
callback(std::move(response_json), Json::Value());
174+
}
175+
176+
bool RemoteEngine::LoadConfig(const std::string& yaml_path) {
177+
try {
178+
YAML::Node config = YAML::LoadFile(yaml_path);
179+
180+
if (config["api_key_template"]) {
181+
config_.api_key_template = config["api_key_template"].as<std::string>();
182+
}
183+
184+
if (config["TransformReq"]) {
185+
Json::Reader reader;
186+
reader.parse(config["TransformReq"].as<std::string>(), config_.transform_req);
187+
}
188+
189+
if (config["TransformResp"]) {
190+
Json::Reader reader;
191+
reader.parse(config["TransformResp"].as<std::string>(), config_.transform_resp);
192+
}
193+
194+
return true;
195+
} catch (const YAML::Exception& e) {
196+
LOG_ERROR << "Failed to load config: " << e.what();
197+
return false;
198+
}
199+
}
200+
201+
// Implement other virtual functions with minimal functionality
202+
void RemoteEngine::HandleEmbedding(std::shared_ptr<Json::Value>,
203+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
204+
callback(Json::Value(), Json::Value());
205+
}
206+
207+
void RemoteEngine::LoadModel(std::shared_ptr<Json::Value>,
208+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
209+
callback(Json::Value(), Json::Value());
210+
}
211+
212+
void RemoteEngine::UnloadModel(std::shared_ptr<Json::Value>,
213+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
214+
callback(Json::Value(), Json::Value());
215+
}
216+
217+
void RemoteEngine::GetModelStatus(std::shared_ptr<Json::Value>,
218+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
219+
callback(Json::Value(), Json::Value());
220+
}
221+
222+
bool RemoteEngine::IsSupported(const std::string&) {
223+
return true;
224+
}
225+
226+
bool RemoteEngine::SetFileLogger(int, const std::string&) {
227+
return true;
228+
}
229+
230+
void RemoteEngine::SetLogLevel(trantor::Logger::LogLevel) {
231+
}

0 commit comments

Comments
 (0)