diff --git a/internal/core/session_manager/session.go b/internal/core/session_manager/session.go index 48d2a95b5..76f6ee47c 100644 --- a/internal/core/session_manager/session.go +++ b/internal/core/session_manager/session.go @@ -2,6 +2,7 @@ package session_manager import ( "context" + "encoding/json" "errors" "fmt" "sync" @@ -13,7 +14,6 @@ import ( "github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities" "github.com/langgenius/dify-plugin-daemon/pkg/utils/cache" "github.com/langgenius/dify-plugin-daemon/pkg/utils/log" - "github.com/langgenius/dify-plugin-daemon/pkg/utils/parser" ) var ( @@ -230,17 +230,29 @@ const ( PLUGIN_IN_STREAM_EVENT_RESPONSE PLUGIN_IN_STREAM_EVENT = "backwards_response" ) +type pluginInStreamMessage struct { + SessionID string `json:"session_id"` + ConversationID *string `json:"conversation_id"` + MessageID *string `json:"message_id"` + AppID *string `json:"app_id"` + EndpointID *string `json:"endpoint_id"` + Context map[string]interface{} `json:"context"` + Event PLUGIN_IN_STREAM_EVENT `json:"event"` + Data any `json:"data"` +} + func (s *Session) Message(event PLUGIN_IN_STREAM_EVENT, data any) []byte { - return parser.MarshalJsonBytes(map[string]any{ - "session_id": s.ID, - "conversation_id": s.ConversationID, - "message_id": s.MessageID, - "app_id": s.AppID, - "endpoint_id": s.EndpointID, - "context": s.Context, - "event": event, - "data": data, + b, _ := json.Marshal(pluginInStreamMessage{ + SessionID: s.ID, + ConversationID: s.ConversationID, + MessageID: s.MessageID, + AppID: s.AppID, + EndpointID: s.EndpointID, + Context: s.Context, + Event: event, + Data: data, }) + return b } func (s *Session) Write(event PLUGIN_IN_STREAM_EVENT, action access_types.PluginAccessAction, data any) error { diff --git a/internal/core/session_manager/session_message_test.go b/internal/core/session_manager/session_message_test.go new file mode 100644 index 000000000..c03b07479 --- /dev/null +++ b/internal/core/session_manager/session_message_test.go @@ -0,0 +1,44 @@ +package session_manager + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSessionMessagePayload(t *testing.T) { + conversationID := "conversation-id" + messageID := "message-id" + appID := "app-id" + endpointID := "endpoint-id" + + session := NewSession(NewSessionPayload{ + ConversationID: &conversationID, + MessageID: &messageID, + AppID: &appID, + EndpointID: &endpointID, + Context: map[string]any{ + "trace": "trace-value", + }, + IgnoreCache: true, + }) + t.Cleanup(func() { + DeleteSession(DeleteSessionPayload{ID: session.ID, IgnoreCache: true}) + }) + + payload := session.Message(PLUGIN_IN_STREAM_EVENT_REQUEST, map[string]any{ + "input": "hello", + }) + + var got map[string]any + require.NoError(t, json.Unmarshal(payload, &got)) + require.Equal(t, session.ID, got["session_id"]) + require.Equal(t, conversationID, got["conversation_id"]) + require.Equal(t, messageID, got["message_id"]) + require.Equal(t, appID, got["app_id"]) + require.Equal(t, endpointID, got["endpoint_id"]) + require.Equal(t, string(PLUGIN_IN_STREAM_EVENT_REQUEST), got["event"]) + require.Equal(t, map[string]any{"trace": "trace-value"}, got["context"]) + require.Equal(t, map[string]any{"input": "hello"}, got["data"]) +} diff --git a/pkg/entities/plugin_entities/model_declaration.go b/pkg/entities/plugin_entities/model_declaration.go index 2d78d35bf..5b60a1d33 100644 --- a/pkg/entities/plugin_entities/model_declaration.go +++ b/pkg/entities/plugin_entities/model_declaration.go @@ -8,7 +8,6 @@ import ( ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" en_translations "github.com/go-playground/validator/v10/translations/en" - "github.com/langgenius/dify-plugin-daemon/pkg/utils/log" "github.com/langgenius/dify-plugin-daemon/pkg/utils/mapping" "github.com/langgenius/dify-plugin-daemon/pkg/utils/parser" "github.com/langgenius/dify-plugin-daemon/pkg/validators" @@ -389,6 +388,26 @@ type ModelDeclaration struct { PriceConfig *ModelPriceConfig `json:"pricing" yaml:"pricing" validate:"omitempty"` } +func (m *ModelDeclaration) normalizeModelProperties() { + if m.ModelProperties == nil { + return + } + + if result, ok := mapping.ConvertAnyMap(m.ModelProperties).(map[string]any); ok { + m.ModelProperties = result + } +} + +func (m *ModelProviderDeclaration) NormalizeModelProperties() { + if m == nil { + return + } + + for i := range m.Models { + m.Models[i].normalizeModelProperties() + } +} + func (m *ModelDeclaration) UnmarshalJSON(data []byte) error { type alias ModelDeclaration @@ -410,6 +429,8 @@ func (m *ModelDeclaration) UnmarshalJSON(data []byte) error { m.ParameterRules = []ModelParameterRule{} } + m.normalizeModelProperties() + return nil } @@ -426,17 +447,6 @@ func (m ModelDeclaration) MarshalJSON() ([]byte, error) { temp.Label.EnUS = temp.Model } - // to avoid ModelProperties not serializable, we need to convert all the keys to string - // includes inner map and slice - if temp.ModelProperties != nil { - result, ok := mapping.ConvertAnyMap(temp.ModelProperties).(map[string]any) - if !ok { - log.Error("ModelProperties is not a map[string]any:", "model_properties", temp.ModelProperties) - } else { - temp.ModelProperties = result - } - } - return json.Marshal(temp) } @@ -461,6 +471,8 @@ func (m *ModelDeclaration) UnmarshalYAML(value *yaml.Node) error { m.ParameterRules = []ModelParameterRule{} } + m.normalizeModelProperties() + return nil } @@ -676,6 +688,7 @@ func (m *ModelProviderDeclaration) UnmarshalJSON(data []byte) error { return err } + m.NormalizeModelProperties() return nil } @@ -712,6 +725,8 @@ func (m *ModelProviderDeclaration) UnmarshalJSON(data []byte) error { m.Models = []ModelDeclaration{} } + m.NormalizeModelProperties() + return nil } @@ -799,6 +814,8 @@ func (m *ModelProviderDeclaration) UnmarshalYAML(value *yaml.Node) error { m.Models = []ModelDeclaration{} } + m.NormalizeModelProperties() + return nil } diff --git a/pkg/entities/plugin_entities/model_declaration_test.go b/pkg/entities/plugin_entities/model_declaration_test.go index 37ec71f31..698f658a9 100644 --- a/pkg/entities/plugin_entities/model_declaration_test.go +++ b/pkg/entities/plugin_entities/model_declaration_test.go @@ -2,6 +2,7 @@ package plugin_entities import ( "encoding/json" + "reflect" "testing" "github.com/langgenius/dify-plugin-daemon/pkg/utils/parser" @@ -23,6 +24,41 @@ func parse_yaml_to_json(data []byte) ([]byte, error) { return jsonData, nil } +func TestModelDeclarationNormalizeModelProperties(t *testing.T) { + model := ModelDeclaration{ + ModelProperties: map[string]any{ + "nested": map[any]any{ + "int_key": map[any]any{ + 1: "one", + }, + "slice": []any{ + map[any]any{ + 2: "two", + }, + }, + }, + }, + } + + model.normalizeModelProperties() + + expected := map[string]any{ + "nested": map[string]any{ + "int_key": map[string]any{ + "1": "one", + }, + "slice": []any{ + map[string]any{ + "2": "two", + }, + }, + }, + } + if !reflect.DeepEqual(model.ModelProperties, expected) { + t.Fatalf("unexpected normalized model properties: %#v", model.ModelProperties) + } +} + func TestFullFunctionModelProvider_Validate(t *testing.T) { const ( model_provider_template = ` diff --git a/pkg/entities/plugin_entities/plugin_declaration.go b/pkg/entities/plugin_entities/plugin_declaration.go index e93dd3525..d0eab0078 100644 --- a/pkg/entities/plugin_entities/plugin_declaration.go +++ b/pkg/entities/plugin_entities/plugin_declaration.go @@ -191,6 +191,14 @@ type PluginDeclaration struct { Trigger *TriggerProviderDeclaration `json:"trigger,omitempty" yaml:"trigger,omitempty" validate:"omitempty"` } +func (p *PluginDeclaration) NormalizeModelProperties() { + if p == nil || p.Model == nil { + return + } + + p.Model.NormalizeModelProperties() +} + func (p *PluginDeclaration) Category() PluginCategory { if p.Tool != nil || len(p.Plugins.Tools) != 0 { return PLUGIN_CATEGORY_TOOL @@ -240,6 +248,8 @@ func (p *PluginDeclaration) UnmarshalJSON(data []byte) error { p.Datasource = extra.Datasource p.Trigger = extra.Trigger + p.NormalizeModelProperties() + return nil } diff --git a/pkg/utils/cache/helper/combined.go b/pkg/utils/cache/helper/combined.go index d88f9e00b..078de3421 100644 --- a/pkg/utils/cache/helper/combined.go +++ b/pkg/utils/cache/helper/combined.go @@ -184,7 +184,8 @@ func CombinedGetPluginDeclaration( }, ) - if err == nil { + if err == nil && declaration != nil { + declaration.NormalizeModelProperties() // Store successful result in memory cache pluginCache.set(cacheKey, declaration) }