From 9595e2d7c5ace44366a6f80530b7e9c784fb0acc Mon Sep 17 00:00:00 2001 From: liut Date: Thu, 12 Mar 2026 16:11:13 +0800 Subject: [PATCH] refactor: move KB tool invokers to stores - Add InvokerForSearch and InvokerForCreate methods to CobStoreX - Move tool logic from registry to cobStore for better separation - Migrate ToolDescriptor definitions to defines.go - Remove duplicate callKBSearch/callKBCreate from invokers.go - Add connection.go for MCP connection management - Rename TABLEs and db instance name of Corpus --- data/schemas/20260312010200_corpus.up.sql | 3 + docs/{cob.yaml => corpus.yaml} | 8 +- docs/example.md | 25 ++ docs/mcps.yaml | 33 +- main.go | 2 +- .../corpus/{cob_gen.go => corpus_gen.go} | 10 +- .../corpus/{cob_test.go => corpus_test.go} | 0 pkg/models/corpus/{cob_x.go => corpus_x.go} | 0 pkg/models/mcps/mcps_gen.go | 63 ++-- pkg/models/mcps/mcps_x.go | 6 + pkg/models/mcps/tool.go | 3 + .../stores/{cob_gen.go => corpus_gen.go} | 26 +- pkg/services/stores/{cob_x.go => corpus_x.go} | 90 ++++- pkg/services/stores/mcps_gen.go | 12 +- pkg/services/stores/wrap.go | 8 +- pkg/services/tools/connection.go | 22 ++ pkg/services/tools/defines.go | 79 ++++ pkg/services/tools/invokers.go | 68 ---- pkg/services/tools/registry.go | 357 ++++++++++++++---- pkg/web/api/api.go | 11 + pkg/web/api/handle_convo.go | 7 +- pkg/web/api/handle_mcps_gen.go | 168 +++++++++ pkg/web/api/handle_mcps_x.go | 125 ++++++ 23 files changed, 906 insertions(+), 220 deletions(-) create mode 100644 data/schemas/20260312010200_corpus.up.sql rename docs/{cob.yaml => corpus.yaml} (95%) create mode 100644 docs/example.md rename pkg/models/corpus/{cob_gen.go => corpus_gen.go} (97%) rename pkg/models/corpus/{cob_test.go => corpus_test.go} (100%) rename pkg/models/corpus/{cob_x.go => corpus_x.go} (100%) create mode 100644 pkg/models/mcps/mcps_x.go rename pkg/services/stores/{cob_gen.go => corpus_gen.go} (75%) rename pkg/services/stores/{cob_x.go => corpus_x.go} (73%) create mode 100644 pkg/services/tools/connection.go create mode 100644 pkg/web/api/handle_mcps_gen.go create mode 100644 pkg/web/api/handle_mcps_x.go diff --git a/data/schemas/20260312010200_corpus.up.sql b/data/schemas/20260312010200_corpus.up.sql new file mode 100644 index 0000000..7991caf --- /dev/null +++ b/data/schemas/20260312010200_corpus.up.sql @@ -0,0 +1,3 @@ + +ALTER TABLE qa_corpus_document RENAME TO corpus_document; +ALTER TABLE qa_corpus_vector_400 RENAME TO corpus_vector_400; diff --git a/docs/cob.yaml b/docs/corpus.yaml similarity index 95% rename from docs/cob.yaml rename to docs/corpus.yaml index 5a64604..e94664d 100644 --- a/docs/cob.yaml +++ b/docs/corpus.yaml @@ -2,13 +2,13 @@ depends: comm: github.com/cupogo/andvari/models/comm oid: github.com/cupogo/andvari/models/oid -gename: cob +gename: corpus modelpkg: corpus models: - name: Document comment: '文档 语料库' - tableTag: 'qa_corpus_document,alias:cd' + tableTag: 'corpus_document,alias:cd' fields: - name: comm.DefaultModel - comment: 主标题 名称 @@ -41,7 +41,7 @@ models: - name: DocVector comment: '文档向量 400=1024, 600=1536' - tableTag: 'qa_corpus_vector_400,alias:cv' + tableTag: 'corpus_vector_400,alias:cv' fields: - name: comm.DefaultModel - comment: 文档编号 @@ -93,7 +93,7 @@ models: # export2: true - name: ChatLog - comment: '聊天日志' + comment: '聊天日志 Deprecated' tableTag: 'qa_chat_log,alias:cl' fields: - name: comm.DefaultModel diff --git a/docs/example.md b/docs/example.md new file mode 100644 index 0000000..b5bd622 --- /dev/null +++ b/docs/example.md @@ -0,0 +1,25 @@ + +```json +fetch('/api/m/mcp/servers', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + name: 'webpawm', + transType: 'streamable', + url: 'http://localhost:8087/mcp', + remark: '用于搜索' + }) +}) +.then(r => r.json()) +.then(console.log); + + +fetch('/api/m/mcp/servers/fi-54ionou2hq0w/activate', { + method: 'PUT' +}) +.then(r => r.json()) +.then(console.log); + +```json diff --git a/docs/mcps.yaml b/docs/mcps.yaml index 0862e9f..e9dd9f2 100644 --- a/docs/mcps.yaml +++ b/docs/mcps.yaml @@ -24,15 +24,17 @@ enums: textMarshaler: true textUnmarshaler: true - - comment: 状态 + - comment: 状态 用于表示连接 name: Status start: 0 type: int8 values: - - label: 已停止 - suffix: Stopped - - label: 运行中 - suffix: Running + - label: 断开 初始默认 + suffix: Disconnected + - label: 连接中 + suffix: Connecting + - label: 已连接 + suffix: Connected stringer: true decodable: true textMarshaler: true @@ -46,7 +48,7 @@ models: - name: Server comment: '服务器' - tableTag: 'qa_mcp_server,alias:s' + tableTag: 'mcp_server,alias:ms' fields: - type: comm.DefaultModel - comment: 名称 @@ -71,7 +73,13 @@ models: type: string tags: {bson: 'url', json: 'url', pg: ',notnull'} isset: true - - comment: '状态' + - comment: '是否激活' + name: IsActive + type: bool + tags: {bson: 'isActive', json: 'isActive', pg: ',notnull'} + isset: true + query: 'equal' + - comment: '连接状态' name: Status type: Status tags: {bson: 'status', json: 'status', pg: ',notnull,type:smallint'} @@ -92,3 +100,14 @@ stores: siname: MCP hods: - { name: Server, type: LGCUD } + + +webcode: chi +webapi: + formTag: form + pkg: api + needAuth: true + needPerm: true + uris: + - model: Server + prefix: '/api/mcp' diff --git a/main.go b/main.go index 0212bb7..e69e077 100644 --- a/main.go +++ b/main.go @@ -85,7 +85,7 @@ func embeddingDocVector(cc *cli.Context) error { spec := &stores.CobDocumentSpec{} spec.Limit = 90 spec.Sort = "id" - return stores.Sgt().Qa().EmbeddingDocVector(ctx, spec) + return stores.Sgt().KB().EmbeddingDocVector(ctx, spec) // return nil } diff --git a/pkg/models/corpus/cob_gen.go b/pkg/models/corpus/corpus_gen.go similarity index 97% rename from pkg/models/corpus/cob_gen.go rename to pkg/models/corpus/corpus_gen.go index 93784b8..f774258 100644 --- a/pkg/models/corpus/cob_gen.go +++ b/pkg/models/corpus/corpus_gen.go @@ -9,7 +9,7 @@ import ( // consts of Document 文档 const ( - DocumentTable = "qa_corpus_document" + DocumentTable = "corpus_document" DocumentAlias = "cd" DocumentLabel = "document" DocumentTypID = "corpusDocument" @@ -17,7 +17,7 @@ const ( // Document 文档 语料库 type Document struct { - comm.BaseModel `bun:"table:qa_corpus_document,alias:cd" json:"-"` + comm.BaseModel `bun:"table:corpus_document,alias:cd" json:"-"` comm.DefaultModel @@ -103,7 +103,7 @@ func (in *DocumentSet) MetaAddKVs(args ...any) *DocumentSet { // consts of DocVector 文档向量 const ( - DocVectorTable = "qa_corpus_vector_400" + DocVectorTable = "corpus_vector_400" DocVectorAlias = "cv" DocVectorLabel = "docVector" DocVectorTypID = "corpusDocVector" @@ -111,7 +111,7 @@ const ( // DocVector 文档向量 400=1024, 600=1536 type DocVector struct { - comm.BaseModel `bun:"table:qa_corpus_vector_400,alias:cv" json:"-"` + comm.BaseModel `bun:"table:corpus_vector_400,alias:cv" json:"-"` comm.DefaultModel @@ -218,7 +218,7 @@ const ( ChatLogTypID = "corpusChatLog" ) -// ChatLog 聊天日志 +// ChatLog 聊天日志 Deprecated type ChatLog struct { comm.BaseModel `bun:"table:qa_chat_log,alias:cl" json:"-"` diff --git a/pkg/models/corpus/cob_test.go b/pkg/models/corpus/corpus_test.go similarity index 100% rename from pkg/models/corpus/cob_test.go rename to pkg/models/corpus/corpus_test.go diff --git a/pkg/models/corpus/cob_x.go b/pkg/models/corpus/corpus_x.go similarity index 100% rename from pkg/models/corpus/cob_x.go rename to pkg/models/corpus/corpus_x.go diff --git a/pkg/models/mcps/mcps_gen.go b/pkg/models/mcps/mcps_gen.go index eb566e2..f7f9a47 100644 --- a/pkg/models/mcps/mcps_gen.go +++ b/pkg/models/mcps/mcps_gen.go @@ -55,20 +55,23 @@ func (z TransType) MarshalText() ([]byte, error) { return []byte(z.String()), nil } -// 状态 +// 状态 用于表示连接 type Status int8 const ( - StatusStopped Status = 0 + iota // 0 已停止 - StatusRunning // 1 运行中 + StatusDisconnected Status = 0 + iota // 0 断开 初始默认 + StatusConnecting // 1 连接中 + StatusConnected // 2 已连接 ) func (z *Status) Decode(s string) error { switch s { - case "0", "stopped", "Stopped": - *z = StatusStopped - case "1", "running", "Running": - *z = StatusRunning + case "0", "disconnected", "Disconnected": + *z = StatusDisconnected + case "1", "connecting", "Connecting": + *z = StatusConnecting + case "2", "connected", "Connected": + *z = StatusConnected default: return fmt.Errorf("invalid status: %q", s) } @@ -79,10 +82,12 @@ func (z *Status) UnmarshalText(b []byte) error { } func (z Status) String() string { switch z { - case StatusStopped: - return "stopped" - case StatusRunning: - return "running" + case StatusDisconnected: + return "disconnected" + case StatusConnecting: + return "connecting" + case StatusConnected: + return "connected" default: return fmt.Sprintf("status %d", int8(z)) } @@ -93,15 +98,15 @@ func (z Status) MarshalText() ([]byte, error) { // consts of Server 服务器 const ( - ServerTable = "qa_mcp_server" - ServerAlias = "s" + ServerTable = "mcp_server" + ServerAlias = "ms" ServerLabel = "server" ServerTypID = "mcpsServer" ) // Server 服务器 type Server struct { - comm.BaseModel `bun:"table:qa_mcp_server,alias:s" json:"-"` + comm.BaseModel `bun:"table:mcp_server,alias:ms" json:"-"` comm.DefaultModel @@ -123,12 +128,15 @@ type ServerBasic struct { Command string `bson:"command" bun:",notnull" extensions:"x-order=C" form:"command" json:"command" pg:",notnull"` // 完整网址 仅对 TransType 为 SSE 或 HTTP 时有效 URL string `bson:"url" bun:",notnull" extensions:"x-order=D" form:"url" json:"url" pg:",notnull"` - // 状态 - // * `stopped` - 已停止 - // * `running` - 运行中 - Status Status `bson:"status" bun:",notnull,type:smallint" enums:"stopped,running" extensions:"x-order=E" form:"status" json:"status" pg:",notnull,type:smallint" swaggertype:"string"` + // 是否激活 + IsActive bool `bson:"isActive" bun:",notnull" extensions:"x-order=E" form:"isActive" json:"isActive" pg:",notnull"` + // 连接状态 + // * `disconnected` - 断开 + // * `connecting` - 连接中 + // * `connected` - 已连接 + Status Status `bson:"status" bun:",notnull,type:smallint" enums:"disconnected,connecting,connected" extensions:"x-order=F" form:"status" json:"status" pg:",notnull,type:smallint" swaggertype:"string"` // 备注 - Remark string `bson:"remark" bun:",notnull" extensions:"x-order=F" form:"remark" json:"remark" pg:",notnull"` + Remark string `bson:"remark" bun:",notnull" extensions:"x-order=G" form:"remark" json:"remark" pg:",notnull"` // for meta update MetaDiff *comm.MetaDiff `bson:"-" bun:"-" json:"metaUp,omitempty" pg:"-" swaggerignore:"true"` } // @name mcpsServerBasic @@ -173,12 +181,15 @@ type ServerSet struct { Command *string `extensions:"x-order=C" json:"command"` // 完整网址 仅对 TransType 为 SSE 或 HTTP 时有效 URL *string `extensions:"x-order=D" json:"url"` - // 状态 - // * `stopped` - 已停止 - // * `running` - 运行中 - Status *Status `enums:"stopped,running" extensions:"x-order=E" json:"status" swaggertype:"string"` + // 是否激活 + IsActive *bool `extensions:"x-order=E" json:"isActive"` + // 连接状态 + // * `disconnected` - 断开 + // * `connecting` - 连接中 + // * `connected` - 已连接 + Status *Status `enums:"disconnected,connecting,connected" extensions:"x-order=F" json:"status" swaggertype:"string"` // 备注 - Remark *string `extensions:"x-order=F" json:"remark"` + Remark *string `extensions:"x-order=G" json:"remark"` // for meta update MetaDiff *comm.MetaDiff `json:"metaUp,omitempty" swaggerignore:"true"` } // @name mcpsServerSet @@ -200,6 +211,10 @@ func (z *Server) SetWith(o ServerSet) { z.LogChangeValue("url", z.URL, o.URL) z.URL = *o.URL } + if o.IsActive != nil && z.IsActive != *o.IsActive { + z.LogChangeValue("is_active", z.IsActive, o.IsActive) + z.IsActive = *o.IsActive + } if o.Status != nil && z.Status != *o.Status { z.LogChangeValue("status", z.Status, o.Status) z.Status = *o.Status diff --git a/pkg/models/mcps/mcps_x.go b/pkg/models/mcps/mcps_x.go new file mode 100644 index 0000000..33a5d32 --- /dev/null +++ b/pkg/models/mcps/mcps_x.go @@ -0,0 +1,6 @@ +package mcps + +// IsRemote 判断是否为远程传输类型(SSE 或 Streamable) +func (t TransType) IsRemote() bool { + return t == TransTypeSSE || t == TransTypeStreamable +} diff --git a/pkg/models/mcps/tool.go b/pkg/models/mcps/tool.go index 5719896..0c5a346 100644 --- a/pkg/models/mcps/tool.go +++ b/pkg/models/mcps/tool.go @@ -1,12 +1,15 @@ package mcps import ( + "context" "encoding/json" "fmt" "math" "strings" ) +type Invoker func(ctx context.Context, params map[string]any) (map[string]any, error) + // ToolDescriptor 是工具的描述符,用于 MCP 工具列表 type ToolDescriptor struct { Name string `json:"name"` diff --git a/pkg/services/stores/cob_gen.go b/pkg/services/stores/corpus_gen.go similarity index 75% rename from pkg/services/stores/cob_gen.go rename to pkg/services/stores/corpus_gen.go index 4e5f303..2507563 100644 --- a/pkg/services/stores/cob_gen.go +++ b/pkg/services/stores/corpus_gen.go @@ -81,21 +81,21 @@ func (spec *ChatLogSpec) Sift(q *ormQuery) *ormQuery { return q } -type cobStore struct { +type corpuStore struct { w *Wrap } -func (s *cobStore) ListDocument(ctx context.Context, spec *CobDocumentSpec) (data corpus.Documents, total int, err error) { +func (s *corpuStore) ListDocument(ctx context.Context, spec *CobDocumentSpec) (data corpus.Documents, total int, err error) { total, err = s.w.db.ListModel(ctx, spec, &data) return } -func (s *cobStore) GetDocument(ctx context.Context, id string) (obj *corpus.Document, err error) { +func (s *corpuStore) GetDocument(ctx context.Context, id string) (obj *corpus.Document, err error) { obj = new(corpus.Document) err = dbGetWithPKID(ctx, s.w.db, obj, id) return } -func (s *cobStore) CreateDocument(ctx context.Context, in corpus.DocumentBasic) (obj *corpus.Document, err error) { +func (s *corpuStore) CreateDocument(ctx context.Context, in corpus.DocumentBasic) (obj *corpus.Document, err error) { obj = corpus.NewDocumentWithBasic(in) dbMetaUp(ctx, s.w.db, obj) err = dbInsert(ctx, s.w.db, obj) @@ -104,7 +104,7 @@ func (s *cobStore) CreateDocument(ctx context.Context, in corpus.DocumentBasic) } return } -func (s *cobStore) UpdateDocument(ctx context.Context, id string, in corpus.DocumentSet) error { +func (s *corpuStore) UpdateDocument(ctx context.Context, id string, in corpus.DocumentSet) error { exist := new(corpus.Document) if err := dbGetWithPKID(ctx, s.w.db, exist, id); err != nil { return err @@ -114,7 +114,7 @@ func (s *cobStore) UpdateDocument(ctx context.Context, id string, in corpus.Docu dbMetaUp(ctx, s.w.db, exist) return dbUpdate(ctx, s.w.db, exist) } -func (s *cobStore) DeleteDocument(ctx context.Context, id string) error { +func (s *corpuStore) DeleteDocument(ctx context.Context, id string) error { obj := new(corpus.Document) if err := dbGetWithPKID(ctx, s.w.db, obj, id); err != nil { return err @@ -128,40 +128,40 @@ func (s *cobStore) DeleteDocument(ctx context.Context, id string) error { }) } -func (s *cobStore) GetDocVector(ctx context.Context, id string) (obj *corpus.DocVector, err error) { +func (s *corpuStore) GetDocVector(ctx context.Context, id string) (obj *corpus.DocVector, err error) { obj = new(corpus.DocVector) err = dbGetWithPKID(ctx, s.w.db, obj, id) return } -func (s *cobStore) CreateDocVector(ctx context.Context, in corpus.DocVectorBasic) (obj *corpus.DocVector, err error) { +func (s *corpuStore) CreateDocVector(ctx context.Context, in corpus.DocVectorBasic) (obj *corpus.DocVector, err error) { obj = corpus.NewDocVectorWithBasic(in) dbMetaUp(ctx, s.w.db, obj) err = dbInsert(ctx, s.w.db, obj) return } -func (s *cobStore) DeleteDocVector(ctx context.Context, id string) error { +func (s *corpuStore) DeleteDocVector(ctx context.Context, id string) error { obj := new(corpus.DocVector) return s.w.db.DeleteModel(ctx, obj, id) } -func (s *cobStore) CreateChatLog(ctx context.Context, in corpus.ChatLogBasic) (obj *corpus.ChatLog, err error) { +func (s *corpuStore) CreateChatLog(ctx context.Context, in corpus.ChatLogBasic) (obj *corpus.ChatLog, err error) { obj = corpus.NewChatLogWithBasic(in) dbMetaUp(ctx, s.w.db, obj) err = dbInsert(ctx, s.w.db, obj) return } -func (s *cobStore) GetChatLog(ctx context.Context, id string) (obj *corpus.ChatLog, err error) { +func (s *corpuStore) GetChatLog(ctx context.Context, id string) (obj *corpus.ChatLog, err error) { obj = new(corpus.ChatLog) err = dbGetWithPKID(ctx, s.w.db, obj, id) return } -func (s *cobStore) ListChatLog(ctx context.Context, spec *ChatLogSpec) (data corpus.ChatLogs, total int, err error) { +func (s *corpuStore) ListChatLog(ctx context.Context, spec *ChatLogSpec) (data corpus.ChatLogs, total int, err error) { total, err = s.w.db.ListModel(ctx, spec, &data) return } -func (s *cobStore) DeleteChatLog(ctx context.Context, id string) error { +func (s *corpuStore) DeleteChatLog(ctx context.Context, id string) error { obj := new(corpus.ChatLog) return s.w.db.DeleteModel(ctx, obj, id) } diff --git a/pkg/services/stores/cob_x.go b/pkg/services/stores/corpus_x.go similarity index 73% rename from pkg/services/stores/cob_x.go rename to pkg/services/stores/corpus_x.go index be63002..a7fb579 100644 --- a/pkg/services/stores/cob_x.go +++ b/pkg/services/stores/corpus_x.go @@ -11,6 +11,7 @@ import ( "github.com/pmezard/go-difflib/difflib" "github.com/liut/morign/pkg/models/corpus" + "github.com/liut/morign/pkg/models/mcps" "github.com/liut/morign/pkg/settings" ) @@ -72,9 +73,11 @@ type CobStoreX interface { ConstructPrompt(ctx context.Context, ms MatchSpec) (prompt string, err error) MatchDocments(ctx context.Context, ms MatchSpec) (data corpus.Documents, err error) MatchVectorWith(ctx context.Context, vec corpus.Vector, threshold float32, limit int) (data corpus.DocMatches, err error) + InvokerForSearch() mcps.Invoker + InvokerForCreate() mcps.Invoker } -func (s *cobStore) ImportDocs(ctx context.Context, r io.Reader, lw io.Writer) error { +func (s *corpuStore) ImportDocs(ctx context.Context, r io.Reader, lw io.Writer) error { rd := csv.NewReader(r) rec, err := rd.Read() if err != nil { @@ -114,7 +117,7 @@ func (s *cobStore) ImportDocs(ctx context.Context, r io.Reader, lw io.Writer) er } } -func (s *cobStore) importLine(ctx context.Context, basic corpus.DocumentBasic, lw io.Writer) error { +func (s *corpuStore) importLine(ctx context.Context, basic corpus.DocumentBasic, lw io.Writer) error { doc := new(corpus.Document) basic.Content = replText.Replace(basic.Content) err := dbGet(ctx, s.w.db, doc, "title = ? AND heading = ?", basic.Title, basic.Heading) @@ -155,7 +158,7 @@ func diff2(text1, text2 string) string { return text } -func (s *cobStore) afterCreatedCobDocument(ctx context.Context, obj *corpus.Document) error { +func (s *corpuStore) afterCreatedCobDocument(ctx context.Context, obj *corpus.Document) error { dvb := corpus.DocVectorBasic{ DocID: obj.ID, Subject: obj.GetSubject(), @@ -201,7 +204,7 @@ func GetEmbedding(ctx context.Context, text string) (vec corpus.Vector, err erro return } -func (s *cobStore) ConstructPrompt(ctx context.Context, ms MatchSpec) (prompt string, err error) { +func (s *corpuStore) ConstructPrompt(ctx context.Context, ms MatchSpec) (prompt string, err error) { var docs corpus.Documents docs, err = s.MatchDocments(ctx, ms) if err != nil { @@ -218,7 +221,7 @@ func (s *cobStore) ConstructPrompt(ctx context.Context, ms MatchSpec) (prompt st return } -func (s *cobStore) MatchDocments(ctx context.Context, ms MatchSpec) (data corpus.Documents, err error) { +func (s *corpuStore) MatchDocments(ctx context.Context, ms MatchSpec) (data corpus.Documents, err error) { ms.setDefaults() var subject string if ms.SkipKeywords { @@ -258,7 +261,7 @@ func (s *cobStore) MatchDocments(ctx context.Context, ms MatchSpec) (data corpus return } -func (s *cobStore) MatchVectorWith(ctx context.Context, vec corpus.Vector, threshold float32, limit int) (data corpus.DocMatches, err error) { +func (s *corpuStore) MatchVectorWith(ctx context.Context, vec corpus.Vector, threshold float32, limit int) (data corpus.DocMatches, err error) { if len(vec) != corpus.VectorLen { logger().Infow("mismatch length of vector", "a", len(vec), "b", corpus.VectorLen) return @@ -281,7 +284,7 @@ func (s *cobStore) MatchVectorWith(ctx context.Context, vec corpus.Vector, thres return } -func (s *cobStore) ExportDocs(ctx context.Context, ea ExportArg) error { +func (s *corpuStore) ExportDocs(ctx context.Context, ea ExportArg) error { data, _, err := s.ListDocument(ctx, ea.Spec) if err != nil { return err @@ -315,7 +318,7 @@ func documentsToCSV(data corpus.Documents, w io.Writer) error { return cw.Error() } -func (s *cobStore) EmbeddingDocVector(ctx context.Context, spec *CobDocumentSpec) error { +func (s *corpuStore) EmbeddingDocVector(ctx context.Context, spec *CobDocumentSpec) error { data, _, err := s.ListDocument(ctx, spec) if err != nil { return err @@ -355,3 +358,74 @@ func dbAfterDeleteCobDocument(ctx context.Context, db ormDB, obj *corpus.Documen _, err := dbBatchDeleteWithKeyID(ctx, db, corpus.DocVectorTable, "doc_id", obj.ID) return err } + +// InvokerForSearch 返回一个搜索知识库文档的 invoker +func (s *corpuStore) InvokerForSearch() mcps.Invoker { + return func(ctx context.Context, args map[string]any) (map[string]any, error) { + keywordArg, ok := args["keyword"] + if !ok { + return mcps.BuildToolErrorResult("missing required argument: keyword"), nil + } + keyword, ok := keywordArg.(string) + if !ok { + return mcps.BuildToolErrorResult("keyword argument must be a string"), nil + } + + // 参考 callKBSearch,使用 MatchSpec + docs, err := s.MatchDocments(ctx, MatchSpec{ + Question: keyword, + Limit: 5, + SkipKeywords: true, + }) + if err != nil { + return mcps.BuildToolErrorResult(err.Error()), nil + } + + if len(docs) == 0 { + return mcps.BuildToolSuccessResult("No relevant information found"), nil + } + + return mcps.BuildToolSuccessResult(docs.MarkdownText()), nil + } +} + +// InvokerForCreate 返回一个创建知识库文档的 invoker +func (s *corpuStore) InvokerForCreate() mcps.Invoker { + return func(ctx context.Context, args map[string]any) (map[string]any, error) { + if !IsKeeper(ctx) { + return mcps.BuildToolErrorResult("permission denied: keeper role required"), nil + } + + user, _ := UserFromContext(ctx) + logger().Infow("mcp call qa create", "args", args, "user", user) + + titleArg, ok := args["title"] + if !ok { + return mcps.BuildToolErrorResult("missing required argument: title"), nil + } + headingArg, ok := args["heading"] + if !ok { + return mcps.BuildToolErrorResult("missing required argument: heading"), nil + } + contentArg, ok := args["content"] + if !ok { + return mcps.BuildToolErrorResult("missing required argument: content"), nil + } + + docBasic := corpus.DocumentBasic{ + Title: titleArg.(string), + Heading: headingArg.(string), + Content: contentArg.(string), + } + docBasic.MetaAddKVs("creator", user.Name) + + obj, err := s.CreateDocument(ctx, docBasic) + if err != nil { + logger().Infow("create document fail", "title", docBasic.Title, "heading", docBasic.Heading, + "content", len(docBasic.Content), "err", err) + return mcps.BuildToolSuccessResult("Create KB document failed: " + err.Error()), nil + } + + return mcps.BuildToolSuccessResult("Created KB document with ID " + obj.StringID()), nil + } +} diff --git a/pkg/services/stores/mcps_gen.go b/pkg/services/stores/mcps_gen.go index b14b971..c473cd5 100644 --- a/pkg/services/stores/mcps_gen.go +++ b/pkg/services/stores/mcps_gen.go @@ -34,10 +34,13 @@ type MCPServerSpec struct { // * `streamable` // * `inMemory` - 内部运行 TransType string `extensions:"x-order=B" form:"transType" json:"transType" swaggertype:"string"` - // 状态 - // * `stopped` - 已停止 - // * `running` - 运行中 - Status mcps.Status `extensions:"x-order=C" form:"status" json:"status" swaggertype:"string"` + // 是否激活 + IsActive string `extensions:"x-order=C" form:"isActive" json:"isActive"` + // 连接状态 + // * `disconnected` - 断开 + // * `connecting` - 连接中 + // * `connected` - 已连接 + Status mcps.Status `extensions:"x-order=D" form:"status" json:"status" swaggertype:"string"` } func (spec *MCPServerSpec) Sift(q *ormQuery) *ormQuery { @@ -49,6 +52,7 @@ func (spec *MCPServerSpec) Sift(q *ormQuery) *ormQuery { q = q.Where("?TableAlias.trans_type = ?", v) } } + q, _ = siftEqual(q, "is_active", spec.IsActive, false) q, _ = siftEqual(q, "status", spec.Status, false) return q diff --git a/pkg/services/stores/wrap.go b/pkg/services/stores/wrap.go index 1aba87e..cb83b3c 100644 --- a/pkg/services/stores/wrap.go +++ b/pkg/services/stores/wrap.go @@ -96,7 +96,7 @@ var ( type Wrap struct { db *pgx.DB - cobStore *cobStore // gened + corpuStore *corpuStore // gened mcpStore *mcpStore // gened convoStore *convoStore // gened } @@ -107,7 +107,7 @@ func NewWithDB(db *pgx.DB) *Wrap { db: db, } - w.cobStore = &cobStore{w: w} // gened + w.corpuStore = &corpuStore{w: w} // gened w.mcpStore = &mcpStore{w: w} // gened w.convoStore = &convoStore{w: w} // gened @@ -165,7 +165,7 @@ func InitDB(ctx context.Context) error { return nil } -func (w *Wrap) Cob() CobStore { return w.cobStore } // Cob gened -func (w *Wrap) Qa() CobStore { return w.cobStore } // Deprecated: by Cob +func (w *Wrap) Cob() CobStore { return w.corpuStore } // Cob gened +func (w *Wrap) KB() CobStore { return w.corpuStore } // Deprecated: by Cob func (w *Wrap) MCP() MCPStore { return w.mcpStore } // MCP gened func (w *Wrap) Convo() ConvoStore { return w.convoStore } // Convo gened diff --git a/pkg/services/tools/connection.go b/pkg/services/tools/connection.go new file mode 100644 index 0000000..76fced6 --- /dev/null +++ b/pkg/services/tools/connection.go @@ -0,0 +1,22 @@ +package tools + +import ( + "fmt" + + "github.com/mark3labs/mcp-go/client" + + "github.com/liut/morign/pkg/models/mcps" +) + +// MCPConnection 表示一个到 MCP 服务器的连接 +type MCPConnection struct { + Name string + URL string + TransType mcps.TransType + client *client.Client + toolNames []string // 注册的工具名列表 +} + +func (mcpc *MCPConnection) getToolKey(name string) string { + return fmt.Sprintf("%s:%s", mcpc.Name, name) +} diff --git a/pkg/services/tools/defines.go b/pkg/services/tools/defines.go index 6fbe37a..33d9180 100644 --- a/pkg/services/tools/defines.go +++ b/pkg/services/tools/defines.go @@ -3,6 +3,8 @@ package tools import ( "fmt" "strings" + + "github.com/liut/morign/pkg/models/mcps" ) const ( @@ -11,6 +13,83 @@ const ( ToolNameFetch = "fetch" ) +// ToolDescriptor 变量定义 +var ( + // kbSearchDescriptor 知识库搜索工具描述 + kbSearchDescriptor = mcps.ToolDescriptor{ + Name: ToolNameKBSearch, + Description: "Search documents in knowledge base with keywords or subject", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "subject": map[string]any{ + "type": "string", + "description": "text of keywords or subject", + }, + }, + "required": []string{"subject"}, + }, + } + + // kbCreateDescriptor 知识库创建工具描述(需要 keeper 角色) + kbCreateDescriptor = mcps.ToolDescriptor{ + Name: ToolNameKBCreate, + Description: "Create new document of knowledge base, all parameters are required. Note: Unless the user explicitly requests supplementary content, do not invoke it. Before invoking, always perform a kb_search to confirm there is no corresponding subject or content. If similar content already exists, do not invoke even if requested by the user!", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "title": map[string]any{ + "type": "string", + "description": "title of document, like a main name or topic", + }, + "heading": map[string]any{ + "type": "string", + "description": "heading of document, like a sub name or property", + }, + "content": map[string]any{ + "type": "string", + "description": "long text of content of document", + }, + }, + "required": []string{"title", "heading", "content"}, + }, + } + + // fetchDescriptor 网页抓取工具描述 + fetchDescriptor = mcps.ToolDescriptor{ + Name: ToolNameFetch, + Description: "Fetches a URL from the internet and optionally extracts its contents as markdown", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "url": map[string]any{ + "type": "string", + "description": "URL to fetch", + }, + "max_length": map[string]any{ + "type": "number", + "description": "Maximum number of characters to return, default 5000", + "default": 5000, + "minimum": 0, + "maximum": 1000000, + }, + "start_index": map[string]any{ + "type": "number", + "description": "On return output starting at this character index, default 0", + "default": 0, + "minimum": 0, + }, + "raw": map[string]any{ + "type": "boolean", + "description": "Get the actual HTML content without simplification, default false", + "default": false, + }, + }, + "required": []string{"url"}, + }, + } +) + // ResultLogs 是工具调用结果的日志类型别名 type ResultLogs map[string]any diff --git a/pkg/services/tools/invokers.go b/pkg/services/tools/invokers.go index 6b96704..cf068dd 100644 --- a/pkg/services/tools/invokers.go +++ b/pkg/services/tools/invokers.go @@ -13,81 +13,13 @@ import ( htmd "github.com/JohannesKaufmann/html-to-markdown" readeck "codeberg.org/readeck/go-readability/v2" - "github.com/liut/morign/pkg/models/corpus" "github.com/liut/morign/pkg/models/mcps" - "github.com/liut/morign/pkg/services/stores" ) const ( DEFAULT_USER_AGENT_AUTONOMOUS = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" ) -// KB Search implementation -func (r *Registry) callKBSearch(ctx context.Context, args map[string]any) (map[string]any, error) { - logger().Infow("mcp call qa search", "args", args) - subjectArg, ok := args["subject"] - if !ok { - return mcps.BuildToolErrorResult("missing required argument: subject"), nil - } - subject, ok := subjectArg.(string) - if !ok { - return mcps.BuildToolErrorResult("subject argument must be a string"), nil - } - - docs, err := r.sto.Cob().MatchDocments(ctx, stores.MatchSpec{ - Question: subject, - Limit: 5, - SkipKeywords: true, - }) - if err != nil { - return mcps.BuildToolErrorResult(err.Error()), nil - } - logger().Infow("matched", "docs", len(docs)) - if len(docs) == 0 { - return mcps.BuildToolSuccessResult("No relevant information found"), nil - } - - return mcps.BuildToolSuccessResult(docs.MarkdownText()), nil -} - -// KB Create implementation -func (r *Registry) callKBCreate(ctx context.Context, args map[string]any) (map[string]any, error) { - if !stores.IsKeeper(ctx) { - return mcps.BuildToolErrorResult("permission denied: keeper role required"), nil - } - - user, _ := stores.UserFromContext(ctx) - logger().Infow("mcp call qa create", "args", args, "user", user) - - titleArg, ok := args["title"] - if !ok { - return mcps.BuildToolErrorResult("missing required argument: title"), nil - } - headingArg, ok := args["heading"] - if !ok { - return mcps.BuildToolErrorResult("missing required argument: heading"), nil - } - contentArg, ok := args["content"] - if !ok { - return mcps.BuildToolErrorResult("missing required argument: content"), nil - } - - docBasic := corpus.DocumentBasic{ - Title: titleArg.(string), - Heading: headingArg.(string), - Content: contentArg.(string), - } - docBasic.MetaAddKVs("creator", user.Name) - obj, err := r.sto.Cob().CreateDocument(ctx, docBasic) - if err != nil { - logger().Infow("create document fail", "title", docBasic.Title, "heading", docBasic.Heading, - "content", len(docBasic.Content), "err", err) - return mcps.BuildToolSuccessResult(fmt.Sprintf( - "Create KB document with title %q and heading %q is failed, %s", docBasic.Title, docBasic.Heading, err)), nil - } - return mcps.BuildToolSuccessResult(fmt.Sprintf("Created KB document with ID %s", obj.StringID())), nil -} - // Fetch implementation func (r *Registry) callFetch(ctx context.Context, args map[string]any) (map[string]any, error) { var ( diff --git a/pkg/services/tools/registry.go b/pkg/services/tools/registry.go index 0267c2f..342fd23 100644 --- a/pkg/services/tools/registry.go +++ b/pkg/services/tools/registry.go @@ -3,6 +3,7 @@ package tools import ( "context" "fmt" + "slices" "strings" "sync" @@ -14,10 +15,9 @@ import ( "github.com/liut/morign/pkg/services/stores" ) -type Invoker func(ctx context.Context, params map[string]any) (map[string]any, error) +type Invoker = mcps.Invoker type Registry struct { - sto stores.Storage tools []mcps.ToolDescriptor invokers map[string]Invoker @@ -32,6 +32,10 @@ type Registry struct { oauthClients map[string]*client.Client // token -> client 缓存 oauthClientsMu sync.Mutex clientInfo mcp.Implementation // MCP 客户端信息 + + // MCP Servers 连接容器(name -> connection) + servers map[string]*MCPConnection + serversMu sync.RWMutex } // RegistryOption 用于配置 Registry 的可选参数 @@ -59,12 +63,12 @@ func WithOAuthMCP(endpoint string, getToken func(ctx context.Context) string) Re // NewRegistry 创建工具注册表 func NewRegistry(sto stores.Storage, opts ...RegistryOption) *Registry { r := &Registry{ - sto: sto, tools: make([]mcps.ToolDescriptor, 0), invokers: make(map[string]Invoker), oauthClients: make(map[string]*client.Client), + servers: make(map[string]*MCPConnection), } - r.initTools() + r.initTools(sto) for _, opt := range opts { opt(r) @@ -73,6 +77,31 @@ func NewRegistry(sto stores.Storage, opts ...RegistryOption) *Registry { return r } +// AddInvoker 添加自定义工具 invoker +// name: 工具名称 +// fn: 工具调用函数 +// desc: 工具描述 +// inputSchema: 输入参数 schema +func (r *Registry) AddInvoker(name string, fn Invoker, desc string, inputSchema map[string]any) error { + // 检查工具名是否冲突 + if err := r.checkToolNameConflict(name); err != nil { + return err + } + + // 注册 invoker + r.invokers[name] = fn + + // 注册 ToolDescriptor + r.tools = append(r.tools, mcps.ToolDescriptor{ + Name: name, + Description: desc, + InputSchema: inputSchema, + }) + + logger().Infow("custom invoker added", "name", name) + return nil +} + func (r *Registry) Invoke(ctx context.Context, name string, params map[string]any) (map[string]any, error) { if name == "" { return mcps.BuildToolErrorResult("tool name is empty"), nil @@ -94,85 +123,20 @@ func (r *Registry) Invoke(ctx context.Context, name string, params map[string]an return mcps.BuildToolErrorResult("tool not found"), nil } -func (r *Registry) initTools() { +func (r *Registry) initTools(sto stores.Storage) { // Add KB tools - if r.sto != nil { + if sto != nil { // 公开工具:KBSearch - r.tools = append(r.tools, mcps.ToolDescriptor{ - Name: ToolNameKBSearch, - Description: "Search documents in knowledge base with keywords or subject", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "subject": map[string]any{ - "type": "string", - "description": "text of keywords or subject", - }, - }, - "required": []string{"subject"}, - }, - }) - r.invokers[ToolNameKBSearch] = r.callKBSearch + r.tools = append(r.tools, kbSearchDescriptor) + r.invokers[ToolNameKBSearch] = sto.Cob().InvokerForSearch() // 受限工具:KBCreate (需要 keeper 角色) - r.privTools = append(r.privTools, mcps.ToolDescriptor{ - Name: ToolNameKBCreate, - Description: "Create new document of knowledge base, all parameters are required. Note: Unless the user explicitly requests supplementary content, do not invoke it. Before invoking, always perform a kb_search to confirm there is no corresponding subject or content. If similar content already exists, do not invoke even if requested by the user!", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "title": map[string]any{ - "type": "string", - "description": "title of document, like a main name or topic", - }, - "heading": map[string]any{ - "type": "string", - "description": "heading of document, like a sub name or property", - }, - "content": map[string]any{ - "type": "string", - "description": "long text of content of document", - }, - }, - "required": []string{"title", "heading", "content"}, - }, - }) - r.invokers[ToolNameKBCreate] = r.callKBCreate + r.privTools = append(r.privTools, kbCreateDescriptor) + r.invokers[ToolNameKBCreate] = sto.Cob().InvokerForCreate() } // 公开工具:Fetch - r.tools = append(r.tools, mcps.ToolDescriptor{ - Name: ToolNameFetch, - Description: "Fetches a URL from the internet and optionally extracts its contents as markdown", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "url": map[string]any{ - "type": "string", - "description": "URL to fetch", - }, - "max_length": map[string]any{ - "type": "number", - "description": "Maximum number of characters to return, default 5000", - "default": 5000, - "minimum": 0, - "maximum": 1000000, - }, - "start_index": map[string]any{ - "type": "number", - "description": "On return output starting at this character index, default 0", - "default": 0, - "minimum": 0, - }, - "raw": map[string]any{ - "type": "boolean", - "description": "Get the actual HTML content without simplification, default false", - "default": false, - }, - }, - "required": []string{"url"}, - }, - }) + r.tools = append(r.tools, fetchDescriptor) r.invokers[ToolNameFetch] = r.callFetch logger().Debugw("init tools", "tools", r.tools, "priv", len(r.privTools)) @@ -243,14 +207,15 @@ func (r *Registry) initOAuthMCPTools(ctx context.Context) error { // 转换为本地 ToolDescriptor 并注册 invoker for _, tool := range result.Tools { + toolKey := fmt.Sprintf("oauth:%s", tool.Name) // InputSchema 是 ToolInputSchema 类型,需要转换 inputSchema := convertInputSchema(tool.InputSchema) r.tools = append(r.tools, mcps.ToolDescriptor{ - Name: tool.Name, + Name: toolKey, Description: tool.Description, InputSchema: inputSchema, }) - r.invokers[tool.Name] = func(ctx context.Context, params map[string]any) (map[string]any, error) { + r.invokers[toolKey] = func(ctx context.Context, params map[string]any) (map[string]any, error) { return r.callOAuthTool(ctx, tool.Name, params) } } @@ -344,10 +309,18 @@ func (r *Registry) callOAuthTool(ctx context.Context, name string, params map[st // convertInputSchema 将 ToolInputSchema 转换为 map[string]any func convertInputSchema(schema mcp.ToolInputSchema) map[string]any { + properties := schema.Properties + if properties == nil { + properties = make(map[string]any) + } + required := schema.Required + if required == nil { + required = make([]string, 0) + } return map[string]any{ "type": schema.Type, - "properties": schema.Properties, - "required": schema.Required, + "properties": properties, + "required": required, } } @@ -373,3 +346,227 @@ func convertMCPToolResult(result *mcp.CallToolResult) map[string]any { "content": content, }) } + +// AddServer 添加一个 MCP Server 并初始化连接 +// 仅支持远程传输类型(SSE 或 Streamable) +func (r *Registry) AddServer(ctx context.Context, server *mcps.Server) error { + // 验证传输类型 + if !server.TransType.IsRemote() { + return fmt.Errorf("unsupported transport type: %v (only SSE and Streamable are supported)", server.TransType) + } + + // 验证 URL + if server.URL == "" { + return fmt.Errorf("URL is required") + } + + // 检查名称冲突 + if err := r.checkToolNameConflict(server.Name); err != nil { + return err + } + + // 创建 transport(使用接口类型) + var tp transport.Interface + var err error + switch server.TransType { + case mcps.TransTypeSSE: + tp, err = transport.NewSSE(server.URL) + case mcps.TransTypeStreamable: + tp, err = transport.NewStreamableHTTP(server.URL) + default: + return fmt.Errorf("unsupported transport type: %v", server.TransType) + } + if err != nil { + return fmt.Errorf("failed to create transport: %w", err) + } + + // 创建并启动 client + c := client.NewClient(tp) + if err := c.Start(ctx); err != nil { + return fmt.Errorf("failed to start MCP client: %w", err) + } + + // 初始化 MCP 协议 + if _, err := c.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: r.clientInfo, + }, + }); err != nil { + _ = c.Close() + return fmt.Errorf("failed to initialize MCP: %w", err) + } + + // 获取工具列表 + result, err := c.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + _ = c.Close() + return fmt.Errorf("failed to list tools: %w", err) + } + + // 检查新工具名是否冲突 + for _, tool := range result.Tools { + if err := r.checkToolNameConflict(tool.Name); err != nil { + _ = c.Close() + return err + } + } + + // 注册工具 + r.serversMu.Lock() + mcpc := &MCPConnection{ + Name: server.Name, + URL: server.URL, + TransType: server.TransType, + client: c, + } + toolNames := make([]string, 0, len(result.Tools)) + for _, tool := range result.Tools { + toolKey := mcpc.getToolKey(tool.Name) + inputSchema := convertInputSchema(tool.InputSchema) + r.tools = append(r.tools, mcps.ToolDescriptor{ + Name: toolKey, + Description: tool.Description, + InputSchema: inputSchema, + }) + r.invokers[toolKey] = func(ctx context.Context, params map[string]any) (map[string]any, error) { + return r.callServerTool(ctx, server.Name, tool.Name, params) + } + toolNames = append(toolNames, toolKey) + logger().Infow("registered MCP tool", "server", server.Name, "tool", tool.Name) + } + mcpc.toolNames = toolNames + r.servers[server.Name] = mcpc + r.serversMu.Unlock() + + logger().Infow("MCP server added", "name", server.Name, "url", server.URL, "tools", len(result.Tools)) + return nil +} + +// checkToolNameConflict 检查工具名是否冲突 +func (r *Registry) checkToolNameConflict(name string) error { + // 检查是否与内置工具冲突 + switch name { + case ToolNameKBSearch, ToolNameKBCreate, ToolNameFetch: + return fmt.Errorf("tool name %q conflicts with built-in tool", name) + } + + // 检查是否与已注册的工具冲突 + for _, t := range r.tools { + if t.Name == name { + return fmt.Errorf("tool name %q already exists", name) + } + } + for _, t := range r.privTools { + if t.Name == name { + return fmt.Errorf("tool name %q already exists", name) + } + } + + // 检查是否与已注册的 server 冲突 + r.serversMu.RLock() + for _, s := range r.servers { + if s.Name == name { + r.serversMu.RUnlock() + return fmt.Errorf("server %q already exists", name) + } + } + r.serversMu.RUnlock() + + return nil +} + +// callServerTool 调用 MCP Server 工具 +func (r *Registry) callServerTool(ctx context.Context, serverName, toolName string, params map[string]any) (map[string]any, error) { + r.serversMu.RLock() + server, ok := r.servers[serverName] + r.serversMu.RUnlock() + + if !ok { + return mcps.BuildToolErrorResult("server not found"), nil + } + + // 确保 params 不为空 + if params == nil { + params = make(map[string]any) + } + + result, err := server.client.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: toolName, + Arguments: params, + }, + }) + if err != nil { + logger().Errorw("MCP server tool call failed", "server", serverName, "tool", toolName, "err", err) + return mcps.BuildToolErrorResult(err.Error()), nil + } + + return convertMCPToolResult(result), nil +} + +// LoadServers 加载所有 Running 状态的 MCP Server +func (r *Registry) LoadServers(ctx context.Context, sto stores.Storage) error { + if sto == nil { + logger().Warnw("no storage configured, skipping MCP server load") + return nil + } + + spec := &stores.MCPServerSpec{ + IsActive: "true", + } + spec.Limit = 2 + spec.Sort = "created DESC" + servers, _, err := sto.MCP().ListServer(ctx, spec) + if err != nil { + return fmt.Errorf("failed to list MCP servers: %w", err) + } + + for _, server := range servers { + if !server.TransType.IsRemote() { + logger().Infow("skipping non-remote MCP server", "name", server.Name, "type", server.TransType) + continue + } + if err := r.AddServer(ctx, &server); err != nil { + logger().Warnw("failed to load MCP server", "name", server.Name, "err", err) + continue + } + logger().Infow("loaded MCP server", "name", server.Name) + } + + logger().Info("MCP servers loaded", "count", len(servers)) + return nil +} + +// RemoveServer 移除 MCP Server 连接 +func (r *Registry) RemoveServer(name string) error { + r.serversMu.Lock() + defer r.serversMu.Unlock() + + conn, ok := r.servers[name] + if !ok { + return fmt.Errorf("server %q not found", name) + } + + // 关闭 client 连接 + if conn.client != nil { + _ = conn.client.Close() + } + + // 使用 toolNames 移除工具 + for _, toolName := range conn.toolNames { + delete(r.invokers, toolName) + } + + // 过滤掉该 server 的工具 + newTools := make([]mcps.ToolDescriptor, 0, len(r.tools)) + for _, tool := range r.tools { + if !slices.Contains(conn.toolNames, tool.Name) { + newTools = append(newTools, tool) + } + } + r.tools = newTools + delete(r.servers, name) + logger().Infow("MCP server removed", "name", name) + return nil +} diff --git a/pkg/web/api/api.go b/pkg/web/api/api.go index 418bb42..926f67d 100644 --- a/pkg/web/api/api.go +++ b/pkg/web/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "fmt" "net/http" "strings" @@ -78,6 +79,11 @@ func newapi(sto stores.Storage) *api { } toolreg := tools.NewRegistry(sto, opts...) + // 加载已激活的 MCP Servers + if err := toolreg.LoadServers(context.Background(), sto); err != nil { + logger().Warnw("failed to load MCP servers", "err", err) + } + return &api{ sto: sto, llm: stores.GetLLMClient(), @@ -217,6 +223,11 @@ func apiOk(w http.ResponseWriter, r *http.Request, args ...any) { render.JSON(w, r, res) } +// nolint +func idResult(id any) *resp.ResultID { + return &resp.ResultID{ID: id} +} + type Done = resp.Done type Failure = resp.Failure type ResultData = resp.ResultData diff --git a/pkg/web/api/handle_convo.go b/pkg/web/api/handle_convo.go index 9362e88..efed365 100644 --- a/pkg/web/api/handle_convo.go +++ b/pkg/web/api/handle_convo.go @@ -295,10 +295,13 @@ func (a *api) chatStreamResponseLoop(ccr *chatRequest, w http.ResponseWriter, r } w.Header().Add("Conversation-ID", ccr.cs.GetID()) + var iter int for { + iter++ // 调用流式响应处理 streamRes := a.doChatStream(ccr, w, r) - logger().Infow("stream round done", "answer_len", len(streamRes.answer), "toolCalls_len", len(streamRes.toolCalls)) + logger().Infow("stream round done", "iter", iter, "answer_len", len(streamRes.answer), + "toolCalls_len", len(streamRes.toolCalls)) // 累积答案 res.answer += streamRes.answer @@ -481,7 +484,7 @@ type SummaryRequest struct { // @Router /api/summary [post] func (a *api) postSummary(w http.ResponseWriter, r *http.Request) { var req SummaryRequest - if err := render.DecodeJSON(r.Body, &req); err != nil { + if err := binder.BindBody(r, &req); err != nil { fail(w, r, 400, "invalid request body") return } diff --git a/pkg/web/api/handle_mcps_gen.go b/pkg/web/api/handle_mcps_gen.go new file mode 100644 index 0000000..51bdeda --- /dev/null +++ b/pkg/web/api/handle_mcps_gen.go @@ -0,0 +1,168 @@ +// This file is generated - Do Not Edit. + +package api + +import ( + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/liut/morign/pkg/models/mcps" + "github.com/liut/morign/pkg/services/stores" + binder "github.com/marcsv/go-binder/binder" +) + +func init() { + regHI(true, "GET", "/mcp/servers", "mcp-servers-get", func(a *api) http.HandlerFunc { + return a.getMCPServers + }) + regHI(true, "GET", "/mcp/servers/:id", "mcp-servers-id-get", func(a *api) http.HandlerFunc { + return a.getMCPServer + }) + regHI(true, "POST", "/mcp/servers", "mcp-servers-post", func(a *api) http.HandlerFunc { + return a.postMCPServer + }) + regHI(true, "PUT", "/mcp/servers/:id", "mcp-servers-id-put", func(a *api) http.HandlerFunc { + return a.putMCPServer + }) + regHI(true, "DELETE", "/mcp/servers/:id", "mcp-servers-id-delete", func(a *api) http.HandlerFunc { + return a.deleteMCPServer + }) +} + +// @Tags 默认 文档生成 +// @ID mcp-servers-get +// @Summary 列出服务器 🔑 +// @Accept json +// @Produce json +// @Param token header string true "登录票据凭证" +// @Param query query stores.MCPServerSpec true "Object" +// @Success 200 {object} Done{result=ResultData{data=mcps.Servers}} +// @Failure 400 {object} Failure "请求或参数错误" +// @Failure 401 {object} Failure "未登录" +// @Failure 404 {object} Failure "目标未找到" +// @Failure 503 {object} Failure "服务端错误" +// @Router /api/mcp/servers [get] +func (a *api) getMCPServers(w http.ResponseWriter, r *http.Request) { + var spec stores.MCPServerSpec + if err := queryBinder.Bind(&spec, r.URL); err != nil { + fail(w, r, 400, err) + return + } + + ctx := r.Context() + data, total, err := a.sto.MCP().ListServer(ctx, &spec) + if err != nil { + fail(w, r, 503, err) + return + } + + success(w, r, dtResult(data, total)) +} + +// @Tags 默认 文档生成 +// @ID mcp-servers-id-get +// @Summary 获取服务器 🔑 +// @Accept json +// @Produce json +// @Param token header string true "登录票据凭证" +// @Param id path string true "编号" +// @Success 200 {object} Done{result=mcps.Server} +// @Failure 400 {object} Failure "请求或参数错误" +// @Failure 401 {object} Failure "未登录" +// @Failure 404 {object} Failure "目标未找到" +// @Failure 503 {object} Failure "服务端错误" +// @Router /api/mcp/servers/{id} [get] +func (a *api) getMCPServer(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + obj, err := a.sto.MCP().GetServer(r.Context(), id) + if err != nil { + fail(w, r, 503, err) + return + } + + success(w, r, obj) +} + +// @Tags 默认 文档生成 +// @ID mcp-servers-post +// @Summary 录入服务器 🔑 +// @Accept json,mpfd +// @Produce json +// @Param token header string true "登录票据凭证" +// @Param query body mcps.ServerBasic true "Object" +// @Success 200 {object} Done{result=ResultID} +// @Failure 400 {object} Failure "请求或参数错误" +// @Failure 401 {object} Failure "未登录" +// @Failure 403 {object} Failure "无权限" +// @Failure 503 {object} Failure "服务端错误" +// @Router /api/mcp/servers [post] +func (a *api) postMCPServer(w http.ResponseWriter, r *http.Request) { + var in mcps.ServerBasic + if err := binder.BindBody(r, &in); err != nil { + fail(w, r, 400, err) + return + } + + obj, err := a.sto.MCP().CreateServer(r.Context(), in) + if err != nil { + fail(w, r, 503, err) + return + } + + success(w, r, idResult(obj.ID)) +} + +// @Tags 默认 文档生成 +// @ID mcp-servers-id-put +// @Summary 更新服务器 🔑 +// @Accept json,mpfd +// @Produce json +// @Param token header string true "登录票据凭证" +// @Param id path string true "编号" +// @Param query body mcps.ServerSet true "Object" +// @Success 200 {object} Done{result=string} +// @Failure 400 {object} Failure "请求或参数错误" +// @Failure 401 {object} Failure "未登录" +// @Failure 403 {object} Failure "无权限" +// @Failure 503 {object} Failure "服务端错误" +// @Router /api/mcp/servers/{id} [put] +func (a *api) putMCPServer(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + var in mcps.ServerSet + if err := binder.BindBody(r, &in); err != nil { + fail(w, r, 400, err) + return + } + + err := a.sto.MCP().UpdateServer(r.Context(), id, in) + if err != nil { + fail(w, r, 503, err) + return + } + + success(w, r, "ok") +} + +// @Tags 默认 文档生成 +// @ID mcp-servers-id-delete +// @Summary 删除服务器 🔑 +// @Accept json +// @Produce json +// @Param token header string true "登录票据凭证" +// @Param id path string true "编号" +// @Success 200 {object} Done +// @Failure 400 {object} Failure "请求或参数错误" +// @Failure 401 {object} Failure "未登录" +// @Failure 403 {object} Failure "无权限" +// @Failure 503 {object} Failure "服务端错误" +// @Router /api/mcp/servers/{id} [delete] +func (a *api) deleteMCPServer(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + err := a.sto.MCP().DeleteServer(r.Context(), id) + if err != nil { + fail(w, r, 503, err) + return + } + + success(w, r, "ok") +} diff --git a/pkg/web/api/handle_mcps_x.go b/pkg/web/api/handle_mcps_x.go new file mode 100644 index 0000000..e5e3737 --- /dev/null +++ b/pkg/web/api/handle_mcps_x.go @@ -0,0 +1,125 @@ +package api + +import ( + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/liut/morign/pkg/models/mcps" +) + +func init() { + regHI(true, "PUT", "/mcp/servers/{id}/activate", "mcp-servers-id-activate", func(a *api) http.HandlerFunc { + return a.putMCPServerActivate + }) + regHI(true, "PUT", "/mcp/servers/{id}/deactivate", "mcp-servers-id-deactivate", func(a *api) http.HandlerFunc { + return a.putMCPServerDeactivate + }) +} + +// @Tags MCP +// @ID mcp-servers-id-activate +// @Summary 激活服务器 🔑 +// @Description 将 MCP Server 添加到工具注册表 +// @Accept json +// @Produce json +// @Param token header string true "登录票据凭证" +// @Param id path string true "编号" +// @Success 200 {object} Done{result=string} +// @Failure 400 {object} Failure "请求或参数错误" +// @Failure 401 {object} Failure "未登录" +// @Failure 403 {object} Failure "无权限" +// @Failure 503 {object} Failure "服务端错误" +// @Router /api/mcp/servers/{id}/activate [put] +func (a *api) putMCPServerActivate(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + + // 获取 Server 对象 + server, err := a.sto.MCP().GetServer(r.Context(), id) + if err != nil { + fail(w, r, 503, err) + return + } + if server == nil { + fail(w, r, 404, "server not found") + return + } + + // 先更新 IsActive 为 true + isActive := true + status := mcps.StatusConnecting + if err := a.sto.MCP().UpdateServer(r.Context(), id, mcps.ServerSet{ + IsActive: &isActive, + Status: &status, + }); err != nil { + fail(w, r, 503, err) + return + } + + // 再调用 AddServer 添加到工具注册表 + if err := a.toolreg.AddServer(r.Context(), server); err != nil { + fail(w, r, 503, err) + return + } + + // 更新 Server 状态为 connected + if server.Status != mcps.StatusConnected { + status := mcps.StatusConnected + if err := a.sto.MCP().UpdateServer(r.Context(), id, mcps.ServerSet{ + Status: &status, + }); err != nil { + fail(w, r, 503, err) + return + } + } + + success(w, r, "ok") +} + +// @Tags MCP +// @ID mcp-servers-id-deactivate +// @Summary 停用服务器 🔑 +// @Description 从工具注册表中移除 MCP Server +// @Accept json +// @Produce json +// @Param token header string true "登录票据凭证" +// @Param id path string true "编号" +// @Success 200 {object} Done{result=string} +// @Failure 400 {object} Failure "请求或参数错误" +// @Failure 401 {object} Failure "未登录" +// @Failure 403 {object} Failure "无权限" +// @Failure 503 {object} Failure "服务端错误" +// @Router /api/mcp/servers/{id}/deactivate [put] +func (a *api) putMCPServerDeactivate(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + + // 获取 Server 对象 + server, err := a.sto.MCP().GetServer(r.Context(), id) + if err != nil { + fail(w, r, 503, err) + return + } + if server == nil { + fail(w, r, 404, "server not found") + return + } + + // 从工具注册表中移除 + if err := a.toolreg.RemoveServer(server.Name); err != nil { + fail(w, r, 503, err) + return + } + + // 更新 IsActive 为 false + isActive := false + // 更新 Server 状态为 disconnected + status := mcps.StatusDisconnected + if err := a.sto.MCP().UpdateServer(r.Context(), id, mcps.ServerSet{ + IsActive: &isActive, + Status: &status, + }); err != nil { + fail(w, r, 503, err) + return + } + + success(w, r, "ok") +}