Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/services/llm/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (p *openAIProvider) StreamChat(ctx context.Context, cfg *config, messages [
bufReader := bufio.NewReaderSize(body, 1024)

var currentToolCalls []ToolCall
var finishReason string
var finishReason FinishReason
var lines int

for {
Expand Down Expand Up @@ -198,7 +198,7 @@ func (p *openAIProvider) StreamChat(ctx context.Context, cfg *config, messages [
} `json:"function"`
} `json:"tool_calls"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
FinishReason FinishReason `json:"finish_reason"`
} `json:"choices"`
}

Expand Down
13 changes: 12 additions & 1 deletion pkg/services/llm/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ import (
"strings"
)

type FinishReason string

const (
FinishReasonStop FinishReason = "stop"
FinishReasonLength FinishReason = "length"
FinishReasonFunctionCall FinishReason = "function_call"
FinishReasonToolCalls FinishReason = "tool_calls"
FinishReasonContentFilter FinishReason = "content_filter"
FinishReasonNull FinishReason = "null"
)

// Message 表示聊天消息
type Message struct {
Role string `json:"role"`
Expand Down Expand Up @@ -77,7 +88,7 @@ type StreamResult struct {
Delta string
ToolCalls []ToolCall
Done bool `json:",omitempty"`
FinishReason string
FinishReason FinishReason
Error error `json:",omitempty"`
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/settings/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type Config struct {

type Provider struct {
APIKey string `envconfig:"Api_Key" required:"true"`
URL string `envconfig:"url"`
URL string `envconfig:"url" required:"true"`
Model string `envconfig:"MODEL" required:"true"`
Type string `envconfig:"type" default:"openai" desc:"provider type: openai, anthropic, openrouter, ollama"`
}
Expand Down
61 changes: 31 additions & 30 deletions pkg/web/api/handle_convo.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ func init() {

// chatRequest 内部聊天请求结构
type chatRequest struct {
messages []llm.Message
tools []llm.ToolDefinition
isSSE bool
cs stores.Conversation
hi *aigc.HistoryItem
messages []llm.Message
tools []llm.ToolDefinition
isSSE bool
cs stores.Conversation
hi *aigc.HistoryItem
chunkIdx int // 全局 chunk 计数器,用于 SSE 事件序号
}

// convertMCPToolsToLLMTools 将 MCP 工具描述转换为 LLM 工具定义
Expand Down Expand Up @@ -235,7 +236,7 @@ func (a *api) postChat(w http.ResponseWriter, r *http.Request) {
logger().Infow("chat", "answer", answer)

var cm ChatMessage
cm.ID = ccr.cs.GetID()
cm.ID = ccr.cs.GetID() // TODO: deprecated by new message id
cm.Text = answer
render.JSON(w, r, &cm)
}
Expand Down Expand Up @@ -267,6 +268,21 @@ func writeEvent(w io.Writer, id string, m any) bool {

// chatStreamResponseLoop 循环处理流式响应,支持工具调用循环
func (a *api) chatStreamResponseLoop(ccr *chatRequest, w http.ResponseWriter, r *http.Request) (res ChatResponse) {
// 预先设置 HTTP 头信息(只设置一次)
if _, ok := w.(http.Flusher); !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return ChatResponse{}
}

w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
if ccr.isSSE {
w.Header().Set("Content-Type", "text/event-stream")
} else {
w.Header().Add("Content-type", "application/octet-stream")
}
w.Header().Add("Conversation-ID", ccr.cs.GetID())

for {
// 调用流式响应处理
streamRes := a.doChatStream(ccr, w, r)
Expand Down Expand Up @@ -301,21 +317,6 @@ func (a *api) chatStreamResponseLoop(ccr *chatRequest, w http.ResponseWriter, r

// doChatStream 执行一次流式调用,返回累积的 answer 和 toolCalls
func (a *api) doChatStream(ccr *chatRequest, w http.ResponseWriter, r *http.Request) ChatResponse {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return ChatResponse{}
}

w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
if ccr.isSSE {
w.Header().Set("Content-Type", "text/event-stream")
} else {
w.Header().Add("Content-type", "application/octet-stream")
}
w.Header().Add("Conversation-ID", ccr.cs.GetID())

stream, err := a.llm.StreamChat(r.Context(), ccr.messages, ccr.tools)
if err != nil {
logger().Infow("call chat stream fail", "err", err)
Expand All @@ -325,15 +326,13 @@ func (a *api) doChatStream(ccr *chatRequest, w http.ResponseWriter, r *http.Requ

var cm ChatMessage
if !ccr.isSSE {
cm.ConversationID = cm.ID
cm.ConversationID = ccr.cs.GetID()
}

var chunkIdx int
var res ChatResponse
var lastWriteEmpty bool // 标记上一次是否写入了空消息

for result := range stream {
chunkIdx++
if result.Error != nil {
logger().Infow("stream error", "err", result.Error)
break
Expand All @@ -343,35 +342,38 @@ func (a *api) doChatStream(ccr *chatRequest, w http.ResponseWriter, r *http.Requ

cm.Delta = result.Delta
res.answer += result.Delta
if len(result.ToolCalls) > 0 {
if len(result.ToolCalls) > 0 && result.FinishReason == llm.FinishReasonToolCalls {
cm.ToolCalls = convertToolCallsForJSON(result.ToolCalls)
}

if ccr.isSSE {
if result.Done {
ccr.chunkIdx++
cm.ConversationID = ccr.cs.GetID()
cm.FinishReason = result.FinishReason
_ = writeEvent(w, strconv.Itoa(chunkIdx), &cm)
cm.FinishReason = string(result.FinishReason)
_ = writeEvent(w, strconv.Itoa(ccr.chunkIdx), &cm)
} else {
// 判断当前是否为空消息
isEmpty := result.Delta == "" && len(cm.ToolCalls) == 0
if !isEmpty || !lastWriteEmpty {
// 有内容,或者上一次不是空的,则输出
if wrote = writeEvent(w, strconv.Itoa(chunkIdx), &cm); !wrote {
ccr.chunkIdx++
if wrote = writeEvent(w, strconv.Itoa(ccr.chunkIdx), &cm); !wrote {
break
}
lastWriteEmpty = isEmpty
}
// 如果当前是空的且上一次也是空的,跳过(连续空消息只保留第一个)
}
} else {
ccr.chunkIdx++
cm.Text += result.Delta
if err = json.NewEncoder(w).Encode(&cm); err != nil {
logger().Infow("json encode fail", "err", err)
break
}
}
flusher.Flush()
w.(http.Flusher).Flush()

if result.Done {
res.toolCalls = result.ToolCalls
Expand Down Expand Up @@ -543,7 +545,6 @@ func convertToolCallsForJSON(tcs []llm.ToolCall) []map[string]any {
return result
}


// chatExecutor 定义聊天执行函数类型,支持流式/非流式
type chatExecutor func(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) (string, []llm.ToolCall, *llm.Usage, error)

Expand Down
Loading