Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/convo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,26 @@ models:
- comment: 标题
name: Title
type: string
tags: {bson: 'title', json: 'title', pg: ',notnull', binding: 'required'}
tags: {bson: 'title', json: 'title', pg: ',notnull'}
isset: true
query: 'match'
- comment: 消息数
name: MessageCount
type: int
tags: {bson: 'msgCount', json: 'msgCount', pg: 'msg_count,notnull,type:smallint'}
isset: true
- comment: '状态'
name: Status
type: SessionStatus
tags: {bson: 'status', json: 'status', pg: ',notnull,type:smallint'}
isset: true
query: 'equal'
- comment: 工具
name: Tools
type: '[]string'
tags: {bson: 'tools', json: 'tools', pg: ",notnull,default:'[]'"}
isset: true

- type: comm.MetaField
- type: comm.OwnerField
oidcat: event
Expand Down
20 changes: 20 additions & 0 deletions pkg/models/aigc/history.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,26 @@ func (z *HistoryItem) calcTokens() (c int) {

type HistoryItems []HistoryItem

// ToText 将历史记录转换为纯文本格式
func (z HistoryItems) ToText() string {
var sb strings.Builder
for _, item := range z {
if item.ChatItem != nil {
if item.ChatItem.User != "" {
sb.WriteString("用户: ")
sb.WriteString(item.ChatItem.User)
sb.WriteString("\n")
}
if item.ChatItem.Assistant != "" {
sb.WriteString("助手: ")
sb.WriteString(item.ChatItem.Assistant)
sb.WriteString("\n")
}
}
}
return sb.String()
}

// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (z *HistoryItem) MarshalBinary() (data []byte, err error) {
data, err = json.Marshal(z)
Expand Down
67 changes: 67 additions & 0 deletions pkg/models/aigc/history_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,70 @@ func TestHiAscend(t *testing.T) {
t.Errorf("Len got %d, want 3", ha.Len())
}
}

