From 4462479493bdf52343ed943476f5258fbcedb87a Mon Sep 17 00:00:00 2001 From: fatelei Date: Thu, 30 Apr 2026 15:56:43 +0800 Subject: [PATCH] fix: return real http status code --- internal/server/controllers/agent.go | 6 +- internal/server/controllers/base.go | 23 +++ internal/server/controllers/base_test.go | 159 ++++++++++++++++++ internal/server/controllers/datasource.go | 6 +- internal/server/controllers/endpoint.go | 14 +- internal/server/controllers/model.go | 4 +- internal/server/controllers/plugins.go | 62 +++---- .../server/controllers/remote_debugging.go | 2 +- internal/server/controllers/tool.go | 8 +- internal/server/controllers/trigger.go | 6 +- internal/service/base_sse.go | 26 ++- 11 files changed, 254 insertions(+), 62 deletions(-) create mode 100644 internal/server/controllers/base_test.go diff --git a/internal/server/controllers/agent.go b/internal/server/controllers/agent.go index 5886d47b7..6b4da7d6a 100644 --- a/internal/server/controllers/agent.go +++ b/internal/server/controllers/agent.go @@ -1,8 +1,6 @@ package controllers import ( - "net/http" - "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/service" ) @@ -13,7 +11,7 @@ func ListAgentStrategies(c *gin.Context) { Page int `form:"page" validate:"required,min=1"` PageSize int `form:"page_size" validate:"required,min=1,max=256"` }) { - c.JSON(http.StatusOK, service.ListAgentStrategies(request.TenantID, request.Page, request.PageSize)) + JSONResponse(c, service.ListAgentStrategies(request.TenantID, request.Page, request.PageSize)) }) } @@ -23,6 +21,6 @@ func GetAgentStrategy(c *gin.Context) { PluginID string `form:"plugin_id" validate:"required"` Provider string `form:"provider" validate:"required"` }) { - c.JSON(http.StatusOK, service.GetAgentStrategy(request.TenantID, request.PluginID, request.Provider)) + JSONResponse(c, service.GetAgentStrategy(request.TenantID, request.PluginID, request.Provider)) }) } diff --git a/internal/server/controllers/base.go b/internal/server/controllers/base.go index a8df6f36a..8ce388c7e 100644 --- a/internal/server/controllers/base.go +++ b/internal/server/controllers/base.go @@ -2,14 +2,37 @@ package controllers import ( "errors" + "net/http" "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/server/constants" "github.com/langgenius/dify-plugin-daemon/internal/types/exception" + "github.com/langgenius/dify-plugin-daemon/pkg/entities" "github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities" "github.com/langgenius/dify-plugin-daemon/pkg/validators" ) +func statusCodeFromResponse(resp *entities.Response) int { + if resp == nil { + return http.StatusInternalServerError + } + + if resp.Code >= 0 { + return http.StatusOK + } + + status := -resp.Code + if status < 100 || status > 599 { + return http.StatusInternalServerError + } + + return status +} + +func JSONResponse(r *gin.Context, resp *entities.Response) { + r.JSON(statusCodeFromResponse(resp), resp) +} + func BindRequest[T any](r *gin.Context, success func(T)) { var request T diff --git a/internal/server/controllers/base_test.go b/internal/server/controllers/base_test.go new file mode 100644 index 000000000..57e65fdc8 --- /dev/null +++ b/internal/server/controllers/base_test.go @@ -0,0 +1,159 @@ +package controllers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/langgenius/dify-plugin-daemon/pkg/entities" +) + +func TestStatusCodeFromResponse(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + resp *entities.Response + want int + }{ + { + name: "nil response", + resp: nil, + want: http.StatusInternalServerError, + }, + { + name: "success response", + resp: &entities.Response{Code: 0}, + want: http.StatusOK, + }, + { + name: "positive code response", + resp: &entities.Response{Code: 123}, + want: http.StatusOK, + }, + { + name: "bad request response", + resp: &entities.Response{Code: -400}, + want: http.StatusBadRequest, + }, + { + name: "not found response", + resp: &entities.Response{Code: -404}, + want: http.StatusNotFound, + }, + { + name: "internal server error response", + resp: &entities.Response{Code: -500}, + want: http.StatusInternalServerError, + }, + { + name: "invalid low status code", + resp: &entities.Response{Code: -99}, + want: http.StatusInternalServerError, + }, + { + name: "invalid high status code", + resp: &entities.Response{Code: -600}, + want: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := statusCodeFromResponse(tt.resp) + if got != tt.want { + t.Fatalf("statusCodeFromResponse() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestJSONResponse(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + resp *entities.Response + wantStatus int + wantBody entities.Response + }{ + { + name: "success response", + resp: entities.NewSuccessResponse(map[string]any{"ok": true}), + wantStatus: http.StatusOK, + wantBody: entities.Response{ + Code: 0, + Message: "success", + Data: map[string]any{"ok": true}, + }, + }, + { + name: "bad request response", + resp: entities.NewDaemonErrorResponse(-400, "bad request"), + wantStatus: http.StatusBadRequest, + wantBody: entities.Response{ + Code: -400, + Message: "bad request", + Data: nil, + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + + JSONResponse(ctx, tt.resp) + + if recorder.Code != tt.wantStatus { + t.Fatalf("recorder.Code = %d, want %d", recorder.Code, tt.wantStatus) + } + + var got entities.Response + if err := json.Unmarshal(recorder.Body.Bytes(), &got); err != nil { + t.Fatalf("failed to unmarshal response body: %v", err) + } + + if got.Code != tt.wantBody.Code { + t.Fatalf("response code = %d, want %d", got.Code, tt.wantBody.Code) + } + if got.Message != tt.wantBody.Message { + t.Fatalf("response message = %q, want %q", got.Message, tt.wantBody.Message) + } + + if tt.wantBody.Data == nil { + if got.Data != nil { + t.Fatalf("response data = %#v, want nil", got.Data) + } + return + } + + gotMap, ok := got.Data.(map[string]any) + if !ok { + t.Fatalf("response data type = %T, want map[string]any", got.Data) + } + + wantMap := tt.wantBody.Data.(map[string]any) + if len(gotMap) != len(wantMap) { + t.Fatalf("response data length = %d, want %d", len(gotMap), len(wantMap)) + } + + for key, wantValue := range wantMap { + if gotMap[key] != wantValue { + t.Fatalf("response data[%q] = %#v, want %#v", key, gotMap[key], wantValue) + } + } + }) + } +} diff --git a/internal/server/controllers/datasource.go b/internal/server/controllers/datasource.go index 6b3af02ac..e571621a6 100644 --- a/internal/server/controllers/datasource.go +++ b/internal/server/controllers/datasource.go @@ -1,8 +1,6 @@ package controllers import ( - "net/http" - "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/service" ) @@ -13,7 +11,7 @@ func ListDatasources(c *gin.Context) { Page int `form:"page" validate:"required,min=1"` PageSize int `form:"page_size" validate:"required,min=1,max=256"` }) { - c.JSON(http.StatusOK, service.ListDatasources(request.TenantID, request.Page, request.PageSize)) + JSONResponse(c, service.ListDatasources(request.TenantID, request.Page, request.PageSize)) }) } @@ -23,6 +21,6 @@ func GetDatasource(c *gin.Context) { PluginID string `form:"plugin_id" validate:"required"` Provider string `form:"provider" validate:"required"` }) { - c.JSON(http.StatusOK, service.GetDatasource(request.TenantID, request.PluginID, request.Provider)) + JSONResponse(c, service.GetDatasource(request.TenantID, request.PluginID, request.Provider)) }) } diff --git a/internal/server/controllers/endpoint.go b/internal/server/controllers/endpoint.go index ea0c02ffa..9da0bb591 100644 --- a/internal/server/controllers/endpoint.go +++ b/internal/server/controllers/endpoint.go @@ -22,7 +22,7 @@ func SetupEndpoint(ctx *gin.Context) { pluginUniqueIdentifier := request.PluginUniqueIdentifier name := request.Name - ctx.JSON(200, service.SetupEndpoint( + JSONResponse(ctx, service.SetupEndpoint( tenantId, userId, pluginUniqueIdentifier, name, settings, )) }) @@ -38,7 +38,7 @@ func ListEndpoints(ctx *gin.Context) { page := request.Page pageSize := request.PageSize - ctx.JSON(200, service.ListEndpoints(tenantId, page, pageSize)) + JSONResponse(ctx, service.ListEndpoints(tenantId, page, pageSize)) }) } @@ -54,7 +54,7 @@ func ListPluginEndpoints(ctx *gin.Context) { page := request.Page pageSize := request.PageSize - ctx.JSON(200, service.ListPluginEndpoints(tenantId, pluginId, page, pageSize)) + JSONResponse(ctx, service.ListPluginEndpoints(tenantId, pluginId, page, pageSize)) }) } @@ -66,7 +66,7 @@ func RemoveEndpoint(ctx *gin.Context) { endpointId := request.EndpointID tenantId := request.TenantID - ctx.JSON(200, service.RemoveEndpoint(endpointId, tenantId)) + JSONResponse(ctx, service.RemoveEndpoint(endpointId, tenantId)) }) } @@ -84,7 +84,7 @@ func UpdateEndpoint(ctx *gin.Context) { settings := request.Settings name := request.Name - ctx.JSON(200, service.UpdateEndpoint(endpointId, tenantId, userId, name, settings)) + JSONResponse(ctx, service.UpdateEndpoint(endpointId, tenantId, userId, name, settings)) }) } @@ -96,7 +96,7 @@ func EnableEndpoint(ctx *gin.Context) { tenantId := request.TenantID endpointId := request.EndpointID - ctx.JSON(200, service.EnableEndpoint(endpointId, tenantId)) + JSONResponse(ctx, service.EnableEndpoint(endpointId, tenantId)) }) } @@ -108,6 +108,6 @@ func DisableEndpoint(ctx *gin.Context) { tenantId := request.TenantID endpointId := request.EndpointID - ctx.JSON(200, service.DisableEndpoint(endpointId, tenantId)) + JSONResponse(ctx, service.DisableEndpoint(endpointId, tenantId)) }) } diff --git a/internal/server/controllers/model.go b/internal/server/controllers/model.go index 383d1ae7e..9bbe35aac 100644 --- a/internal/server/controllers/model.go +++ b/internal/server/controllers/model.go @@ -1,8 +1,6 @@ package controllers import ( - "net/http" - "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/service" ) @@ -13,6 +11,6 @@ func ListModels(c *gin.Context) { Page int `form:"page" validate:"required,min=1"` PageSize int `form:"page_size" validate:"required,min=1,max=256"` }) { - c.JSON(http.StatusOK, service.ListModels(request.TenantID, request.Page, request.PageSize)) + JSONResponse(c, service.ListModels(request.TenantID, request.Page, request.PageSize)) }) } diff --git a/internal/server/controllers/plugins.go b/internal/server/controllers/plugins.go index 872fa6005..ab77b789d 100644 --- a/internal/server/controllers/plugins.go +++ b/internal/server/controllers/plugins.go @@ -32,18 +32,18 @@ func UploadPlugin(app *app.Config) gin.HandlerFunc { return func(c *gin.Context) { difyPkgFileHeader, err := c.FormFile("dify_pkg") if err != nil { - c.JSON(http.StatusOK, exception.BadRequestError(err).ToResponse()) + JSONResponse(c, exception.BadRequestError(err).ToResponse()) return } tenantId := c.Param("tenant_id") if tenantId == "" { - c.JSON(http.StatusOK, exception.BadRequestError(errors.New("tenant ID is required")).ToResponse()) + JSONResponse(c, exception.BadRequestError(errors.New("tenant ID is required")).ToResponse()) return } if difyPkgFileHeader.Size > app.MaxPluginPackageSize { - c.JSON(http.StatusOK, exception.BadRequestError(errors.New("file size exceeds the maximum limit")).ToResponse()) + JSONResponse(c, exception.BadRequestError(errors.New("file size exceeds the maximum limit")).ToResponse()) return } @@ -51,12 +51,12 @@ func UploadPlugin(app *app.Config) gin.HandlerFunc { difyPkgFile, err := difyPkgFileHeader.Open() if err != nil { - c.JSON(http.StatusOK, exception.BadRequestError(err).ToResponse()) + JSONResponse(c, exception.BadRequestError(err).ToResponse()) return } defer difyPkgFile.Close() - c.JSON(http.StatusOK, service.UploadPluginPkg(app, c, tenantId, difyPkgFile, verifySignature)) + JSONResponse(c, service.UploadPluginPkg(app, c, tenantId, difyPkgFile, verifySignature)) } } @@ -64,18 +64,18 @@ func UploadBundle(app *app.Config) gin.HandlerFunc { return func(c *gin.Context) { difyBundleFileHeader, err := c.FormFile("dify_bundle") if err != nil { - c.JSON(http.StatusOK, exception.BadRequestError(err).ToResponse()) + JSONResponse(c, exception.BadRequestError(err).ToResponse()) return } tenantId := c.Param("tenant_id") if tenantId == "" { - c.JSON(http.StatusOK, exception.BadRequestError(errors.New("tenant ID is required")).ToResponse()) + JSONResponse(c, exception.BadRequestError(errors.New("tenant ID is required")).ToResponse()) return } if difyBundleFileHeader.Size > app.MaxBundlePackageSize { - c.JSON(http.StatusOK, exception.BadRequestError(errors.New("file size exceeds the maximum limit")).ToResponse()) + JSONResponse(c, exception.BadRequestError(errors.New("file size exceeds the maximum limit")).ToResponse()) return } @@ -83,12 +83,12 @@ func UploadBundle(app *app.Config) gin.HandlerFunc { difyBundleFile, err := difyBundleFileHeader.Open() if err != nil { - c.JSON(http.StatusOK, exception.BadRequestError(err).ToResponse()) + JSONResponse(c, exception.BadRequestError(err).ToResponse()) return } defer difyBundleFile.Close() - c.JSON(http.StatusOK, service.UploadPluginBundle(app, c, tenantId, difyBundleFile, verifySignature)) + JSONResponse(c, service.UploadPluginBundle(app, c, tenantId, difyBundleFile, verifySignature)) } } @@ -102,16 +102,16 @@ func UpgradePlugin(app *app.Config) gin.HandlerFunc { Meta map[string]any `json:"meta" validate:"omitempty"` }) { if request.OriginalPluginUniqueIdentifier == request.NewPluginUniqueIdentifier { - c.JSON(http.StatusOK, exception.BadRequestError(errors.New("original and new plugin unique identifier are the same")).ToResponse()) + JSONResponse(c, exception.BadRequestError(errors.New("original and new plugin unique identifier are the same")).ToResponse()) return } if request.OriginalPluginUniqueIdentifier.PluginID() != request.NewPluginUniqueIdentifier.PluginID() { - c.JSON(http.StatusOK, exception.BadRequestError(errors.New("original and new plugin id are different")).ToResponse()) + JSONResponse(c, exception.BadRequestError(errors.New("original and new plugin id are different")).ToResponse()) return } - c.JSON(http.StatusOK, service.UpgradePlugin( + JSONResponse(c, service.UpgradePlugin( app, request.TenantID, request.Source, @@ -136,7 +136,7 @@ func InstallPluginFromIdentifiers(app *app.Config) gin.HandlerFunc { } if len(request.Metas) != len(request.PluginUniqueIdentifiers) { - c.JSON(http.StatusOK, exception.BadRequestError(errors.New("the number of metas must be equal to the number of plugin unique identifiers")).ToResponse()) + JSONResponse(c, exception.BadRequestError(errors.New("the number of metas must be equal to the number of plugin unique identifiers")).ToResponse()) return } @@ -146,7 +146,7 @@ func InstallPluginFromIdentifiers(app *app.Config) gin.HandlerFunc { } } - c.JSON(http.StatusOK, service.InstallMultiplePluginsToTenant( + JSONResponse(c, service.InstallMultiplePluginsToTenant( c.Request.Context(), app, request.TenantID, request.PluginUniqueIdentifiers, request.Source, request.Metas, )) }) @@ -168,7 +168,7 @@ func DecodePluginFromIdentifier(app *app.Config) gin.HandlerFunc { BindRequest(c, func(request struct { PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `form:"plugin_unique_identifier" validate:"required,plugin_unique_identifier"` }) { - c.JSON(http.StatusOK, service.DecodePluginFromIdentifier(app, request.PluginUniqueIdentifier)) + JSONResponse(c, service.DecodePluginFromIdentifier(app, request.PluginUniqueIdentifier)) }) } } @@ -179,7 +179,7 @@ func FetchPluginInstallationTasks(c *gin.Context) { Page int `form:"page" validate:"required,min=1"` PageSize int `form:"page_size" validate:"required,min=1,max=256"` }) { - c.JSON(http.StatusOK, service.FetchPluginInstallationTasks(request.TenantID, request.Page, request.PageSize)) + JSONResponse(c, service.FetchPluginInstallationTasks(request.TenantID, request.Page, request.PageSize)) }) } @@ -188,7 +188,7 @@ func FetchPluginInstallationTask(c *gin.Context) { TenantID string `uri:"tenant_id" validate:"required"` TaskID string `uri:"id" validate:"required"` }) { - c.JSON(http.StatusOK, service.FetchPluginInstallationTask(request.TenantID, request.TaskID)) + JSONResponse(c, service.FetchPluginInstallationTask(request.TenantID, request.TaskID)) }) } @@ -197,7 +197,7 @@ func DeletePluginInstallationTask(c *gin.Context) { TenantID string `uri:"tenant_id" validate:"required"` TaskID string `uri:"id" validate:"required"` }) { - c.JSON(http.StatusOK, service.DeletePluginInstallationTask(request.TenantID, request.TaskID)) + JSONResponse(c, service.DeletePluginInstallationTask(request.TenantID, request.TaskID)) }) } @@ -205,7 +205,7 @@ func DeleteAllPluginInstallationTasks(c *gin.Context) { BindRequest(c, func(request struct { TenantID string `uri:"tenant_id" validate:"required"` }) { - c.JSON(http.StatusOK, service.DeleteAllPluginInstallationTasks(request.TenantID)) + JSONResponse(c, service.DeleteAllPluginInstallationTasks(request.TenantID)) }) } @@ -218,11 +218,11 @@ func DeletePluginInstallationItemFromTask(c *gin.Context) { identifierString := strings.TrimLeft(request.Identifier, "/") identifier, err := plugin_entities.NewPluginUniqueIdentifier(identifierString) if err != nil { - c.JSON(http.StatusOK, exception.BadRequestError(err).ToResponse()) + JSONResponse(c, exception.BadRequestError(err).ToResponse()) return } - c.JSON(http.StatusOK, service.DeletePluginInstallationItemFromTask(request.TenantID, request.TaskID, identifier)) + JSONResponse(c, service.DeletePluginInstallationItemFromTask(request.TenantID, request.TaskID, identifier)) }) } @@ -231,7 +231,7 @@ func FetchPluginManifest(c *gin.Context) { TenantID string `uri:"tenant_id" validate:"required"` PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `form:"plugin_unique_identifier" validate:"required,plugin_unique_identifier"` }) { - c.JSON(http.StatusOK, service.FetchPluginManifest(request.PluginUniqueIdentifier)) + JSONResponse(c, service.FetchPluginManifest(request.PluginUniqueIdentifier)) }) } @@ -240,7 +240,7 @@ func FetchPluginReadme(c *gin.Context) { PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `form:"plugin_unique_identifier" validate:"required,plugin_unique_identifier"` Language string `form:"language" validate:"omitempty"` }) { - c.JSON(http.StatusOK, service.FetchPluginReadme(request.PluginUniqueIdentifier, request.Language)) + JSONResponse(c, service.FetchPluginReadme(request.PluginUniqueIdentifier, request.Language)) }) } @@ -249,7 +249,7 @@ func UninstallPlugin(c *gin.Context) { TenantID string `uri:"tenant_id" validate:"required"` PluginInstallationID string `json:"plugin_installation_id" validate:"required"` }) { - c.JSON(http.StatusOK, service.UninstallPlugin(request.TenantID, request.PluginInstallationID)) + JSONResponse(c, service.UninstallPlugin(request.TenantID, request.PluginInstallationID)) }) } @@ -257,7 +257,7 @@ func FetchPluginFromIdentifier(c *gin.Context) { BindRequest(c, func(request struct { PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `form:"plugin_unique_identifier" validate:"required,plugin_unique_identifier"` }) { - c.JSON(http.StatusOK, service.FetchPluginFromIdentifier(request.PluginUniqueIdentifier)) + JSONResponse(c, service.FetchPluginFromIdentifier(request.PluginUniqueIdentifier)) }) } @@ -267,7 +267,7 @@ func ListPlugins(c *gin.Context) { Page int `form:"page" validate:"required,min=1"` PageSize int `form:"page_size" validate:"required,min=1,max=256"` }) { - c.JSON(http.StatusOK, service.ListPlugins(request.TenantID, request.Page, request.PageSize)) + JSONResponse(c, service.ListPlugins(request.TenantID, request.Page, request.PageSize)) }) } @@ -276,7 +276,7 @@ func BatchFetchPluginInstallationByIDs(c *gin.Context) { TenantID string `uri:"tenant_id" validate:"required"` PluginIDs []string `json:"plugin_ids" validate:"required,max=256"` }) { - c.JSON(http.StatusOK, service.BatchFetchPluginInstallationByIDs(request.TenantID, request.PluginIDs)) + JSONResponse(c, service.BatchFetchPluginInstallationByIDs(request.TenantID, request.PluginIDs)) }) } @@ -285,7 +285,7 @@ func FetchMissingPluginInstallations(c *gin.Context) { TenantID string `uri:"tenant_id" validate:"required"` PluginUniqueIdentifiers []plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifiers" validate:"required,max=256,dive,plugin_unique_identifier"` }) { - c.JSON(http.StatusOK, service.FetchMissingPluginInstallations(request.TenantID, request.PluginUniqueIdentifiers)) + JSONResponse(c, service.FetchMissingPluginInstallations(request.TenantID, request.PluginUniqueIdentifiers)) }) } @@ -311,8 +311,8 @@ func SwitchServerlessEndpoint(c *gin.Context) { FunctionName string `json:"function_name" validate:"required"` FunctionURL string `json:"function_url" validate:"required"` }) { - c.JSON( - http.StatusOK, + JSONResponse( + c, service.SwitchServerlessEndpoint( request.PluginUniqueIdentifier, request.FunctionName, diff --git a/internal/server/controllers/remote_debugging.go b/internal/server/controllers/remote_debugging.go index c146199d7..fbe2fdd65 100644 --- a/internal/server/controllers/remote_debugging.go +++ b/internal/server/controllers/remote_debugging.go @@ -9,7 +9,7 @@ import ( func GetRemoteDebuggingKey(c *gin.Context) { BindRequest( c, func(request requests.RequestGetRemoteDebuggingKey) { - c.JSON(200, service.GetRemoteDebuggingKey(request.TenantID)) + JSONResponse(c, service.GetRemoteDebuggingKey(request.TenantID)) }, ) } diff --git a/internal/server/controllers/tool.go b/internal/server/controllers/tool.go index 38037c34a..e69829f04 100644 --- a/internal/server/controllers/tool.go +++ b/internal/server/controllers/tool.go @@ -1,8 +1,6 @@ package controllers import ( - "net/http" - "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/service" ) @@ -13,7 +11,7 @@ func ListTools(c *gin.Context) { Page int `form:"page" validate:"required,min=1"` PageSize int `form:"page_size" validate:"required,min=1,max=256"` }) { - c.JSON(http.StatusOK, service.ListTools(request.TenantID, request.Page, request.PageSize)) + JSONResponse(c, service.ListTools(request.TenantID, request.Page, request.PageSize)) }) } @@ -23,7 +21,7 @@ func GetTool(c *gin.Context) { PluginID string `form:"plugin_id" validate:"required"` Provider string `form:"provider" validate:"required"` }) { - c.JSON(http.StatusOK, service.GetTool(request.TenantID, request.PluginID, request.Provider)) + JSONResponse(c, service.GetTool(request.TenantID, request.PluginID, request.Provider)) }) } @@ -32,6 +30,6 @@ func CheckToolExistence(c *gin.Context) { TenantID string `uri:"tenant_id" validate:"required"` ProviderIDS []service.RequestCheckToolExistence `json:"provider_ids" validate:"required,dive"` }) { - c.JSON(http.StatusOK, service.CheckToolExistence(request.TenantID, request.ProviderIDS)) + JSONResponse(c, service.CheckToolExistence(request.TenantID, request.ProviderIDS)) }) } diff --git a/internal/server/controllers/trigger.go b/internal/server/controllers/trigger.go index 8eb476ab9..ed4afd3ad 100644 --- a/internal/server/controllers/trigger.go +++ b/internal/server/controllers/trigger.go @@ -1,8 +1,6 @@ package controllers import ( - "net/http" - "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/service" ) @@ -13,7 +11,7 @@ func ListTriggers(c *gin.Context) { Page int `form:"page" validate:"required,min=1"` PageSize int `form:"page_size" validate:"required,min=1,max=256"` }) { - c.JSON(http.StatusOK, service.ListTriggers(request.TenantID, request.Page, request.PageSize)) + JSONResponse(c, service.ListTriggers(request.TenantID, request.Page, request.PageSize)) }) } @@ -23,6 +21,6 @@ func GetTrigger(c *gin.Context) { PluginID string `form:"plugin_id" validate:"required"` Provider string `form:"provider" validate:"required"` }) { - c.JSON(http.StatusOK, service.GetTrigger(request.TenantID, request.PluginID, request.Provider)) + JSONResponse(c, service.GetTrigger(request.TenantID, request.PluginID, request.Provider)) }) } diff --git a/internal/service/base_sse.go b/internal/service/base_sse.go index bfe99c55c..89b0eaab3 100644 --- a/internal/service/base_sse.go +++ b/internal/service/base_sse.go @@ -2,6 +2,7 @@ package service import ( "errors" + "net/http" "sync/atomic" "time" @@ -18,6 +19,23 @@ import ( "github.com/langgenius/dify-plugin-daemon/pkg/utils/stream" ) +func statusCodeFromResponse(resp *entities.Response) int { + if resp == nil { + return http.StatusInternalServerError + } + + if resp.Code >= 0 { + return http.StatusOK + } + + status := -resp.Code + if status < 100 || status > 599 { + return http.StatusInternalServerError + } + + return status +} + // baseSSEService is a helper function to handle SSE service // it accepts a generator function that returns a stream response to gin context func baseSSEService[R any]( @@ -28,8 +46,6 @@ func baseSSEService[R any]( ) { startTime := time.Now() writer := ctx.Writer - writer.WriteHeader(200) - writer.Header().Set("Content-Type", "text/event-stream") done := make(chan bool) doneClosed := new(int32) @@ -48,7 +64,8 @@ func baseSSEService[R any]( pluginDaemonResponse, err := generator() if err != nil { - writeData(exception.InternalServerError(err).ToResponse()) + resp := exception.InternalServerError(err).ToResponse() + ctx.JSON(statusCodeFromResponse(resp), resp) duration := time.Since(startTime).Seconds() if onCompletion != nil { onCompletion("error", duration) @@ -57,6 +74,9 @@ func baseSSEService[R any]( return } + writer.Header().Set("Content-Type", "text/event-stream") + writer.WriteHeader(http.StatusOK) + routine.Submit(routinepkg.Labels{ routinepkg.RoutineLabelKeyModule: "service", routinepkg.RoutineLabelKeyMethod: "baseSSEService",