Skip to content

Commit 7690d40

Browse files
committed
migrate the chat completions api details into separate module
1 parent a66eed6 commit 7690d40

File tree

5 files changed

+170
-151
lines changed

5 files changed

+170
-151
lines changed

lua-openai-dev-1.rockspec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ build = {
2424
type = "builtin",
2525
modules = {
2626
["openai"] = "openai/init.lua",
27-
["openai.chat_session"] = "openai/chat_session.lua",
27+
["openai.chat_completions"] = "openai/chat_completions.lua",
2828
["openai.responses"] = "openai/responses.lua"
2929
}
3030
}
Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,65 @@ local parse_error_message = types.partial({
6565
code = empty + types.string:tag("code")
6666
})
6767
})
68+
local parse_completion_chunk = types.partial({
69+
object = "chat.completion.chunk",
70+
choices = types.shape({
71+
types.partial({
72+
delta = types.partial({
73+
["content"] = types.string:tag("content")
74+
}),
75+
index = types.number:tag("index")
76+
})
77+
})
78+
})
79+
local consume_json_head
80+
do
81+
local C, S, P
82+
do
83+
local _obj_0 = require("lpeg")
84+
C, S, P = _obj_0.C, _obj_0.S, _obj_0.P
85+
end
86+
local consume_json = P(function(str, pos)
87+
local str_len = #str
88+
for k = pos + 1, str_len do
89+
local candidate = str:sub(pos, k)
90+
local parsed = false
91+
pcall(function()
92+
parsed = cjson.decode(candidate)
93+
end)
94+
if parsed then
95+
return k + 1
96+
end
97+
end
98+
return nil
99+
end)
100+
consume_json_head = S("\t\n\r ") ^ 0 * P("data: ") * C(consume_json) * C(P(1) ^ 0)
101+
end
102+
local create_chat_stream_filter
103+
create_chat_stream_filter = function(chunk_callback)
104+
assert(types["function"](chunk_callback), "Must provide chunk_callback function when streaming response")
105+
local accumulation_buffer = ""
106+
return function(...)
107+
local chunk = ...
108+
if type(chunk) == "string" then
109+
accumulation_buffer = accumulation_buffer .. chunk
110+
while true do
111+
local json_blob, rest = consume_json_head:match(accumulation_buffer)
112+
if not (json_blob) then
113+
break
114+
end
115+
accumulation_buffer = rest
116+
do
117+
chunk = parse_completion_chunk(cjson.decode(json_blob))
118+
if chunk then
119+
chunk_callback(chunk)
120+
end
121+
end
122+
end
123+
end
124+
return ...
125+
end
126+
end
68127
local ChatSession
69128
do
70129
local _class_0
@@ -125,7 +184,7 @@ do
125184
if stream_callback then
126185
assert(type(response) == "string", "Expected string response from streaming output")
127186
local parts = { }
128-
local f = self.client:create_stream_filter(function(c)
187+
local f = create_chat_stream_filter(function(c)
129188
return table.insert(parts, c.content)
130189
end)
131190
f(response)
@@ -185,5 +244,11 @@ do
185244
end
186245
return {
187246
ChatSession = ChatSession,
188-
test_message = test_message
247+
test_message = test_message,
248+
test_function = test_function,
249+
parse_chat_response = parse_chat_response,
250+
parse_error_message = parse_error_message,
251+
parse_completion_chunk = parse_completion_chunk,
252+
consume_json_head = consume_json_head,
253+
create_chat_stream_filter = create_chat_stream_filter
189254
}
Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
-- This is the legacy API https://platform.openai.com/docs/api-reference/chat
2+
13
cjson = require "cjson"
24
import types from require "tableshape"
35

@@ -66,6 +68,82 @@ parse_error_message = types.partial {
6668
}
6769
}
6870

71+
-- sse streaming chunk format from chat completions API
72+
-- {
73+
-- "id": "chatcmpl-XXX",
74+
-- "object": "chat.completion.chunk",
75+
-- "created": 1682979397,
76+
-- "model": "gpt-3.5-turbo-0301",
77+
-- "choices": [
78+
-- {
79+
-- "delta": {
80+
-- "content": " hello"
81+
-- },
82+
-- "index": 0,
83+
-- "finish_reason": null
84+
-- }
85+
-- ]
86+
-- }
87+
88+
parse_completion_chunk = types.partial {
89+
object: "chat.completion.chunk"
90+
-- not sure of the whole range of chunks, so for now we strictly parse an append
91+
choices: types.shape {
92+
types.partial {
93+
delta: types.partial {
94+
"content": types.string\tag "content"
95+
}
96+
index: types.number\tag "index"
97+
}
98+
}
99+
}
100+
101+
-- lpeg pattern to read a json data block from the front of a string, returns
102+
-- the json blob and the rest of the string if it could parse one
103+
consume_json_head = do
104+
import C, S, P from require "lpeg"
105+
106+
-- this pattern reads from the front just enough characters to consume a
107+
-- valid json object
108+
consume_json = P (str, pos) ->
109+
str_len = #str
110+
for k=pos+1,str_len
111+
candidate = str\sub pos, k
112+
parsed = false
113+
pcall -> parsed = cjson.decode candidate
114+
if parsed
115+
return k + 1
116+
117+
return nil -- fail
118+
119+
S("\t\n\r ")^0 * P("data: ") * C(consume_json) * C(P(1)^0)
120+
121+
122+
-- creates a ltn12 compatible filter function that will call chunk_callback
123+
-- for each parsed json chunk from the server-sent events api response
124+
create_chat_stream_filter = (chunk_callback) ->
125+
assert types.function(chunk_callback), "Must provide chunk_callback function when streaming response"
126+
127+
accumulation_buffer = ""
128+
129+
(...) ->
130+
chunk = ...
131+
132+
if type(chunk) == "string"
133+
accumulation_buffer ..= chunk
134+
135+
while true
136+
json_blob, rest = consume_json_head\match accumulation_buffer
137+
unless json_blob
138+
break
139+
140+
accumulation_buffer = rest
141+
if chunk = parse_completion_chunk cjson.decode json_blob
142+
chunk_callback chunk
143+
144+
...
145+
146+
69147
-- handles appending response for each call to chat
70148
-- TODO: hadle appending the streaming response to the output
71149
class ChatSession
@@ -133,7 +211,7 @@ class ChatSession
133211
"Expected string response from streaming output"
134212