func TestHistoryItems_ToText(t *testing.T) {
tests := []struct {
name string
items HistoryItems
expected string
}{
{
name: "empty",
items: HistoryItems{},
expected: "",
},
{
name: "only user message",
items: HistoryItems{
{ChatItem: &HistoryChatItem{User: "Hello"}},
},
expected: "用户: Hello\n",
},
{
name: "only assistant message",
items: HistoryItems{
{ChatItem: &HistoryChatItem{Assistant: "Hi there"}},
},
expected: "助手: Hi there\n",
},
{
name: "user and assistant",
items: HistoryItems{
{ChatItem: &HistoryChatItem{User: "Hello", Assistant: "Hi there"}},
},
expected: "用户: Hello\n助手: Hi there\n",
},
{
name: "multiple messages",
items: HistoryItems{
{ChatItem: &HistoryChatItem{User: "First question"}},
{ChatItem: &HistoryChatItem{Assistant: "First answer"}},
{ChatItem: &HistoryChatItem{User: "Second question"}},
},
expected: "用户: First question\n助手: First answer\n用户: Second question\n",
},
{
name: "nil chat item",
items: HistoryItems{
{},
},
expected: "",
},
{
name: "empty chat item",
items: HistoryItems{
{ChatItem: &HistoryChatItem{}},
},
expected: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.items.ToText()
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}
24 changes: 20 additions & 4 deletions pkg/models/convo/convo_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,15 @@ type Session struct {

type SessionBasic struct {
// 标题
Title string `binding:"required" bson:"title" bun:",notnull" extensions:"x-order=A" form:"title" json:"title" pg:",notnull"`
Title string `bson:"title" bun:",notnull" extensions:"x-order=A" form:"title" json:"title" pg:",notnull"`
// 消息数
MessageCount int `bson:"msgCount" bun:"msg_count,notnull,type:smallint" extensions:"x-order=B" form:"msgCount" json:"msgCount" pg:"msg_count,notnull,type:smallint"`
// 状态
// * `open` - 开启
// * `closed` - 关闭
Status SessionStatus `bson:"status" bun:",notnull,type:smallint" enums:"open,closed" extensions:"x-order=B" form:"status" json:"status" pg:",notnull,type:smallint" swaggertype:"string"`
Status SessionStatus `bson:"status" bun:",notnull,type:smallint" enums:"open,closed" extensions:"x-order=C" form:"status" json:"status" pg:",notnull,type:smallint" swaggertype:"string"`
// 工具
Tools []string `bson:"tools" bun:",notnull,default:'[]'" extensions:"x-order=D" json:"tools" pg:",notnull,default:'[]'"`
// for meta update
MetaDiff *comm.MetaDiff `bson:"-" bun:"-" json:"metaUp,omitempty" pg:"-" swaggerignore:"true"`
} // @name convoSessionBasic
Expand Down Expand Up @@ -115,25 +119,37 @@ func (_ *Session) IdentityAlias() string { return SessionAlias }
type SessionSet struct {
// 标题
Title *string `extensions:"x-order=A" json:"title"`
// 消息数
MessageCount *int `extensions:"x-order=B" json:"msgCount"`
// 状态
// * `open` - 开启
// * `closed` - 关闭
Status *SessionStatus `enums:"open,closed" extensions:"x-order=B" json:"status" swaggertype:"string"`
Status *SessionStatus `enums:"open,closed" extensions:"x-order=C" json:"status" swaggertype:"string"`
// 工具
Tools *[]string `extensions:"x-order=D" json:"tools"`
// for meta update
MetaDiff *comm.MetaDiff `json:"metaUp,omitempty" swaggerignore:"true"`
// 仅用于更新所有者(负责人)
OwnerID *string `extensions:"x-order=C" json:"ownerID,omitempty"`
OwnerID *string `extensions:"x-order=E" json:"ownerID,omitempty"`
} // @name convoSessionSet

func (z *Session) SetWith(o SessionSet) {
if o.Title != nil && z.Title != *o.Title {
z.LogChangeValue("title", z.Title, o.Title)
z.Title = *o.Title
}
if o.MessageCount != nil && z.MessageCount != *o.MessageCount {
z.LogChangeValue("msg_count", z.MessageCount, o.MessageCount)
z.MessageCount = *o.MessageCount
}
if o.Status != nil && z.Status != *o.Status {
z.LogChangeValue("status", z.Status, o.Status)
z.Status = *o.Status
}
if o.Tools != nil {
z.LogChangeValue("tools", z.Tools, o.Tools)
z.Tools = *o.Tools
}
if o.MetaDiff != nil && z.MetaUp(o.MetaDiff) {
z.SetChange("meta")
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/services/llm/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ type FunctionDefinition struct {
Parameters any `json:"parameters,omitempty"`
}

type Tools []ToolDefinition

func (z Tools) Names() []string {
out := make([]string, len(z))
for i := range z {
out[i] = z[i].Function.Name
}
return out
}

// ChatResult 聊天结果
type ChatResult struct {
Content string
Expand Down
23 changes: 2 additions & 21 deletions pkg/services/stores/cob_x.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@ import (
)

const (
Separator = "\n* "
AnswerStop = " END"
Separator = "\n* "

tplKeyword = "总结下面的文字内容,提炼出关键字句,如果是疑问句,则忽略问话的形式,只罗列出重点关键词,去除疑问形式,不考虑疑问表达,也不要返回多余内容,只关注最重要的词语,例如如果文字内容是问“什么”“为什么”“有什么”“怎么样”等等类似的语句,这些问话形式一律忽略,只返回关键字句,如果关键字句不成语句,则以关键字列表的形式返回,且用空格分隔,仅占一行,不要多行:\n\n%s\n\n"
tplQaCtx = "根据以下文本编写尽可能多一些的问题及回答: \n\n文本:\n%s\n\n"
maxQaTokens = 1024
tplKeyword = "总结下面的文字内容,提炼出关键字句,如果是疑问句,则忽略问话的形式,只罗列出重点关键词,去除疑问形式,不考虑疑问表达,也不要返回多余内容,只关注最重要的词语,例如如果文字内容是问“什么”“为什么”“有什么”“怎么样”等等类似的语句,这些问话形式一律忽略,只返回关键字句,如果关键字句不成语句,则以关键字列表的形式返回,且用空格分隔,仅占一行,不要多行:\n\n%s\n\n"
)

var (
Expand Down Expand Up @@ -193,22 +190,6 @@ func GetEmbedding(ctx context.Context, text string) (vec corpus.Vector, err erro
return
}

func GetKeywords(ctx context.Context, text string) (kw string, err error) {
if len(text) == 0 {
err = ErrEmptyParam
return
}
prompt := fmt.Sprintf(tplKeyword, text)
result, _, err := llmSu.Generate(ctx, prompt)
if err != nil {
logger().Infow("summarize fail", "text", text, "err", err)
return
}
kw = strings.TrimSpace(result)
logger().Infow("summarize ok", "text", text, "kw", kw)
return
}

func (s *cobStore) ConstructPrompt(ctx context.Context, ms MatchSpec) (prompt string, err error) {
var docs corpus.Documents
docs, err = s.MatchDocments(ctx, ms)
Expand Down
32 changes: 30 additions & 2 deletions pkg/services/stores/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ const (
type Conversation interface {
GetID() string
GetOID() oid.OID
SetTools(names ...string)
Save(ctx context.Context) error
CountHistory(ctx context.Context) int
AddHistory(ctx context.Context, item *aigc.HistoryItem) error
ListHistory(ctx context.Context) (aigc.HistoryItems, error)
ClearHistory(ctx context.Context) error
Expand Down Expand Up @@ -71,7 +74,30 @@ func (s *conversation) GetOID() oid.OID {
return s.id
}

// TODO: AddMessages(), Summary()
func (s *conversation) SetTools(names ...string) {
if len(names) > 0 {
s.sess.Tools = names
}
}

// 保存聊天会话
func (s *conversation) Save(ctx context.Context) error {
count := s.CountHistory(ctx)
s.sess.MessageCount = count
return s.sto.Convo().SaveSession(ctx, s.sess)
}

func (s *conversation) CountHistory(ctx context.Context) int {
key := s.getKey()
n, err := s.rc.LLen(ctx, key).Result()
if err != nil {
logger().Infow("llen fail", "key", key, "err", err)
return 0
}
return int(n)
}

// TODO: AddMessages()

func (s *conversation) AddHistory(ctx context.Context, item *aigc.HistoryItem) error {
key := s.getKey()
Expand Down Expand Up @@ -107,8 +133,10 @@ func (s *conversation) AddHistory(ctx context.Context, item *aigc.HistoryItem) e
}
if err != nil {
logger().Infow("add history fail", "key", key, "err", err)
return err
}
return err

return s.Save(ctx)
}

// getLastUserMessage 获取列表中最后一条消息
Expand Down
40 changes: 40 additions & 0 deletions pkg/services/stores/conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,43 @@ func TestHistoryMaxLength(t *testing.T) {
t.Errorf("expected %d items, got %d", historyMaxLength, len(history))
}
}

func TestCountHistory(t *testing.T) {
mr, conv := newTestConversation(t)
defer mr.Close()

ctx := context.Background()

// 空列表计数
count := conv.CountHistory(ctx)
if count != 0 {
t.Errorf("expected 0, got %d", count)
}

// 添加消息后计数
items := []*aigc.HistoryItem{
{Time: 1, ChatItem: &aigc.HistoryChatItem{User: "First"}},
{Time: 2, ChatItem: &aigc.HistoryChatItem{User: "Second"}},
{Time: 3, ChatItem: &aigc.HistoryChatItem{User: "Third"}},
}
for _, item := range items {
if err := conv.AddHistory(ctx, item); err != nil {
t.Fatalf("AddHistory failed: %v", err)
}
}

count = conv.CountHistory(ctx)
if count != 3 {
t.Errorf("expected 3, got %d", count)
}

// 清除后计数
err := conv.ClearHistory(ctx)
if err != nil {
t.Fatalf("ClearHistory failed: %v", err)
}
count = conv.CountHistory(ctx)
if count != 0 {
t.Errorf("expected 0 after clear, got %d", count)
}
}
17 changes: 17 additions & 0 deletions pkg/services/stores/convo_x.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,26 @@ import (
)

type ConvoStoreX interface {
SaveSession(ctx context.Context, sess *convo.Session) error
SaveUser(ctx context.Context, user *ConvoUser) error
}

func (s convoStore) SaveSession(ctx context.Context, obj *convo.Session) error {
if !obj.IsZeroID() {
exist := new(convo.Session)
if err := dbGetWithPKID(ctx, s.w.db, exist, obj.ID); err == nil {
exist.SetIsUpdate(true)
exist.SetWith(convo.SessionSet{
MessageCount: &obj.MessageCount,
})
dbMetaUp(ctx, s.w.db, exist)
return dbUpdate(ctx, s.w.db, obj)
}
}
dbMetaUp(ctx, s.w.db, obj)
return dbInsert(ctx, s.w.db, obj)
}

func (s *convoStore) SaveUser(ctx context.Context, user *convo.User) error {
// 根据 username 查询用户是否存在
existing := new(convo.User)
Expand Down
Loading
Loading