Skip to content

Commit 7cf20dc

Browse files
michalkulakowskiMichal Kulakowski
authored andcommitted
Responses api init
1 parent d1de370 commit 7cf20dc

9 files changed

Lines changed: 1588 additions & 9 deletions

File tree

src/http_rest_api_handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ static Status createV3HttpPayload(
531531
return Status(StatusCode::JSON_INVALID, "model field is not a string");
532532
}
533533

534-
bool isTextGenerationEndpoint = uri.find("completions") != std::string_view::npos;
534+
bool isTextGenerationEndpoint = (uri.find("completions") != std::string_view::npos) || (uri.find("responses") != std::string_view::npos);
535535
if (isTextGenerationEndpoint) {
536536
auto streamIt = parsedJson->FindMember("stream");
537537
if (streamIt != parsedJson->MemberEnd()) {

src/llm/apis/openai_completions.cpp

Lines changed: 790 additions & 0 deletions
Large diffs are not rendered by default.

src/llm/apis/openai_completions.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ namespace ovms {
4747
enum class Endpoint {
4848
CHAT_COMPLETIONS,
4949
COMPLETIONS,
50+
RESPONSES,
5051
TOKENIZE,
5152
};
5253

@@ -69,12 +70,16 @@ class OpenAIChatCompletionsHandler {
6970
std::chrono::time_point<std::chrono::system_clock> created;
7071
ov::genai::Tokenizer tokenizer;
7172
size_t processedTokens = 0; // tracks overall number of tokens processed by the pipeline
73+
size_t responsesStreamingSequenceNumber = 0;
74+
bool responsesStreamingInitialized = false;
75+
std::string responsesStreamingOutputText;
7276

7377
// Output parser is used to parse chat completions response to extract specific fields like tool calls and reasoning.
7478
std::unique_ptr<OutputParser> outputParser = nullptr;
7579

7680
absl::Status parseCompletionsPart();
7781
absl::Status parseChatCompletionsPart(std::optional<uint32_t> maxTokensLimit, std::optional<std::string> allowedLocalMediaPath, std::optional<std::vector<std::string>> allowedMediaDomains);
82+
absl::Status parseResponsesPart(std::optional<uint32_t> maxTokensLimit, std::optional<std::string> allowedLocalMediaPath, std::optional<std::vector<std::string>> allowedMediaDomains);
7883
absl::Status parseCommonPart(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, std::optional<uint32_t> maxModelLength);
7984

8085
ParsedOutput parseOutputIfNeeded(const std::vector<int64_t>& generatedIds);

src/llm/language_model/continuous_batching/servable.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,15 @@ static ov::genai::GenerationOutput prepareEmptyStopReasonOutput() {
103103
return out;
104104
}
105105

106+
static ov::genai::GenerationOutput prepareEmptyNoneReasonOutput() {
107+
static ov::genai::GenerationOutput out = {
108+
std::vector<int64_t>(), // generated_ids
109+
std::vector<float>(), // generated_log_probs
110+
0.0f, // score
111+
ov::genai::GenerationFinishReason::NONE};
112+
return out;
113+
}
114+
106115
absl::Status ContinuousBatchingServable::readCompleteExecutionResults(std::shared_ptr<GenAiServableExecutionContext>& executionContext) {
107116
auto cbExecutionContext = std::static_pointer_cast<ContinuousBatchingServableExecutionContext>(executionContext);
108117
if (cbExecutionContext->payload.client->isDisconnected()) {
@@ -136,7 +145,11 @@ absl::Status ContinuousBatchingServable::readPartialExecutionResults(std::shared
136145
ov::genai::GenerationOutputs generationOutputs = cbExecutionContext->generationHandle->read();
137146
RET_CHECK(generationOutputs.size() <= 1); // TODO: Support multiple generations
138147
if (generationOutputs.size() == 0) {
139-
cbExecutionContext->generationOutputs = {prepareEmptyStopReasonOutput()};
148+
if (cbExecutionContext->generationHandle->get_status() == ov::genai::GenerationStatus::RUNNING) {
149+
cbExecutionContext->generationOutputs = {prepareEmptyNoneReasonOutput()};
150+
} else {
151+
cbExecutionContext->generationOutputs = {prepareEmptyStopReasonOutput()};
152+
}
140153
} else {
141154
cbExecutionContext->generationOutputs = {generationOutputs.begin()->second};
142155
}

src/llm/servable.cpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,12 @@ absl::Status GenAiServable::loadRequest(std::shared_ptr<GenAiServableExecutionCo
6868
executionContext->endpoint = Endpoint::CHAT_COMPLETIONS;
6969
} else if (payload.uri == "/v3/completions" || payload.uri == "/v3/v1/completions") {
7070
executionContext->endpoint = Endpoint::COMPLETIONS;
71+
} else if (payload.uri == "/v3/responses" || payload.uri == "/v3/v1/responses") {
72+
executionContext->endpoint = Endpoint::RESPONSES;
7173
} else if (TokenizeParser::isTokenizeEndpoint(payload.uri)) {
7274
executionContext->endpoint = Endpoint::TOKENIZE;
7375
} else {
74-
return absl::InvalidArgumentError("Wrong endpoint. Allowed endpoints: /v3/chat/completions, /v3/completions");
76+
return absl::InvalidArgumentError("Wrong endpoint. Allowed endpoints: /v3/chat/completions, /v3/completions, /v3/responses, /v3/tokenize");
7577
}
7678
executionContext->payload = payload;
7779
return absl::OkStatus();
@@ -204,6 +206,50 @@ absl::Status GenAiServable::prepareInputs(std::shared_ptr<GenAiServableExecution
204206
}
205207
break;
206208
}
209+
case Endpoint::RESPONSES: {
210+
if (executionContext->apiHandler->getChatHistory().size() > 0) {
211+
#if (PYTHON_DISABLE == 0)
212+
bool success;
213+
if (executionContext->apiHandler->getProcessedJson().size() > 0) {
214+
success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, executionContext->apiHandler->getProcessedJson(), inputText);
215+
} else {
216+
success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, executionContext->payload.body, inputText);
217+
}
218+
if (!success) {
219+
return absl::Status(absl::StatusCode::kInvalidArgument, inputText);
220+
}
221+
#else
222+
ov::genai::ChatHistory& chatHistory = executionContext->apiHandler->getChatHistory();
223+
constexpr bool add_generation_prompt = true;
224+
auto toolsStatus = executionContext->apiHandler->parseToolsToJsonContainer();
225+
if (!toolsStatus.ok()) {
226+
return toolsStatus.status();
227+
}
228+
const auto& tools = toolsStatus.value();
229+
auto chatTemplateKwargsStatus = executionContext->apiHandler->parseChatTemplateKwargsToJsonContainer();
230+
if (!chatTemplateKwargsStatus.ok()) {
231+
return chatTemplateKwargsStatus.status();
232+
}
233+
const auto& chatTemplateKwargs = chatTemplateKwargsStatus.value();
234+
try {
235+
inputText = getProperties()->tokenizer.apply_chat_template(chatHistory, add_generation_prompt, {}, tools, chatTemplateKwargs);
236+
} catch (const std::exception& e) {
237+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to apply chat template: {}", e.what());
238+
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to apply chat template. The model either does not have chat template or has an invalid one.");
239+
}
240+
#endif
241+
if (inputText.size() == 0) {
242+
return absl::Status(absl::StatusCode::kInvalidArgument, "Final prompt after applying chat template is empty");
243+
}
244+
} else {
245+
auto prompt = executionContext->apiHandler->getPrompt();
246+
if (!prompt.has_value()) {
247+
return absl::Status(absl::StatusCode::kInvalidArgument, "input is missing");
248+
}
249+
inputText = prompt.value();
250+
}
251+
break;
252+
}
207253
case Endpoint::COMPLETIONS: {
208254
inputText = executionContext->apiHandler->getPrompt().value();
209255
break;
@@ -277,8 +323,12 @@ absl::Status GenAiServable::preparePartialResponse(std::shared_ptr<GenAiServable
277323
if (!serializedChunk.empty()) {
278324
executionContext->response = wrapTextInServerSideEventMessage(serializedChunk);
279325
}
280-
if (executionContext->apiHandler->getStreamOptions().includeUsage)
281-
executionContext->response += wrapTextInServerSideEventMessage(executionContext->apiHandler->serializeStreamingUsageChunk());
326+
if (executionContext->apiHandler->getStreamOptions().includeUsage) {
327+
std::string usageChunk = executionContext->apiHandler->serializeStreamingUsageChunk();
328+
if (!usageChunk.empty()) {
329+
executionContext->response += wrapTextInServerSideEventMessage(usageChunk);
330+
}
331+
}
282332

283333
executionContext->response += wrapTextInServerSideEventMessage("[DONE]");
284334

src/llm/visual_language_model/continuous_batching/servable.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ absl::Status VisualLanguageModelServable::loadRequest(std::shared_ptr<GenAiServa
4545
}
4646
if (payload.uri == "/v3/chat/completions" || payload.uri == "/v3/v1/chat/completions") {
4747
executionContext->endpoint = Endpoint::CHAT_COMPLETIONS;
48+
} else if (payload.uri == "/v3/responses" || payload.uri == "/v3/v1/responses") {
49+
executionContext->endpoint = Endpoint::RESPONSES;
4850
} else if (TokenizeParser::isTokenizeEndpoint(payload.uri)) {
4951
executionContext->endpoint = Endpoint::TOKENIZE;
5052
} else {
51-
return absl::InvalidArgumentError("Wrong endpoint. VLM Servable allowed only on /v3/chat/completions endpoint or /v3/tokenize");
53+
return absl::InvalidArgumentError("Wrong endpoint. VLM Servable allowed only on /v3/chat/completions, /v3/responses endpoint or /v3/tokenize");
5254
}
5355
executionContext->payload = payload;
5456
return absl::OkStatus();
@@ -67,7 +69,7 @@ absl::Status VisualLanguageModelServable::prepareInputs(std::shared_ptr<GenAiSer
6769
if (vlmExecutionContext->apiHandler == nullptr) {
6870
return absl::Status(absl::StatusCode::kInvalidArgument, "API handler is not initialized");
6971
}
70-
if (executionContext->endpoint == Endpoint::CHAT_COMPLETIONS) {
72+
if (executionContext->endpoint == Endpoint::CHAT_COMPLETIONS || executionContext->endpoint == Endpoint::RESPONSES) {
7173
ov::genai::ChatHistory& chatHistory = vlmExecutionContext->apiHandler->getChatHistory();
7274

7375
for (size_t i = 0; i < chatHistory.size(); i++) {

src/llm/visual_language_model/legacy/servable.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,12 @@ absl::Status VisualLanguageModelLegacyServable::loadRequest(std::shared_ptr<GenA
5353
}
5454
if (payload.uri == "/v3/chat/completions" || payload.uri == "/v3/v1/chat/completions") {
5555
executionContext->endpoint = Endpoint::CHAT_COMPLETIONS;
56+
} else if (payload.uri == "/v3/responses" || payload.uri == "/v3/v1/responses") {
57+
executionContext->endpoint = Endpoint::RESPONSES;
5658
} else if (TokenizeParser::isTokenizeEndpoint(payload.uri)) {
5759
executionContext->endpoint = Endpoint::TOKENIZE;
5860
} else {
59-
return absl::InvalidArgumentError("Wrong endpoint. VLM Servable allowed only on /v3/chat/completions endpoint or /v3/tokenize");
61+
return absl::InvalidArgumentError("Wrong endpoint. VLM Servable allowed only on /v3/chat/completions, /v3/responses endpoint or /v3/tokenize");
6062
}
6163
executionContext->payload = payload;
6264
return absl::OkStatus();
@@ -237,7 +239,7 @@ absl::Status VisualLanguageModelLegacyServable::prepareInputs(std::shared_ptr<Ge
237239
if (vlmExecutionContext->apiHandler == nullptr) {
238240
return absl::Status(absl::StatusCode::kInvalidArgument, "API handler is not initialized");
239241
}
240-
if (executionContext->endpoint == Endpoint::CHAT_COMPLETIONS) {
242+
if (executionContext->endpoint == Endpoint::CHAT_COMPLETIONS || executionContext->endpoint == Endpoint::RESPONSES) {
241243
ov::genai::ChatHistory& chatHistory = vlmExecutionContext->apiHandler->getChatHistory();
242244

243245
for (size_t i = 0; i < chatHistory.size(); i++) {

0 commit comments

Comments
 (0)