Skip to content

Commit 504f752

Browse files
committed
move chat session to separate module
1 parent c29caa1 commit 504f752

File tree

4 files changed

+357
-338
lines changed

4 files changed

+357
-338
lines changed

openai/chat_session.lua

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
local cjson = require("cjson")
2+
local types
3+
types = require("tableshape").types
4+
local empty = (types["nil"] + types.literal(cjson.null)):describe("nullable")
5+
local content_format = types.string + types.array_of(types.one_of({
6+
types.shape({
7+
type = "text",
8+
text = types.string
9+
}),
10+
types.shape({
11+
type = "image_url",
12+
image_url = types.string + types.partial({
13+
url = types.string
14+
})
15+
})
16+
}))
17+
local test_message = types.one_of({
18+
types.partial({
19+
role = types.one_of({
20+
"system",
21+
"user",
22+
"assistant"
23+
}),
24+
content = empty + content_format,
25+
name = empty + types.string,
26+
function_call = empty + types.table
27+
}),
28+
types.partial({
29+
role = types.one_of({
30+
"function"
31+
}),
32+
name = types.string,
33+
content = empty + types.string
34+
})
35+
})
36+
local test_function = types.shape({
37+
name = types.string,
38+
description = types["nil"] + types.string,
39+
parameters = types["nil"] + types.table
40+
})
41+
local parse_chat_response = types.partial({
42+
usage = types.table:tag("usage"),
43+
choices = types.partial({
44+
types.partial({
45+
message = types.one_of({
46+
types.partial({
47+
role = "assistant",
48+
content = types.string + empty,
49+
function_call = types.partial({
50+
name = types.string,
51+
arguments = types.string
52+
})
53+
}),
54+
types.partial({
55+
role = "assistant",
56+
content = types.string:tag("response")
57+
})
58+
}):tag("message")
59+
})
60+
})
61+
})
62+
local parse_error_message = types.partial({
63+
error = types.partial({
64+
message = types.string:tag("message"),
65+
code = empty + types.string:tag("code")
66+
})
67+
})
68+
local ChatSession
69+
do
70+
local _class_0
71+
local _base_0 = {
72+
append_message = function(self, m, ...)
73+
assert(test_message(m))
74+
table.insert(self.messages, m)
75+
if select("#", ...) > 0 then
76+
return self:append_message(...)
77+
end
78+
end,
79+
last_message = function(self)
80+
return self.messages[#self.messages]
81+
end,
82+
send = function(self, message, stream_callback)
83+
if stream_callback == nil then
84+
stream_callback = nil
85+
end
86+
if type(message) == "string" then
87+
message = {
88+
role = "user",
89+
content = message
90+
}
91+
end
92+
self:append_message(message)
93+
return self:generate_response(true, stream_callback)
94+
end,
95+
generate_response = function(self, append_response, stream_callback)
96+
if append_response == nil then
97+
append_response = true
98+
end
99+
if stream_callback == nil then
100+
stream_callback = nil
101+
end
102+
local status, response = self.client:chat(self.messages, {
103+
function_call = self.opts.function_call,
104+
functions = self.functions,
105+
model = self.opts.model,
106+
temperature = self.opts.temperature,
107+
stream = stream_callback and true or nil,
108+
response_format = self.opts.response_format
109+
}, stream_callback)
110+
if status ~= 200 then
111+
local err_msg = "Bad status: " .. tostring(status)
112+
do
113+
local err = parse_error_message(response)
114+
if err then
115+
if err.message then
116+
err_msg = err_msg .. ": " .. tostring(err.message)
117+
end
118+
if err.code then
119+
err_msg = err_msg .. " (" .. tostring(err.code) .. ")"
120+
end
121+
end
122+
end
123+
return nil, err_msg, response
124+
end
125+
if stream_callback then
126+
assert(type(response) == "string", "Expected string response from streaming output")
127+
local parts = { }
128+
local f = self.client:create_stream_filter(function(c)
129+
return table.insert(parts, c.content)
130+
end)
131+
f(response)
132+
local message = {
133+
role = "assistant",
134+
content = table.concat(parts)
135+
}
136+
if append_response then
137+
self:append_message(message)
138+
end
139+
return message.content
140+
end
141+
local out, err = parse_chat_response(response)
142+
if not (out) then
143+
err = "Failed to parse response from server: " .. tostring(err)
144+
return nil, err, response
145+
end
146+
if append_response then
147+
self:append_message(out.message)
148+
end
149+
return out.response or out.message
150+
end
151+
}
152+
_base_0.__index = _base_0
153+
_class_0 = setmetatable({
154+
__init = function(self, client, opts)
155+
if opts == nil then
156+
opts = { }
157+
end
158+
self.client, self.opts = client, opts
159+
self.messages = { }
160+
if type(self.opts.messages) == "table" then
161+
self:append_message(unpack(self.opts.messages))
162+
end
163+
if type(self.opts.functions) == "table" then
164+
self.functions = { }
165+
local _list_0 = self.opts.functions
166+
for _index_0 = 1, #_list_0 do
167+
local func = _list_0[_index_0]
168+
assert(test_function(func))
169+
table.insert(self.functions, func)
170+
end
171+
end
172+
end,
173+
__base = _base_0,
174+
__name = "ChatSession"
175+
}, {
176+
__index = _base_0,
177+
__call = function(cls, ...)
178+
local _self_0 = setmetatable({}, _base_0)
179+
cls.__init(_self_0, ...)
180+
return _self_0
181+
end
182+
})
183+
_base_0.__class = _class_0
184+
ChatSession = _class_0
185+
end
186+
return {
187+
ChatSession = ChatSession,
188+
test_message = test_message
189+
}

openai/chat_session.moon

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
cjson = require "cjson"
2+
import types from require "tableshape"
3+
4+
empty = (types.nil + types.literal(cjson.null))\describe "nullable"
5+
6+
content_format = types.string + types.array_of types.one_of {
7+
types.shape { type: "text", text: types.string }
8+
types.shape { type: "image_url", image_url: types.string + types.partial {
9+
url: types.string
10+
}}
11+
}
12+
13+
test_message = types.one_of {
14+
types.partial {
15+
role: types.one_of {"system", "user", "assistant"}
16+
content: empty + content_format -- this can be empty when function_call is set
17+
name: empty + types.string
18+
function_call: empty + types.table
19+
}
20+
21+
-- this message type is for sending a function call response back
22+
types.partial {
23+
role: types.one_of {"function"}
24+
name: types.string
25+
content: empty + types.string
26+
}
27+
}
28+
29+
-- verify the shape of a function declaration
30+
test_function = types.shape {
31+
name: types.string
32+
description: types.nil + types.string
33+
parameters: types.nil + types.table
34+
}
35+
36+
parse_chat_response = types.partial {
37+
usage: types.table\tag "usage"
38+
choices: types.partial {
39+
types.partial {
40+
message: types.one_of({
41+
-- if function call is requested, content is not required so we tag
42+
-- nothing so we can return the whole object
43+
types.partial({
44+
role: "assistant"
45+
content: types.string + empty
46+
function_call: types.partial {
47+
name: types.string
48+
-- API returns arguments a string that should be in json format
49+
arguments: types.string
50+
}
51+
})
52+
53+
types.partial {
54+
role: "assistant"
55+
content: types.string\tag "response"
56+
}
57+
})\tag "message"
58+
}
59+
}
60+
}
61+
62+
parse_error_message = types.partial {
63+
error: types.partial {
64+
message: types.string\tag "message"
65+
code: empty + types.string\tag "code"
66+
}
67+
}
68+
69+
-- handles appending response for each call to chat
70+
-- TODO: hadle appending the streaming response to the output
71+
class ChatSession
72+
new: (@client, @opts={}) =>
73+
@messages = {}
74+
75+
if type(@opts.messages) == "table"
76+
@append_message unpack @opts.messages
77+
78+
if type(@opts.functions) == "table"
79+
@functions = {}
80+
for func in *@opts.functions
81+
assert test_function func
82+
table.insert @functions, func
83+
84+
append_message: (m, ...) =>
85+
assert test_message m
86+
table.insert @messages, m
87+
88+
if select("#", ...) > 0
89+
@append_message ...
90+
91+
last_message: =>
92+
@messages[#@messages]
93+
94+
-- append a message to the history, then trigger a completion with generate_response
95+
-- message: message object to append to history
96+
-- stream_callback: provide a function to enable streaming output. function will receive each chunk as it's generated
97+
send: (message, stream_callback=nil) =>
98+
if type(message) == "string"
99+
message = {role: "user", content: message}
100+
101+
@append_message message
102+
@generate_response true, stream_callback
103+
104+
-- call openai API to generate the next response for the stored chat history
105+
-- returns a string of the response
106+
-- append_response: should the response be appended to the chat history
107+
-- stream_callback: provide a function to enable streaming output. function will receive each chunk as it's generated
108+
generate_response: (append_response=true, stream_callback=nil) =>
109+
status, response = @client\chat @messages, {
110+
function_call: @opts.function_call -- override the default function call behavior
111+
functions: @functions
112+
model: @opts.model
113+
temperature: @opts.temperature
114+
stream: stream_callback and true or nil
115+
response_format: @opts.response_format
116+
}, stream_callback
117+
118+
if status != 200
119+
err_msg = "Bad status: #{status}"
120+
121+
if err = parse_error_message response
122+
if err.message
123+
err_msg ..= ": #{err.message}"
124+
125+
if err.code
126+
err_msg ..= " (#{err.code})"
127+
128+
return nil, err_msg, response
129+
130+
-- if we are streaming we need to pase the entire fragmented response
131+
if stream_callback
132+
assert type(response) == "string",
133+
"Expected string response from streaming output"
134+
135+
parts = {}
136+
f = @client\create_stream_filter (c) ->
137+
table.insert parts, c.content
138+
139+
f response
140+
message = {
141+
role: "assistant"
142+
content: table.concat parts
143+
}
144+
145+
if append_response
146+
@append_message message
147+
148+
return message.content
149+
150+
out, err = parse_chat_response response
151+
152+
unless out
153+
err = "Failed to parse response from server: #{err}"
154+
return nil, err, response
155+
156+
if append_response
157+
@append_message out.message
158+
159+
-- response is missing for function_calls, so we return the entire message object
160+
out.response or out.message
161+
162+
{:ChatSession, :test_message}

0 commit comments

Comments
 (0)