Skip to content

Commit aec6edc

Browse files
committed
make create_chat_completion stream include all events, only types on the chat legacy alias
1 parent 4e0115d commit aec6edc

File tree

6 files changed

+70
-18
lines changed

6 files changed

+70
-18
lines changed

openai/chat_completions.lua

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,7 @@ create_chat_stream_filter = function(chunk_callback)
114114
break
115115
end
116116
accumulation_buffer = rest
117-
do
118-
chunk = parse_completion_chunk(cjson.decode(json_blob))
119-
if chunk then
120-
chunk_callback(chunk)
121-
end
122-
end
117+
chunk_callback(cjson.decode(json_blob))
123118
end
124119
end
125120
return ...
@@ -159,7 +154,7 @@ do
159154
if stream_callback == nil then
160155
stream_callback = nil
161156
end
162-
local status, response = self.client:create_chat_completion(self.messages, {
157+
local status, response = self.client:chat(self.messages, {
163158
function_call = self.opts.function_call,
164159
functions = self.functions,
165160
model = self.opts.model,
@@ -186,7 +181,12 @@ do
186181
assert(type(response) == "string", "Expected string response from streaming output")
187182
local parts = { }
188183
local f = create_chat_stream_filter(function(c)
189-
return table.insert(parts, c.content)
184+
do
185+
local parsed = parse_completion_chunk(c)
186+
if parsed then
187+
return table.insert(parts, parsed.content)
188+
end
189+
end
190190
end)
191191
f(response)
192192
local message = {
@@ -253,5 +253,6 @@ end
253253
return {
254254
ChatSession = ChatSession,
255255
test_message = test_message,
256-
create_chat_stream_filter = create_chat_stream_filter
256+
create_chat_stream_filter = create_chat_stream_filter,
257+
parse_completion_chunk = parse_completion_chunk
257258
}

openai/chat_completions.moon

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ create_chat_stream_filter = (chunk_callback) ->
140140
break
141141

142142
accumulation_buffer = rest
143-
if chunk = parse_completion_chunk cjson.decode json_blob
144-
chunk_callback chunk
143+
chunk_callback cjson.decode json_blob
144+
-- if chunk = parse_completion_chunk cjson.decode json_blob
145+
-- chunk_callback chunk
145146

146147
...
147148

@@ -186,7 +187,7 @@ class ChatSession
186187
-- append_response: should the response be appended to the chat history
187188
-- stream_callback: provide a function to enable streaming output. function will receive each chunk as it's generated
188189
generate_response: (append_response=true, stream_callback=nil) =>
189-
status, response = @client\create_chat_completion @messages, {
190+
status, response = @client\chat @messages, {
190191
function_call: @opts.function_call -- override the default function call behavior
191192
functions: @functions
192193
model: @opts.model
@@ -214,7 +215,8 @@ class ChatSession
214215

215216
parts = {}
216217
f = create_chat_stream_filter (c) ->
217-
table.insert parts, c.content
218+
if parsed = parse_completion_chunk c
219+
table.insert parts, parsed.content
218220

219221
f response
220222
message = {
@@ -250,4 +252,5 @@ class ChatSession
250252
:ChatSession
251253
:test_message
252254
:create_chat_stream_filter
255+
:parse_completion_chunk
253256
}

openai/init.lua

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,26 @@ do
5050
end
5151
return self:_request("POST", "/chat/completions", payload, nil, stream_filter)
5252
end,
53-
chat = function(self, ...)
54-
return self:create_chat_completion(...)
53+
chat = function(self, messages, opts, chunk_callback)
54+
if chunk_callback == nil then
55+
chunk_callback = nil
56+
end
57+
do
58+
local cb = chunk_callback
59+
if cb then
60+
local parse_completion_chunk
61+
parse_completion_chunk = require("openai.chat_completions").parse_completion_chunk
62+
chunk_callback = function(chunk)
63+
do
64+
local delta = parse_completion_chunk(chunk)
65+
if delta then
66+
return cb(delta)
67+
end
68+
end
69+
end
70+
end
71+
end
72+
return self:create_chat_completion(messages, opts, chunk_callback)
5573
end,
5674
completion = function(self, prompt, opts)
5775
local payload = {

openai/init.moon

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,17 @@ class OpenAI
5858
@_request "POST", "/chat/completions", payload, nil, stream_filter
5959

6060
-- legacy alias for create_chat_completion (for backward compatibility)
61-
chat: (...) => @create_chat_completion ...
61+
-- the legacy method also has the filtered chunk responses instead of pushing
62+
-- through every event through the callback
63+
chat: (messages, opts, chunk_callback=nil) =>
64+
if cb = chunk_callback
65+
import parse_completion_chunk from require "openai.chat_completions"
66+
chunk_callback = (chunk) ->
67+
-- filter chunk to only pass through chat.completion.chunk with parsed delta
68+
if delta = parse_completion_chunk chunk
69+
cb delta
70+
71+
@create_chat_completion messages, opts, chunk_callback
6272

6373
-- call /completions
6474
-- opts: additional parameters as described in https://platform.openai.com/docs/api-reference/completions

openai/responses.moon

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ parse_response_stream_chunk = (chunk) ->
130130
raw: chunk
131131
}
132132

133-
-- takes a stream of string chunks and extracts SSE json objects out of the
134-
-- stream, calling chunk_callback when a parsed object is found
133+
134+
-- creates a ltn12 compatible filter function that will call chunk_callback
135+
-- for each parsed json chunk from the server-sent events api response
135136
create_response_stream_filter = (chunk_callback) ->
136137
assert types.function(chunk_callback), "Must provide chunk_callback function when streaming response"
137138

spec/openai_spec.moon

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,25 @@ describe "OpenAI API Client", ->
402402

403403
assert.same "This is a chat response.", response
404404

405+
it "processes streaming chunks with create_chat_completion (raw)", ->
406+
client = OpenAI "test-api-key"
407+
408+
chunks_received = {}
409+
stream_callback = (chunk) ->
410+
table.insert chunks_received, chunk
411+
412+
status, response = client\create_chat_completion {
413+
{role: "user", content: "tell me a joke"}
414+
}, {stream: true}, stream_callback
415+
416+
assert.same 200, status
417+
-- create_chat_completion passes raw JSON chunks
418+
assert.same {
419+
{object: "chat.completion.chunk", choices: {{delta: {content: "This is "}, index: 0}}}
420+
{object: "chat.completion.chunk", choices: {{delta: {content: "a chat "}, index: 1}}}
421+
{object: "chat.completion.chunk", choices: {{delta: {content: "response."}, index: 2}}}
422+
}, chunks_received
423+
405424
describe "responses", ->
406425
it "creates a response (raw API)", ->
407426
client = OpenAI "test-api-key"

0 commit comments

Comments
 (0)