From 46e46357f28339e708eba141340c9928618fa3a1 Mon Sep 17 00:00:00 2001 From: liut Date: Wed, 11 Mar 2026 15:25:31 +0800 Subject: [PATCH] feat: add conversation summary and title generation - Add GetSummary function for generating chat title from history - Add HistoryItems.ToText() method for converting history to text - Add PATCH /conversation/{csid}/title API endpoint - Implement Conversation.Save() and CountHistory() methods - Add unit tests for ToText and CountHistory - Add /api/summary endpoint for text summarization - Refine rate limiter key generation with IP + path - Adjust default AskRateLimit to 20-H --- docs/convo.yaml | 13 ++- pkg/models/aigc/history.go | 20 +++++ pkg/models/aigc/history_test.go | 67 ++++++++++++++ pkg/models/convo/convo_gen.go | 24 ++++- pkg/services/llm/types.go | 10 +++ pkg/services/stores/cob_x.go | 23 +---- pkg/services/stores/conversation.go | 32 ++++++- pkg/services/stores/conversation_test.go | 40 +++++++++ pkg/services/stores/convo_x.go | 17 ++++ pkg/services/stores/llm.go | 42 +++++++++ pkg/services/stores/wrap.go | 2 + pkg/settings/config.go | 2 +- pkg/web/api/api.go | 12 ++- pkg/web/api/handle_convo.go | 108 +++++++++++++++++++++-- 14 files changed, 376 insertions(+), 36 deletions(-) diff --git a/docs/convo.yaml b/docs/convo.yaml index 18509d9..3ee70eb 100644 --- a/docs/convo.yaml +++ b/docs/convo.yaml @@ -32,15 +32,26 @@ models: - comment: 标题 name: Title type: string - tags: {bson: 'title', json: 'title', pg: ',notnull', binding: 'required'} + tags: {bson: 'title', json: 'title', pg: ',notnull'} isset: true query: 'match' + - comment: 消息数 + name: MessageCount + type: int + tags: {bson: 'msgCount', json: 'msgCount', pg: 'msg_count,notnull,type:smallint'} + isset: true - comment: '状态' name: Status type: SessionStatus tags: {bson: 'status', json: 'status', pg: ',notnull,type:smallint'} isset: true query: 'equal' + - comment: 工具 + name: Tools + type: '[]string' + tags: {bson: 'tools', json: 'tools', pg: ",notnull,default:'[]'"} + isset: true + - type: comm.MetaField - type: comm.OwnerField oidcat: event diff --git a/pkg/models/aigc/history.go b/pkg/models/aigc/history.go index 87ef21e..9a53a5d 100644 --- a/pkg/models/aigc/history.go +++ b/pkg/models/aigc/history.go @@ -29,6 +29,26 @@ func (z *HistoryItem) calcTokens() (c int) { type HistoryItems []HistoryItem +// ToText 将历史记录转换为纯文本格式 +func (z HistoryItems) ToText() string { + var sb strings.Builder + for _, item := range z { + if item.ChatItem != nil { + if item.ChatItem.User != "" { + sb.WriteString("用户: ") + sb.WriteString(item.ChatItem.User) + sb.WriteString("\n") + } + if item.ChatItem.Assistant != "" { + sb.WriteString("助手: ") + sb.WriteString(item.ChatItem.Assistant) + sb.WriteString("\n") + } + } + } + return sb.String() +} + // MarshalBinary implements the encoding.BinaryMarshaler interface. func (z *HistoryItem) MarshalBinary() (data []byte, err error) { data, err = json.Marshal(z) diff --git a/pkg/models/aigc/history_test.go b/pkg/models/aigc/history_test.go index 13df757..cf3217b 100644 --- a/pkg/models/aigc/history_test.go +++ b/pkg/models/aigc/history_test.go @@ -334,3 +334,70 @@ func TestHiAscend(t *testing.T) { t.Errorf("Len got %d, want 3", ha.Len()) } } + +func TestHistoryItems_ToText(t *testing.T) { + tests := []struct { + name string + items HistoryItems + expected string + }{ + { + name: "empty", + items: HistoryItems{}, + expected: "", + }, + { + name: "only user message", + items: HistoryItems{ + {ChatItem: &HistoryChatItem{User: "Hello"}}, + }, + expected: "用户: Hello\n", + }, + { + name: "only assistant message", + items: HistoryItems{ + {ChatItem: &HistoryChatItem{Assistant: "Hi there"}}, + }, + expected: "助手: Hi there\n", + }, + { + name: "user and assistant", + items: HistoryItems{ + {ChatItem: &HistoryChatItem{User: "Hello", Assistant: "Hi there"}}, + }, + expected: "用户: Hello\n助手: Hi there\n", + }, + { + name: "multiple messages", + items: HistoryItems{ + {ChatItem: &HistoryChatItem{User: "First question"}}, + {ChatItem: &HistoryChatItem{Assistant: "First answer"}}, + {ChatItem: &HistoryChatItem{User: "Second question"}}, + }, + expected: "用户: First question\n助手: First answer\n用户: Second question\n", + }, + { + name: "nil chat item", + items: HistoryItems{ + {}, + }, + expected: "", + }, + { + name: "empty chat item", + items: HistoryItems{ + {ChatItem: &HistoryChatItem{}}, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.items.ToText() + if result != tt.expected { + t.Errorf("got %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/pkg/models/convo/convo_gen.go b/pkg/models/convo/convo_gen.go index a4b7db8..d513208 100644 --- a/pkg/models/convo/convo_gen.go +++ b/pkg/models/convo/convo_gen.go @@ -72,11 +72,15 @@ type Session struct { type SessionBasic struct { // 标题 - Title string `binding:"required" bson:"title" bun:",notnull" extensions:"x-order=A" form:"title" json:"title" pg:",notnull"` + Title string `bson:"title" bun:",notnull" extensions:"x-order=A" form:"title" json:"title" pg:",notnull"` + // 消息数 + MessageCount int `bson:"msgCount" bun:"msg_count,notnull,type:smallint" extensions:"x-order=B" form:"msgCount" json:"msgCount" pg:"msg_count,notnull,type:smallint"` // 状态 // * `open` - 开启 // * `closed` - 关闭 - Status SessionStatus `bson:"status" bun:",notnull,type:smallint" enums:"open,closed" extensions:"x-order=B" form:"status" json:"status" pg:",notnull,type:smallint" swaggertype:"string"` + Status SessionStatus `bson:"status" bun:",notnull,type:smallint" enums:"open,closed" extensions:"x-order=C" form:"status" json:"status" pg:",notnull,type:smallint" swaggertype:"string"` + // 工具 + Tools []string `bson:"tools" bun:",notnull,default:'[]'" extensions:"x-order=D" json:"tools" pg:",notnull,default:'[]'"` // for meta update MetaDiff *comm.MetaDiff `bson:"-" bun:"-" json:"metaUp,omitempty" pg:"-" swaggerignore:"true"` } // @name convoSessionBasic @@ -115,14 +119,18 @@ func (_ *Session) IdentityAlias() string { return SessionAlias } type SessionSet struct { // 标题 Title *string `extensions:"x-order=A" json:"title"` + // 消息数 + MessageCount *int `extensions:"x-order=B" json:"msgCount"` // 状态 // * `open` - 开启 // * `closed` - 关闭 - Status *SessionStatus `enums:"open,closed" extensions:"x-order=B" json:"status" swaggertype:"string"` + Status *SessionStatus `enums:"open,closed" extensions:"x-order=C" json:"status" swaggertype:"string"` + // 工具 + Tools *[]string `extensions:"x-order=D" json:"tools"` // for meta update MetaDiff *comm.MetaDiff `json:"metaUp,omitempty" swaggerignore:"true"` // 仅用于更新所有者(负责人) - OwnerID *string `extensions:"x-order=C" json:"ownerID,omitempty"` + OwnerID *string `extensions:"x-order=E" json:"ownerID,omitempty"` } // @name convoSessionSet func (z *Session) SetWith(o SessionSet) { @@ -130,10 +138,18 @@ func (z *Session) SetWith(o SessionSet) { z.LogChangeValue("title", z.Title, o.Title) z.Title = *o.Title } + if o.MessageCount != nil && z.MessageCount != *o.MessageCount { + z.LogChangeValue("msg_count", z.MessageCount, o.MessageCount) + z.MessageCount = *o.MessageCount + } if o.Status != nil && z.Status != *o.Status { z.LogChangeValue("status", z.Status, o.Status) z.Status = *o.Status } + if o.Tools != nil { + z.LogChangeValue("tools", z.Tools, o.Tools) + z.Tools = *o.Tools + } if o.MetaDiff != nil && z.MetaUp(o.MetaDiff) { z.SetChange("meta") } diff --git a/pkg/services/llm/types.go b/pkg/services/llm/types.go index dfce231..1745166 100644 --- a/pkg/services/llm/types.go +++ b/pkg/services/llm/types.go @@ -71,6 +71,16 @@ type FunctionDefinition struct { Parameters any `json:"parameters,omitempty"` } +type Tools []ToolDefinition + +func (z Tools) Names() []string { + out := make([]string, len(z)) + for i := range z { + out[i] = z[i].Function.Name + } + return out +} + // ChatResult 聊天结果 type ChatResult struct { Content string diff --git a/pkg/services/stores/cob_x.go b/pkg/services/stores/cob_x.go index 2322557..7297484 100644 --- a/pkg/services/stores/cob_x.go +++ b/pkg/services/stores/cob_x.go @@ -15,12 +15,9 @@ import ( ) const ( - Separator = "\n* " - AnswerStop = " END" + Separator = "\n* " - tplKeyword = "总结下面的文字内容,提炼出关键字句,如果是疑问句,则忽略问话的形式,只罗列出重点关键词,去除疑问形式,不考虑疑问表达,也不要返回多余内容,只关注最重要的词语,例如如果文字内容是问“什么”“为什么”“有什么”“怎么样”等等类似的语句,这些问话形式一律忽略,只返回关键字句,如果关键字句不成语句,则以关键字列表的形式返回,且用空格分隔,仅占一行,不要多行:\n\n%s\n\n" - tplQaCtx = "根据以下文本编写尽可能多一些的问题及回答: \n\n文本:\n%s\n\n" - maxQaTokens = 1024 + tplKeyword = "总结下面的文字内容,提炼出关键字句,如果是疑问句,则忽略问话的形式,只罗列出重点关键词,去除疑问形式,不考虑疑问表达,也不要返回多余内容,只关注最重要的词语,例如如果文字内容是问“什么”“为什么”“有什么”“怎么样”等等类似的语句,这些问话形式一律忽略,只返回关键字句,如果关键字句不成语句,则以关键字列表的形式返回,且用空格分隔,仅占一行,不要多行:\n\n%s\n\n" ) var ( @@ -193,22 +190,6 @@ func GetEmbedding(ctx context.Context, text string) (vec corpus.Vector, err erro return } -func GetKeywords(ctx context.Context, text string) (kw string, err error) { - if len(text) == 0 { - err = ErrEmptyParam - return - } - prompt := fmt.Sprintf(tplKeyword, text) - result, _, err := llmSu.Generate(ctx, prompt) - if err != nil { - logger().Infow("summarize fail", "text", text, "err", err) - return - } - kw = strings.TrimSpace(result) - logger().Infow("summarize ok", "text", text, "kw", kw) - return -} - func (s *cobStore) ConstructPrompt(ctx context.Context, ms MatchSpec) (prompt string, err error) { var docs corpus.Documents docs, err = s.MatchDocments(ctx, ms) diff --git a/pkg/services/stores/conversation.go b/pkg/services/stores/conversation.go index dcbe215..c824f06 100644 --- a/pkg/services/stores/conversation.go +++ b/pkg/services/stores/conversation.go @@ -21,6 +21,9 @@ const ( type Conversation interface { GetID() string GetOID() oid.OID + SetTools(names ...string) + Save(ctx context.Context) error + CountHistory(ctx context.Context) int AddHistory(ctx context.Context, item *aigc.HistoryItem) error ListHistory(ctx context.Context) (aigc.HistoryItems, error) ClearHistory(ctx context.Context) error @@ -71,7 +74,30 @@ func (s *conversation) GetOID() oid.OID { return s.id } -// TODO: AddMessages(), Summary() +func (s *conversation) SetTools(names ...string) { + if len(names) > 0 { + s.sess.Tools = names + } +} + +// 保存聊天会话 +func (s *conversation) Save(ctx context.Context) error { + count := s.CountHistory(ctx) + s.sess.MessageCount = count + return s.sto.Convo().SaveSession(ctx, s.sess) +} + +func (s *conversation) CountHistory(ctx context.Context) int { + key := s.getKey() + n, err := s.rc.LLen(ctx, key).Result() + if err != nil { + logger().Infow("llen fail", "key", key, "err", err) + return 0 + } + return int(n) +} + +// TODO: AddMessages() func (s *conversation) AddHistory(ctx context.Context, item *aigc.HistoryItem) error { key := s.getKey() @@ -107,8 +133,10 @@ func (s *conversation) AddHistory(ctx context.Context, item *aigc.HistoryItem) e } if err != nil { logger().Infow("add history fail", "key", key, "err", err) + return err } - return err + + return s.Save(ctx) } // getLastUserMessage 获取列表中最后一条消息 diff --git a/pkg/services/stores/conversation_test.go b/pkg/services/stores/conversation_test.go index 27e8856..7ca37de 100644 --- a/pkg/services/stores/conversation_test.go +++ b/pkg/services/stores/conversation_test.go @@ -276,3 +276,43 @@ func TestHistoryMaxLength(t *testing.T) { t.Errorf("expected %d items, got %d", historyMaxLength, len(history)) } } + +func TestCountHistory(t *testing.T) { + mr, conv := newTestConversation(t) + defer mr.Close() + + ctx := context.Background() + + // 空列表计数 + count := conv.CountHistory(ctx) + if count != 0 { + t.Errorf("expected 0, got %d", count) + } + + // 添加消息后计数 + items := []*aigc.HistoryItem{ + {Time: 1, ChatItem: &aigc.HistoryChatItem{User: "First"}}, + {Time: 2, ChatItem: &aigc.HistoryChatItem{User: "Second"}}, + {Time: 3, ChatItem: &aigc.HistoryChatItem{User: "Third"}}, + } + for _, item := range items { + if err := conv.AddHistory(ctx, item); err != nil { + t.Fatalf("AddHistory failed: %v", err) + } + } + + count = conv.CountHistory(ctx) + if count != 3 { + t.Errorf("expected 3, got %d", count) + } + + // 清除后计数 + err := conv.ClearHistory(ctx) + if err != nil { + t.Fatalf("ClearHistory failed: %v", err) + } + count = conv.CountHistory(ctx) + if count != 0 { + t.Errorf("expected 0 after clear, got %d", count) + } +} diff --git a/pkg/services/stores/convo_x.go b/pkg/services/stores/convo_x.go index 17eb2b4..4efc0d7 100644 --- a/pkg/services/stores/convo_x.go +++ b/pkg/services/stores/convo_x.go @@ -8,9 +8,26 @@ import ( ) type ConvoStoreX interface { + SaveSession(ctx context.Context, sess *convo.Session) error SaveUser(ctx context.Context, user *ConvoUser) error } +func (s convoStore) SaveSession(ctx context.Context, obj *convo.Session) error { + if !obj.IsZeroID() { + exist := new(convo.Session) + if err := dbGetWithPKID(ctx, s.w.db, exist, obj.ID); err == nil { + exist.SetIsUpdate(true) + exist.SetWith(convo.SessionSet{ + MessageCount: &obj.MessageCount, + }) + dbMetaUp(ctx, s.w.db, exist) + return dbUpdate(ctx, s.w.db, obj) + } + } + dbMetaUp(ctx, s.w.db, obj) + return dbInsert(ctx, s.w.db, obj) +} + func (s *convoStore) SaveUser(ctx context.Context, user *convo.User) error { // 根据 username 查询用户是否存在 existing := new(convo.User) diff --git a/pkg/services/stores/llm.go b/pkg/services/stores/llm.go index b490b3a..0d8314b 100644 --- a/pkg/services/stores/llm.go +++ b/pkg/services/stores/llm.go @@ -1,6 +1,10 @@ package stores import ( + "context" + "fmt" + "strings" + "github.com/liut/morign/pkg/services/llm" "github.com/liut/morign/pkg/settings" ) @@ -67,3 +71,41 @@ func GetLLMEmbeddingClient() llm.Client { func GetLLMSummarizeClient() llm.Client { return llmSu } + +func GetKeywords(ctx context.Context, text string) (kw string, err error) { + if len(text) == 0 { + err = ErrEmptyParam + return + } + prompt := fmt.Sprintf(tplKeyword, text) + result, _, err := llmSu.Generate(ctx, prompt) + if err != nil { + logger().Infow("summarize fail", "text", text, "err", err) + return + } + kw = strings.TrimSpace(result) + logger().Infow("summarize ok", "text", text, "kw", kw) + return +} + +// GetSummary 生成聊天记录的简短标题 +// text 参数为聊天记录文本,tips 参数为自定义提示内容(可选) +func GetSummary(ctx context.Context, text, tips string) (summary string, err error) { + if len(text) == 0 { + err = ErrEmptyParam + return + } + // 使用自定义提示或默认提示 + if tips == "" { + tips = "请根据以下聊天记录生成一个简短的标题(不超过10个字),这个标题只针对聊天的主题,且只返回标题,不要其他内容:" + } + prompt := fmt.Sprintf("%s\n\n%s\n\n标题:", tips, text) + result, _, err := llmSu.Generate(ctx, prompt) + if err != nil { + logger().Infow("summary fail", "text", text, "err", err) + return + } + summary = strings.TrimSpace(result) + logger().Infow("summary ok", "text", text, "summary", summary) + return +} diff --git a/pkg/services/stores/wrap.go b/pkg/services/stores/wrap.go index 14d0730..1aba87e 100644 --- a/pkg/services/stores/wrap.go +++ b/pkg/services/stores/wrap.go @@ -39,6 +39,8 @@ var ( dbGet = pgx.Get dbFirst = pgx.First dbLast = pgx.Last + dbEnsureID = pgx.EnsureID + dbExists = pgx.Exists queryOne = pgx.QueryOne queryList = pgx.QueryList queryPager = pgx.QueryPager diff --git a/pkg/settings/config.go b/pkg/settings/config.go index fbd9a69..8ef0689 100644 --- a/pkg/settings/config.go +++ b/pkg/settings/config.go @@ -38,7 +38,7 @@ type Config struct { QAEmbedding bool `envconfig:"QA_Embedding" desc:"enable embed QA into prompt"` QAChatLog bool `envconfig:"QA_chat_log"` - AskRateLimit string `envconfig:"Ask_Rate_Limit" default:"60-H"` + AskRateLimit string `envconfig:"Ask_Rate_Limit" default:"20-H"` DateInContext bool `envconfig:"date_in_context"` diff --git a/pkg/web/api/api.go b/pkg/web/api/api.go index 22c140b..418bb42 100644 --- a/pkg/web/api/api.go +++ b/pkg/web/api/api.go @@ -3,6 +3,7 @@ package api import ( "fmt" "net/http" + "strings" "github.com/go-chi/chi/v5" "github.com/go-chi/render" @@ -105,11 +106,20 @@ func (a *api) Strap(router chi.Router) { if err != nil { logger().Fatalw("limiter with redis option failed", "err", err) } + replLK := strings.NewReplacer("/", "") instance := limiter.New(store, rate) middleware := stdlib.NewMiddleware(instance, stdlib.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, err error) { logger().Warnw("failed on", "uri", r.RequestURI, "err", err) - })) + }), + stdlib.WithKeyGetter(func(r *http.Request) string { + return fmt.Sprintf("%s:%s", + limiter.GetIPWithMask(r, limiter.Options{ + TrustForwardHeader: true, + }).String(), + replLK.Replace(r.RequestURI)) + }), + ) router.Route(settings.Current.APIPrefix, func(r chi.Router) { r.Use(middleware.Handler) // 限流 diff --git a/pkg/web/api/handle_convo.go b/pkg/web/api/handle_convo.go index 1344c8a..7743dc1 100644 --- a/pkg/web/api/handle_convo.go +++ b/pkg/web/api/handle_convo.go @@ -15,6 +15,7 @@ import ( "github.com/marcsv/go-binder/binder" "github.com/liut/morign/pkg/models/aigc" + "github.com/liut/morign/pkg/models/convo" "github.com/liut/morign/pkg/models/corpus" "github.com/liut/morign/pkg/models/mcps" "github.com/liut/morign/pkg/services/llm" @@ -39,16 +40,22 @@ func init() { regHI(true, "GET", "/tools", "", func(a *api) http.HandlerFunc { return a.getTools }) + regHI(true, "POST", "/summary", "", func(a *api) http.HandlerFunc { + return a.postSummary + }) + regHI(true, "PATCH", "/conversation/{csid}/title", "", func(a *api) http.HandlerFunc { + return a.patchConversationTitle + }) } // chatRequest 内部聊天请求结构 type chatRequest struct { - messages []llm.Message - tools []llm.ToolDefinition - isSSE bool - cs stores.Conversation - hi *aigc.HistoryItem - chunkIdx int // 全局 chunk 计数器,用于 SSE 事件序号 + messages []llm.Message + tools []llm.ToolDefinition + isSSE bool + cs stores.Conversation + hi *aigc.HistoryItem + chunkIdx int // 全局 chunk 计数器,用于 SSE 事件序号 } // convertMCPToolsToLLMTools 将 MCP 工具描述转换为 LLM 工具定义 @@ -95,6 +102,7 @@ func (a *api) prepareChatRequest(ctx context.Context, param *ChatRequest) *chatR Role: llm.RoleSystem, Content: toolsPrompt, }) + cs.SetTools(llm.Tools(tools).Names()...) } else { // 没有工具,使用问答 docs, err := a.sto.Cob().MatchDocments(ctx, stores.MatchSpec{ Question: param.Prompt, @@ -456,6 +464,94 @@ func (a *api) getWelcome(w http.ResponseWriter, r *http.Request) { apiOk(w, r, msg) } +// SummaryRequest 摘要请求 +type SummaryRequest struct { + Tips string `json:"tips,omitempty"` + Text string `json:"text"` +} + +// @Summary 生成聊天记录摘要 +// @Description 根据聊天记录生成简短标题 +// @Accept json +// @Produce json +// @Param request body SummaryRequest true "请求参数" +// @Success 200 {object} resp.Done +// @Router /api/summary [post] +func (a *api) postSummary(w http.ResponseWriter, r *http.Request) { + var req SummaryRequest + if err := render.DecodeJSON(r.Body, &req); err != nil { + fail(w, r, 400, "invalid request body") + return + } + if req.Text == "" { + fail(w, r, 400, "text is required") + return + } + + summary, err := stores.GetSummary(r.Context(), req.Text, req.Tips) + if err != nil { + fail(w, r, 500, err) + return + } + + apiOk(w, r, summary) +} + +// @Tags 聊天 +// @Summary 生成会话标题 +// @Accept json +// @Produce json +// @Param token header string false "登录票据凭证" +// @Param csid path string true "会话ID" +// @Success 200 {object} Done{result=string} +// @Failure 400 {object} Failure "请求错误" +// @Failure 500 {object} Failure "服务端错误" +// @Router /api/conversation/{csid}/title [patch] +func (a *api) patchConversationTitle(w http.ResponseWriter, r *http.Request) { + csid := chi.URLParam(r, "csid") + if csid == "" { + fail(w, r, 400, "csid is required") + return + } + + // 获取会话历史 + cs := stores.NewConversation(r.Context(), csid) + history, err := cs.ListHistory(r.Context()) + if err != nil { + fail(w, r, 500, err) + return + } + + if len(history) == 0 { + fail(w, r, 400, "no history found") + return + } + + // 将历史记录转换为文本 + text := history.ToText() + if text == "" { + fail(w, r, 400, "no valid chat content") + return + } + + // 调用 GetSummary 生成标题 + summary, err := stores.GetSummary(r.Context(), text, "") + if err != nil { + fail(w, r, 500, err) + return + } + + // 更新会话标题 + title := summary + err = a.sto.Convo().UpdateSession(r.Context(), csid, convo.SessionSet{Title: &title}) + if err != nil { + fail(w, r, 500, err) + return + } + + apiOk(w, r, M{"title": summary}) +} + // @Tags 聊天 // @Summary 获取会话历史 // @Accept json