135213
parts = {}
136-
f = @client\create_stream_filter (c) ->
214+
f = create_chat_stream_filter (c) ->
137215
table.insert parts, c.content
138216

139217
f response
@@ -159,4 +237,13 @@ class ChatSession
159237
-- response is missing for function_calls, so we return the entire message object
160238
out.response or out.message
161239

162-
{:ChatSession, :test_message}
240+
{
241+
:ChatSession
242+
:test_message
243+
:test_function
244+
:parse_chat_response
245+
:parse_error_message
246+
:parse_completion_chunk
247+
:consume_json_head
248+
:create_chat_stream_filter
249+
}

openai/init.lua

Lines changed: 8 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,88 +4,32 @@ local cjson = require("cjson")
44
local unpack = table.unpack or unpack
55
local types
66
types = require("tableshape").types
7-
local ChatSession, test_message
8-
do
9-
local _obj_0 = require("openai.chat_session")
10-
ChatSession, test_message = _obj_0.ChatSession, _obj_0.test_message
11-
end
127
local parse_url = require("socket.url").parse
13-
local parse_completion_chunk = types.partial({
14-
object = "chat.completion.chunk",
15-
choices = types.shape({
16-
types.partial({
17-
delta = types.partial({
18-
["content"] = types.string:tag("content")
19-
}),
20-
index = types.number:tag("index")
21-
})
22-
})
23-
})
24-
local consume_json_head
25-
do
26-
local C, S, P
27-
do
28-
local _obj_0 = require("lpeg")
29-
C, S, P = _obj_0.C, _obj_0.S, _obj_0.P
30-
end
31-
local consume_json = P(function(str, pos)
32-
local str_len = #str
33-
for k = pos + 1, str_len do
34-
local candidate = str:sub(pos, k)
35-
local parsed = false
36-
pcall(function()
37-
parsed = cjson.decode(candidate)
38-
end)
39-
if parsed then
40-
return k + 1
41-
end
42-
end
43-
return nil
44-
end)
45-
consume_json_head = S("\t\n\r ") ^ 0 * P("data: ") * C(consume_json) * C(P(1) ^ 0)
46-
end
478
local OpenAI
489
do
4910
local _class_0
5011
local _base_0 = {
5112
api_base = "https://api.openai.com/v1",
5213
default_model = "gpt-4.1",
5314
new_chat_session = function(self, ...)
15+
local ChatSession
16+
ChatSession = require("openai.chat_completions").ChatSession
5417
return ChatSession(self, ...)
5518
end,
5619
new_response_chat_session = function(self, ...)
5720
local ResponsesChatSession
5821
ResponsesChatSession = require("openai.responses").ResponsesChatSession
5922
return ResponsesChatSession(self, ...)
6023
end,
61-
create_stream_filter = function(self, chunk_callback)
62-
assert(types["function"](chunk_callback), "Must provide chunk_callback function when streaming response")
63-
local accumulation_buffer = ""
64-
return function(...)
65-
local chunk = ...
66-
if type(chunk) == "string" then
67-
accumulation_buffer = accumulation_buffer .. chunk
68-
while true do
69-
local json_blob, rest = consume_json_head:match(accumulation_buffer)
70-
if not (json_blob) then
71-
break
72-
end
73-
accumulation_buffer = rest
74-
do
75-
chunk = parse_completion_chunk(cjson.decode(json_blob))
76-
if chunk then
77-
chunk_callback(chunk)
78-
end
79-
end
80-
end
81-
end
82-
return ...
83-
end
84-
end,
8524
chat = function(self, messages, opts, chunk_callback)
8625
if chunk_callback == nil then
8726
chunk_callback = nil
8827
end
28+
local test_message, create_chat_stream_filter
29+
do
30+
local _obj_0 = require("openai.chat_completions")
31+
test_message, create_chat_stream_filter = _obj_0.test_message, _obj_0.create_chat_stream_filter
32+
end
8933
local test_messages = types.array_of(test_message)
9034
assert(test_messages(messages))
9135
local payload = {
@@ -99,7 +43,7 @@ do
9943
end
10044
local stream_filter
10145
if payload.stream then
102-
stream_filter = self:create_stream_filter(chunk_callback)
46+
stream_filter = create_chat_stream_filter(chunk_callback)
10347
end
10448
return self:_request("POST", "/chat/completions", payload, nil, stream_filter)
10549
end,
@@ -294,7 +238,6 @@ do
294238
end
295239
return {
296240
OpenAI = OpenAI,
297-
ChatSession = ChatSession,
298241
VERSION = VERSION,
299242
new = OpenAI
300243
}

0 commit comments

Comments
 (0)