From 48b23c09de47c54da94499614f63ff923810a7c3 Mon Sep 17 00:00:00 2001 From: nfebe Date: Sat, 9 May 2026 10:33:09 +0100 Subject: [PATCH] fix(authz): Enforce per-deployment scope on container and resource actions Operators with read-only access to a deployment could previously perform write actions on its containers (SSH, lifecycle, resources) and see or mutate domains, SSL, backups, scheduler tasks, security events, traffic logs, and virtual hosts belonging to deployments they had no access to. Each protected action now checks the actor's access level for the target deployment, list endpoints filter results to the actor's accessible deployments, and global proxy sync is restricted to admins. WebSocket container exec rejects the session before signaling auth success when the actor lacks write access. API-key usage on the WebSocket path now updates last-used metadata for audit parity with HTTP requests. --- internal/api/authz.go | 100 +++++++++ internal/api/authz_test.go | 321 ++++++++++++++++++++++++++++ internal/api/backup_handlers.go | 76 +++++++ internal/api/container_exec.go | 23 +- internal/api/container_exec_test.go | 97 +++++++++ internal/api/resource_handlers.go | 8 + internal/api/scheduler_handlers.go | 80 ++++++- internal/api/security_handlers.go | 15 ++ internal/api/server.go | 76 ++++++- internal/api/traffic_handlers.go | 23 ++ internal/api/user_deployments.go | 4 + internal/auth/middleware.go | 56 ++++- internal/auth/middleware_test.go | 53 +++++ internal/docker/stats.go | 55 +++-- internal/networks/manager.go | 39 ++-- 15 files changed, 985 insertions(+), 41 deletions(-) create mode 100644 internal/api/authz.go create mode 100644 internal/api/authz_test.go create mode 100644 internal/auth/middleware_test.go diff --git a/internal/api/authz.go b/internal/api/authz.go new file mode 100644 index 0000000..c4414f8 --- /dev/null +++ b/internal/api/authz.go @@ -0,0 +1,100 @@ +package api + +import ( + "net/http" + "os/exec" + "strings" + + "github.com/flatrun/agent/internal/auth" + "github.com/gin-gonic/gin" +) + +const composeProjectLabel = "com.docker.compose.project" + +func (s *Server) requireDeploymentAccess(c *gin.Context, deploymentName, level string) bool { + if deploymentName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Deployment name required"}) + return false + } + + actor := auth.GetActorFromContext(c) + // Nil actor is allowed for direct handler tests; production routes set an actor via auth middleware + // or explicit anonymous-admin context when auth is disabled. + if actor == nil { + return true + } + if actor.Role == auth.RoleAdmin { + return true + } + + if !actor.CanAccessDeployment(deploymentName, level) { + c.JSON(http.StatusForbidden, gin.H{"error": "No access to this deployment"}) + return false + } + + return true +} + +func (s *Server) requireContainerAccess(c *gin.Context, containerID, level string) bool { + if containerID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Container ID required"}) + return false + } + + actor := auth.GetActorFromContext(c) + // Nil actor is allowed for direct handler tests; production routes set an actor via auth middleware + // or explicit anonymous-admin context when auth is disabled. + if actor == nil { + return true + } + if actor.Role == auth.RoleAdmin { + // Admins can see missing-container errors; non-admins below get a non-enumerating 403. + if _, err := containerDeploymentName(containerID); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Container not found"}) + return false + } + return true + } + + deploymentName, err := containerDeploymentName(containerID) + if err != nil || deploymentName == "" { + c.JSON(http.StatusForbidden, gin.H{"error": "No access to this container"}) + return false + } + + if !actor.CanAccessDeployment(deploymentName, level) { + c.JSON(http.StatusForbidden, gin.H{"error": "No access to this container"}) + return false + } + + return true +} + +func (s *Server) actorCanAccessContainer(c *gin.Context, containerID, level string) bool { + actor := auth.GetActorFromContext(c) + if actor == nil || actor.Role == auth.RoleAdmin { + return true + } + + deploymentName, err := containerDeploymentName(containerID) + if err != nil || deploymentName == "" { + return false + } + + return actor.CanAccessDeployment(deploymentName, level) +} + +func containerDeploymentName(containerID string) (string, error) { + cmd := exec.Command("docker", "inspect", "--format", "{{ index .Config.Labels \""+composeProjectLabel+"\" }}", containerID) + output, err := cmd.Output() + if err != nil { + return "", err + } + + deploymentName := strings.TrimSpace(string(output)) + if deploymentName == "" { + return "", nil + } + + return deploymentName, nil +} diff --git a/internal/api/authz_test.go b/internal/api/authz_test.go new file mode 100644 index 0000000..f4fe7e0 --- /dev/null +++ b/internal/api/authz_test.go @@ -0,0 +1,321 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/internal/backup" + "github.com/flatrun/agent/internal/contextkeys" + "github.com/flatrun/agent/internal/docker" + "github.com/flatrun/agent/internal/nginx" + "github.com/flatrun/agent/internal/proxy" + "github.com/flatrun/agent/internal/scheduler" + "github.com/flatrun/agent/internal/security" + "github.com/flatrun/agent/internal/traffic" + "github.com/flatrun/agent/pkg/config" + "github.com/flatrun/agent/pkg/models" + "github.com/gin-gonic/gin" +) + +func testActor(role auth.Role, deployments map[string]string) *auth.ActorContext { + return &auth.ActorContext{ + Type: "user", + Role: role, + Deployments: deployments, + } +} + +func actorMiddleware(actor *auth.ActorContext) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set(contextkeys.Actor, actor) + c.Next() + } +} + +func TestListVirtualHostsFiltersByDeploymentAccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "vhosts-authz-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + confDir := filepath.Join(tmpDir, "nginx", "conf.d") + if err := os.MkdirAll(confDir, 0755); err != nil { + t.Fatalf("failed to create conf dir: %v", err) + } + for _, name := range []string{"allowed-app.conf", "other-app.conf"} { + if err := os.WriteFile(filepath.Join(confDir, name), []byte("server {}"), 0644); err != nil { + t.Fatalf("failed to write vhost: %v", err) + } + } + + cfg := &config.Config{DeploymentsPath: tmpDir, Nginx: config.NginxConfig{ConfigPath: confDir}} + server := &Server{ + config: cfg, + proxyOrchestrator: proxy.NewOrchestratorWithManagers( + nginx.NewManager(&cfg.Nginx, tmpDir, ""), + nil, + ), + } + + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"allowed-app": auth.AccessLevelRead}))) + router.GET("/proxy/vhosts", server.listVirtualHosts) + + req := httptest.NewRequest(http.MethodGet, "/proxy/vhosts", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp struct { + VirtualHosts []struct { + Name string `json:"name"` + } `json:"virtual_hosts"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if len(resp.VirtualHosts) != 1 || resp.VirtualHosts[0].Name != "allowed-app" { + t.Fatalf("expected only allowed-app vhost, got %#v", resp.VirtualHosts) + } +} + +func TestSyncAllProxiesRequiresAdmin(t *testing.T) { + gin.SetMode(gin.TestMode) + + server := &Server{} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"app": auth.AccessLevelWrite}))) + router.POST("/proxy/sync", server.syncAllProxies) + + req := httptest.NewRequest(http.MethodPost, "/proxy/sync", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestSecurityEventsRequireDeploymentFilterForNonAdmin(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "security-authz-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + securityManager, err := security.NewManager(tmpDir) + if err != nil { + t.Fatalf("failed to create security manager: %v", err) + } + defer securityManager.Close() + + server := &Server{securityManager: securityManager} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"app": auth.AccessLevelRead}))) + router.GET("/security/events", server.listSecurityEvents) + + req := httptest.NewRequest(http.MethodGet, "/security/events", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestTrafficLogsKeepDBTotalForScopedNonAdmin(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "traffic-authz-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + trafficManager, err := traffic.NewManager(tmpDir, 7) + if err != nil { + t.Fatalf("failed to create traffic manager: %v", err) + } + defer trafficManager.Close() + + for i := 0; i < 2; i++ { + if _, err := trafficManager.IngestLog(&traffic.IngestTrafficLog{ + DeploymentName: "app", + RequestPath: "/", + RequestMethod: "GET", + StatusCode: 200, + SourceIP: "203.0.113.10", + }); err != nil { + t.Fatalf("failed to ingest traffic log: %v", err) + } + } + + server := &Server{trafficManager: trafficManager} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"app": auth.AccessLevelRead}))) + router.GET("/traffic/logs", server.getTrafficLogs) + + req := httptest.NewRequest(http.MethodGet, "/traffic/logs?deployment=app&limit=1", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp struct { + Total int `json:"total"` + Logs []traffic.TrafficLog `json:"logs"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if len(resp.Logs) != 1 { + t.Fatalf("expected one paged log, got %d", len(resp.Logs)) + } + if resp.Total != 2 { + t.Fatalf("expected DB total 2, got %d", resp.Total) + } +} + +func TestUpdateContainerResourcesValidatesBodyBeforeContainerLookup(t *testing.T) { + gin.SetMode(gin.TestMode) + + server := &Server{} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleAdmin, nil))) + router.PUT("/containers/:id/resources", server.updateContainerResources) + + req := httptest.NewRequest(http.MethodPut, "/containers/does-not-exist/resources", bytes.NewBufferString("{invalid")) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400 before container lookup, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestContainerLifecycleDeniesUnscopedContainerForNonAdmin(t *testing.T) { + gin.SetMode(gin.TestMode) + + server := &Server{} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"app": auth.AccessLevelWrite}))) + router.POST("/containers/:id/start", server.startContainer) + router.POST("/containers/:id/stop", server.stopContainer) + router.POST("/containers/:id/restart", server.restartContainer) + router.DELETE("/containers/:id", server.removeContainer) + + tests := []struct { + method string + path string + }{ + {http.MethodPost, "/containers/does-not-exist/start"}, + {http.MethodPost, "/containers/does-not-exist/stop"}, + {http.MethodPost, "/containers/does-not-exist/restart"}, + {http.MethodDelete, "/containers/does-not-exist"}, + } + + for _, tt := range tests { + req := httptest.NewRequest(tt.method, tt.path, nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + if w.Code != http.StatusForbidden { + t.Fatalf("%s %s: expected 403, got %d: %s", tt.method, tt.path, w.Code, w.Body.String()) + } + } +} + +func TestRestoreBackupRequiresWriteOnTargetDeployment(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "backup-authz-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + for _, name := range []string{"source-app", "target-app"} { + createTestDeployment(t, tmpDir, name, &models.ServiceMetadata{Name: name}) + } + + backupManager, err := backup.NewManager(tmpDir) + if err != nil { + t.Fatalf("failed to create backup manager: %v", err) + } + created, err := backupManager.CreateBackup(context.Background(), "source-app", nil) + if err != nil { + t.Fatalf("failed to create backup: %v", err) + } + + server := &Server{backupManager: backupManager} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{ + "source-app": auth.AccessLevelRead, + "target-app": auth.AccessLevelRead, + }))) + router.POST("/backups/:id/restore", server.restoreBackup) + + body := bytes.NewBufferString(`{"deployment_name":"target-app"}`) + req := httptest.NewRequest(http.MethodPost, "/backups/"+created.ID+"/restore", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestCreateScheduledTaskRequiresWriteDeploymentAccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "scheduler-authz-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + createTestDeployment(t, tmpDir, "app", &models.ServiceMetadata{Name: "app"}) + + backupManager, err := backup.NewManager(tmpDir) + if err != nil { + t.Fatalf("failed to create backup manager: %v", err) + } + manager := docker.NewManager(tmpDir) + schedulerManager, err := scheduler.NewManager(tmpDir, scheduler.NewExecutor(backupManager, manager)) + if err != nil { + t.Fatalf("failed to create scheduler manager: %v", err) + } + defer schedulerManager.Stop() + + server := &Server{manager: manager, schedulerManager: schedulerManager} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"app": auth.AccessLevelRead}))) + router.POST("/scheduler/tasks", server.createScheduledTask) + + body := bytes.NewBufferString(`{"name":"backup","type":"backup","deployment_name":"app","cron_expr":"0 * * * *","enabled":true}`) + req := httptest.NewRequest(http.MethodPost, "/scheduler/tasks", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String()) + } +} diff --git a/internal/api/backup_handlers.go b/internal/api/backup_handlers.go index c8d1c69..75b44db 100644 --- a/internal/api/backup_handlers.go +++ b/internal/api/backup_handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "strconv" + "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/backup" "github.com/flatrun/agent/pkg/models" "github.com/gin-gonic/gin" @@ -18,6 +19,11 @@ func (s *Server) listBackups(c *gin.Context) { filter := &backup.BackupListFilter{ DeploymentName: c.Query("deployment"), } + if filter.DeploymentName != "" { + if !s.requireDeploymentAccess(c, filter.DeploymentName, auth.AccessLevelRead) { + return + } + } if limit := c.Query("limit"); limit != "" { if l, err := strconv.Atoi(limit); err == nil { @@ -31,6 +37,17 @@ func (s *Server) listBackups(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := backups[:0] + for _, b := range backups { + if actor.CanAccessDeployment(b.DeploymentName, auth.AccessLevelRead) { + filtered = append(filtered, b) + } + } + backups = filtered + } + c.JSON(http.StatusOK, gin.H{"backups": backups}) } @@ -46,6 +63,9 @@ func (s *Server) getBackup(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } + if !s.requireDeploymentAccess(c, b.DeploymentName, auth.AccessLevelRead) { + return + } c.JSON(http.StatusOK, gin.H{"backup": b}) } @@ -61,6 +81,9 @@ func (s *Server) createBackup(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + if !s.requireDeploymentAccess(c, req.DeploymentName, auth.AccessLevelWrite) { + return + } deployment, err := s.manager.GetDeployment(req.DeploymentName) if err != nil { @@ -133,6 +156,15 @@ func (s *Server) deleteBackup(c *gin.Context) { } backupID := c.Param("id") + b, err := s.backupManager.GetBackup(backupID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + if !s.requireDeploymentAccess(c, b.DeploymentName, auth.AccessLevelWrite) { + return + } + if err := s.backupManager.DeleteBackup(backupID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -148,6 +180,15 @@ func (s *Server) downloadBackup(c *gin.Context) { } backupID := c.Param("id") + b, err := s.backupManager.GetBackup(backupID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + if !s.requireDeploymentAccess(c, b.DeploymentName, auth.AccessLevelRead) { + return + } + backupPath, err := s.backupManager.GetBackupPath(backupID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) @@ -217,6 +258,22 @@ func (s *Server) restoreBackup(c *gin.Context) { } req.BackupID = backupID + b, err := s.backupManager.GetBackup(backupID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + targetDeployment := b.DeploymentName + if req.DeploymentName != "" { + targetDeployment = req.DeploymentName + } + if !s.requireDeploymentAccess(c, b.DeploymentName, auth.AccessLevelRead) { + return + } + if !s.requireDeploymentAccess(c, targetDeployment, auth.AccessLevelWrite) { + return + } + jobID := s.backupManager.StartRestoreJob(&req) c.JSON(http.StatusAccepted, gin.H{"job_id": jobID, "message": "Restore job started"}) } @@ -233,6 +290,9 @@ func (s *Server) getBackupJob(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "Job not found"}) return } + if !s.requireDeploymentAccess(c, job.DeploymentName, auth.AccessLevelRead) { + return + } c.JSON(http.StatusOK, gin.H{"job": job}) } @@ -244,6 +304,11 @@ func (s *Server) listBackupJobs(c *gin.Context) { } deploymentName := c.Query("deployment") + if deploymentName != "" { + if !s.requireDeploymentAccess(c, deploymentName, auth.AccessLevelRead) { + return + } + } limit := 50 if l := c.Query("limit"); l != "" { if parsed, err := strconv.Atoi(l); err == nil { @@ -252,5 +317,16 @@ func (s *Server) listBackupJobs(c *gin.Context) { } jobs := s.backupManager.ListJobs(deploymentName, limit) + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := jobs[:0] + for _, job := range jobs { + if actor.CanAccessDeployment(job.DeploymentName, auth.AccessLevelRead) { + filtered = append(filtered, job) + } + } + jobs = filtered + } + c.JSON(http.StatusOK, gin.H{"jobs": jobs}) } diff --git a/internal/api/container_exec.go b/internal/api/container_exec.go index 3dbd711..ace77fc 100644 --- a/internal/api/container_exec.go +++ b/internal/api/container_exec.go @@ -11,6 +11,8 @@ import ( "time" "github.com/creack/pty" + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/internal/contextkeys" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) @@ -48,7 +50,6 @@ func (s *Server) containerExec(c *gin.Context) { } defer conn.Close() - // First-message authentication if s.authMiddleware.IsAuthEnabled() { _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second)) @@ -64,13 +65,28 @@ func (s *Server) containerExec(c *gin.Context) { return } - if !s.authMiddleware.ValidateTokenString(auth.Token) { + actor, err := s.authMiddleware.ActorForTokenString(auth.Token, c.ClientIP()) + if err != nil { sendError(conn, "Invalid or expired token") return } + c.Set(contextkeys.Actor, actor) _ = conn.SetReadDeadline(time.Time{}) + } else { + c.Set(contextkeys.Actor, &auth.ActorContext{Type: "anonymous", Role: auth.RoleAdmin}) + } + actor := auth.GetActorFromContext(c) + if actor == nil || !actor.HasPermission(auth.PermContainersWrite) { + sendError(conn, "Permission denied: containers:write required") + return + } + if !s.actorCanAccessContainer(c, containerID, auth.AccessLevelWrite) { + sendError(conn, "No access to this container") + return + } + if s.authMiddleware.IsAuthEnabled() { if err := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"auth_success"}`)); err != nil { return } @@ -192,6 +208,9 @@ func sendError(conn *websocket.Conn, msg string) { func (s *Server) containerExecHTTP(c *gin.Context) { containerID := c.Param("id") + if !s.requireContainerAccess(c, containerID, auth.AccessLevelWrite) { + return + } var req struct { Command string `json:"command"` diff --git a/internal/api/container_exec_test.go b/internal/api/container_exec_test.go index e072fb3..c822dbd 100644 --- a/internal/api/container_exec_test.go +++ b/internal/api/container_exec_test.go @@ -3,11 +3,17 @@ package api import ( "bytes" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "os" + "strings" "testing" + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/pkg/config" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func TestDetectShell(t *testing.T) { @@ -69,6 +75,97 @@ func TestWebSocketUpgrader(t *testing.T) { } } +func TestContainerExecWebSocketDeniesReadOnlyDeploymentAccessBeforeAuthSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "container-exec-ws-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.AuthConfig{ + Enabled: true, + JWTSecret: "test-jwt-secret-for-ws-exec", + } + os.Setenv("FLATRUN_ADMIN_PASSWORD", "testadminpass") + defer os.Unsetenv("FLATRUN_ADMIN_PASSWORD") + + authManager, err := auth.NewManager(tmpDir, cfg, true) + if err != nil { + t.Fatalf("failed to create auth manager: %v", err) + } + defer authManager.Close() + + user, err := authManager.CreateUser("readonly-operator", "", "password", auth.RoleOperator, nil) + if err != nil { + t.Fatalf("failed to create user: %v", err) + } + admin, err := authManager.GetUserByUsername("admin") + if err != nil { + t.Fatalf("failed to get admin: %v", err) + } + if err := authManager.AssignDeployment(user.ID, "app", auth.AccessLevelRead, admin.ID); err != nil { + t.Fatalf("failed to assign deployment: %v", err) + } + + authMiddleware := auth.NewMiddlewareWithManager(cfg, authManager) + token, err := authMiddleware.GenerateJWTForUser(user, "") + if err != nil { + t.Fatalf("failed to generate token: %v", err) + } + + server := &Server{authMiddleware: authMiddleware} + router := gin.New() + router.GET("/containers/:id/exec", server.containerExec) + + httpServer := newSkippableHTTPServer(t, router) + defer httpServer.Close() + + wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "/containers/does-not-exist/exec" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial websocket: %v", err) + } + defer conn.Close() + + if err := conn.WriteJSON(authMessage{Type: "auth", Token: token}); err != nil { + t.Fatalf("failed to write auth message: %v", err) + } + + _, message, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read websocket response: %v", err) + } + + text := string(message) + if strings.Contains(text, "auth_success") { + t.Fatalf("received auth_success before authorization denial: %q", text) + } + if !strings.Contains(text, "No access to this container") { + t.Fatalf("expected no-access denial, got %q", text) + } +} + +func newSkippableHTTPServer(t *testing.T, handler http.Handler) *httptest.Server { + t.Helper() + + var srv *httptest.Server + func() { + defer func() { + if r := recover(); r != nil { + msg := fmt.Sprint(r) + if strings.Contains(strings.ToLower(msg), "operation not permitted") { + t.Skipf("httptest listener not permitted in this environment: %v", r) + } + panic(r) + } + }() + srv = httptest.NewServer(handler) + }() + return srv +} + func TestAuthMessageParsing(t *testing.T) { tests := []struct { name string diff --git a/internal/api/resource_handlers.go b/internal/api/resource_handlers.go index c40aee4..70d6f63 100644 --- a/internal/api/resource_handlers.go +++ b/internal/api/resource_handlers.go @@ -3,12 +3,16 @@ package api import ( "net/http" + "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/docker" "github.com/gin-gonic/gin" ) func (s *Server) getContainerResources(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelRead) { + return + } resources, err := docker.GetContainerResources(id) if err != nil { @@ -42,6 +46,10 @@ func (s *Server) updateContainerResources(c *gin.Context) { return } + if !s.requireContainerAccess(c, id, auth.AccessLevelWrite) { + return + } + if err := docker.UpdateContainerResources(id, &update); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": err.Error(), diff --git a/internal/api/scheduler_handlers.go b/internal/api/scheduler_handlers.go index 59edda7..ba5b191 100644 --- a/internal/api/scheduler_handlers.go +++ b/internal/api/scheduler_handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "strconv" + "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/scheduler" "github.com/gin-gonic/gin" ) @@ -20,6 +21,9 @@ func (s *Server) listScheduledTasks(c *gin.Context) { var err error if deploymentName != "" { + if !s.requireDeploymentAccess(c, deploymentName, auth.AccessLevelRead) { + return + } tasks, err = s.schedulerManager.GetTasksByDeployment(deploymentName) } else { tasks, err = s.schedulerManager.GetAllTasks() @@ -30,6 +34,17 @@ func (s *Server) listScheduledTasks(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := tasks[:0] + for _, task := range tasks { + if actor.CanAccessDeployment(task.DeploymentName, auth.AccessLevelRead) { + filtered = append(filtered, task) + } + } + tasks = filtered + } + c.JSON(http.StatusOK, gin.H{"tasks": tasks}) } @@ -50,6 +65,9 @@ func (s *Server) getScheduledTask(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) return } + if !s.requireDeploymentAccess(c, task.DeploymentName, auth.AccessLevelRead) { + return + } c.JSON(http.StatusOK, gin.H{"task": task}) } @@ -65,6 +83,9 @@ func (s *Server) createScheduledTask(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + if !s.requireDeploymentAccess(c, req.DeploymentName, auth.AccessLevelWrite) { + return + } if err := s.schedulerManager.ValidateCronExpr(req.CronExpr); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid cron expression: " + err.Error()}) @@ -107,6 +128,15 @@ func (s *Server) updateScheduledTask(c *gin.Context) { return } + existingTask, err := s.schedulerManager.GetTask(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) + return + } + if !s.requireDeploymentAccess(c, existingTask.DeploymentName, auth.AccessLevelWrite) { + return + } + var req scheduler.UpdateTaskRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -121,11 +151,6 @@ func (s *Server) updateScheduledTask(c *gin.Context) { } if req.Config != nil && req.Config.CommandConfig != nil { - existingTask, err := s.schedulerManager.GetTask(id) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) - return - } resolved, err := s.resolveService(existingTask.DeploymentName, req.Config.CommandConfig.Service) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -154,6 +179,14 @@ func (s *Server) deleteScheduledTask(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"}) return } + task, err := s.schedulerManager.GetTask(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) + return + } + if !s.requireDeploymentAccess(c, task.DeploymentName, auth.AccessLevelWrite) { + return + } if err := s.schedulerManager.DeleteTask(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -174,6 +207,14 @@ func (s *Server) runTaskNow(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"}) return } + task, err := s.schedulerManager.GetTask(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) + return + } + if !s.requireDeploymentAccess(c, task.DeploymentName, auth.AccessLevelWrite) { + return + } if err := s.schedulerManager.RunTaskNow(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -194,6 +235,14 @@ func (s *Server) getTaskExecutions(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"}) return } + task, err := s.schedulerManager.GetTask(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) + return + } + if !s.requireDeploymentAccess(c, task.DeploymentName, auth.AccessLevelRead) { + return + } limit := 50 if l := c.Query("limit"); l != "" { @@ -230,5 +279,26 @@ func (s *Server) getRecentExecutions(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := executions[:0] + taskDeployments := make(map[int64]string) + for _, execution := range executions { + deploymentName, ok := taskDeployments[execution.TaskID] + if !ok { + task, err := s.schedulerManager.GetTask(execution.TaskID) + if err != nil { + continue + } + deploymentName = task.DeploymentName + taskDeployments[execution.TaskID] = deploymentName + } + if actor.CanAccessDeployment(deploymentName, auth.AccessLevelRead) { + filtered = append(filtered, execution) + } + } + executions = filtered + } + c.JSON(http.StatusOK, gin.H{"executions": executions}) } diff --git a/internal/api/security_handlers.go b/internal/api/security_handlers.go index 83fee8a..bfd670f 100644 --- a/internal/api/security_handlers.go +++ b/internal/api/security_handlers.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/security" "github.com/flatrun/agent/pkg/config" "github.com/flatrun/agent/pkg/models" @@ -86,6 +87,17 @@ func (s *Server) listSecurityEvents(c *gin.Context) { SourceIP: c.Query("source_ip"), DeploymentName: c.Query("deployment"), } + if filter.DeploymentName != "" { + if !s.requireDeploymentAccess(c, filter.DeploymentName, auth.AccessLevelRead) { + return + } + } else { + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Deployment filter required"}) + return + } + } if limit := c.Query("limit"); limit != "" { if l, err := strconv.Atoi(limit); err == nil { @@ -145,6 +157,9 @@ func (s *Server) getSecurityEvent(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "Event not found"}) return } + if event.DeploymentName != "" && !s.requireDeploymentAccess(c, event.DeploymentName, auth.AccessLevelRead) { + return + } c.JSON(http.StatusOK, gin.H{"event": event}) } diff --git a/internal/api/server.go b/internal/api/server.go index 78cc9d5..54f6001 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -348,9 +348,9 @@ func (s *Server) setupRoutes() { protected.POST("/deployments/:name/certificates/renew", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.renewDeploymentCertificates) // Proxy endpoints - protected.GET("/proxy/status/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.getProxyStatus) - protected.POST("/proxy/setup/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.setupProxy) - protected.DELETE("/proxy/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesDelete), s.teardownProxy) + protected.GET("/proxy/status/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getProxyStatus) + protected.POST("/proxy/setup/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.setupProxy) + protected.DELETE("/proxy/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesDelete), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.teardownProxy) protected.GET("/proxy/vhosts", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.listVirtualHosts) protected.POST("/proxy/sync", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.syncAllProxies) protected.POST("/deployments/:name/ssl/disable", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.disableSSL) @@ -576,7 +576,7 @@ func (s *Server) setupRoutes() { } // Get users with access to a deployment - protected.GET("/deployments/:name/users", s.authMiddleware.RequirePermission(auth.PermUsersRead), s.getDeploymentUsers) + protected.GET("/deployments/:name/users", s.authMiddleware.RequirePermission(auth.PermUsersRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelAdmin), s.getDeploymentUsers) } // DNS plugin routes @@ -1918,6 +1918,10 @@ func (s *Server) connectContainer(c *gin.Context) { return } + if !s.requireContainerAccess(c, req.Container, auth.AccessLevelWrite) { + return + } + if err := s.networksManager.ConnectContainer(networkName, req.Container); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": err.Error(), @@ -1946,6 +1950,10 @@ func (s *Server) disconnectContainer(c *gin.Context) { return } + if !s.requireContainerAccess(c, req.Container, auth.AccessLevelWrite) { + return + } + if err := s.networksManager.DisconnectContainer(networkName, req.Container); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": err.Error(), @@ -3800,6 +3808,18 @@ func (s *Server) listVirtualHosts(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := vhosts[:0] + for _, vhost := range vhosts { + // VirtualHostInfo.Name is the deployment name derived from .conf. + if actor.CanAccessDeployment(vhost.Name, auth.AccessLevelRead) { + filtered = append(filtered, vhost) + } + } + vhosts = filtered + } + c.JSON(http.StatusOK, gin.H{ "virtual_hosts": vhosts, }) @@ -3814,6 +3834,12 @@ type ProxySyncResult struct { } func (s *Server) syncAllProxies(c *gin.Context) { + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Admin access required"}) + return + } + deployments, err := s.manager.ListDeployments() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4325,6 +4351,18 @@ func (s *Server) listContainers(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := containers[:0] + for _, container := range containers { + // Non-admins only see FlatRun/Compose containers assigned through deployment access. + if actor.CanAccessDeployment(container.DeploymentName, auth.AccessLevelRead) { + filtered = append(filtered, container) + } + } + containers = filtered + } + c.JSON(http.StatusOK, gin.H{ "containers": containers, }) @@ -4332,6 +4370,9 @@ func (s *Server) listContainers(c *gin.Context) { func (s *Server) startContainer(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelWrite) { + return + } if err := s.networksManager.StartContainer(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4348,6 +4389,9 @@ func (s *Server) startContainer(c *gin.Context) { func (s *Server) stopContainer(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelWrite) { + return + } if err := s.networksManager.StopContainer(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4364,6 +4408,9 @@ func (s *Server) stopContainer(c *gin.Context) { func (s *Server) restartContainer(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelWrite) { + return + } if err := s.networksManager.RestartContainer(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4380,6 +4427,9 @@ func (s *Server) restartContainer(c *gin.Context) { func (s *Server) removeContainer(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelAdmin) { + return + } if err := s.networksManager.RemoveContainer(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4396,6 +4446,9 @@ func (s *Server) removeContainer(c *gin.Context) { func (s *Server) getContainerLogs(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelRead) { + return + } tailStr := c.DefaultQuery("tail", "100") tail, err := strconv.Atoi(tailStr) @@ -4419,6 +4472,9 @@ func (s *Server) getContainerLogs(c *gin.Context) { func (s *Server) getContainerStats(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelRead) { + return + } stats, err := docker.GetContainerStats(id) if err != nil { @@ -4442,6 +4498,18 @@ func (s *Server) getAllContainerStats(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := stats[:0] + for _, stat := range stats { + // Non-admins only see FlatRun/Compose containers assigned through deployment access. + if actor.CanAccessDeployment(stat.DeploymentName, auth.AccessLevelRead) { + filtered = append(filtered, stat) + } + } + stats = filtered + } + c.JSON(http.StatusOK, gin.H{ "stats": stats, }) diff --git a/internal/api/traffic_handlers.go b/internal/api/traffic_handlers.go index 0e7fdc2..5622572 100644 --- a/internal/api/traffic_handlers.go +++ b/internal/api/traffic_handlers.go @@ -5,6 +5,7 @@ import ( "strconv" "time" + "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/traffic" "github.com/gin-gonic/gin" ) @@ -45,6 +46,17 @@ func (s *Server) getTrafficLogs(c *gin.Context) { SourceIP: c.Query("source_ip"), RequestPath: c.Query("path"), } + if filter.DeploymentName != "" { + if !s.requireDeploymentAccess(c, filter.DeploymentName, auth.AccessLevelRead) { + return + } + } else { + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Deployment filter required"}) + return + } + } if statusCode := c.Query("status_code"); statusCode != "" { if code, err := strconv.Atoi(statusCode); err == nil { @@ -100,6 +112,17 @@ func (s *Server) getTrafficStats(c *gin.Context) { } deploymentName := c.Query("deployment") + if deploymentName != "" { + if !s.requireDeploymentAccess(c, deploymentName, auth.AccessLevelRead) { + return + } + } else { + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Deployment filter required"}) + return + } + } since := 24 * time.Hour if sinceStr := c.Query("since"); sinceStr != "" { diff --git a/internal/api/user_deployments.go b/internal/api/user_deployments.go index e35db5c..7505fdd 100644 --- a/internal/api/user_deployments.go +++ b/internal/api/user_deployments.go @@ -135,6 +135,10 @@ func (s *Server) removeUserDeployment(c *gin.Context) { return } + if !s.requireDeploymentAccess(c, deploymentName, auth.AccessLevelAdmin) { + return + } + if err := s.authManager.RemoveDeploymentAccess(userID, deploymentName); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to remove deployment access"}) return diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 3e5b59b..41e84cf 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -489,10 +489,62 @@ func (m *Middleware) GetAuthStatus(c *gin.Context) { } func (m *Middleware) ValidateTokenString(token string) bool { + _, err := m.ActorForTokenString(token, "") + return err == nil +} + +func (m *Middleware) ActorForTokenString(token string, clientIP string) (*ActorContext, error) { if !m.config.Enabled { - return true + return &ActorContext{ + Type: "anonymous", + Role: RoleAdmin, + }, nil + } + + if claims := m.validateJWTWithClaims(token); claims != nil { + if m.manager != nil && claims.SessionID != "" { + session, err := m.manager.GetSessionByID(claims.SessionID) + if err != nil || !session.RevokedAt.IsZero() { + return nil, fmt.Errorf("session revoked or invalid") + } + } + + if m.manager != nil && claims.UserID > 0 { + user, err := m.manager.GetUser(claims.UserID) + if err != nil { + return nil, err + } + if !user.IsActive { + return nil, ErrUserInactive + } + return m.manager.BuildActorContext(user, nil) + } + + return &ActorContext{ + Type: "jwt", + Role: RoleAdmin, + }, nil + } + + if m.manager != nil { + apiKey, user, err := m.manager.ValidateAPIKey(token) + if err == nil { + if clientIP != "" { + _ = m.manager.UpdateAPIKeyLastUsed(apiKey.ID, clientIP) + } + return m.manager.BuildActorContext(user, apiKey) + } } - return m.validateJWT(token) || m.validateAPIKey(token) + + if m.validateAPIKey(token) { + log.Printf("Warning: Legacy API key used. Consider migrating to user-based API keys.") + return &ActorContext{ + Type: "legacy_key", + Role: RoleAdmin, + }, nil + } + + return nil, fmt.Errorf("invalid or expired token") } func (m *Middleware) IsAuthEnabled() bool { diff --git a/internal/auth/middleware_test.go b/internal/auth/middleware_test.go new file mode 100644 index 0000000..05910a5 --- /dev/null +++ b/internal/auth/middleware_test.go @@ -0,0 +1,53 @@ +package auth + +import ( + "testing" + "time" +) + +func TestActorForTokenStringUpdatesAPIKeyLastUsed(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + manager.config.Enabled = true + + user, _ := manager.CreateUser("wskeyuser", "", "pass", RoleOperator, nil) + key, plainKey, err := manager.CreateAPIKey(user.ID, "WS Key", "", "", nil, nil, time.Time{}) + if err != nil { + t.Fatalf("CreateAPIKey failed: %v", err) + } + + mw := NewMiddlewareWithManager(manager.config, manager) + actor, err := mw.ActorForTokenString(plainKey, "203.0.113.10") + if err != nil { + t.Fatalf("ActorForTokenString failed: %v", err) + } + if actor.APIKey == nil || actor.APIKey.ID != key.ID { + t.Fatalf("expected API key actor, got %#v", actor) + } + + updated, err := manager.GetAPIKey(key.ID) + if err != nil { + t.Fatalf("GetAPIKey failed: %v", err) + } + if updated.LastUsedAt.IsZero() { + t.Fatal("expected LastUsedAt to be set") + } + if updated.LastUsedIP != "203.0.113.10" { + t.Fatalf("expected LastUsedIP to be 203.0.113.10, got %q", updated.LastUsedIP) + } +} + +func TestActorForTokenStringLegacyKeyReturnsAdminActor(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + manager.config.Enabled = true + + mw := NewMiddlewareWithManager(manager.config, manager) + actor, err := mw.ActorForTokenString("legacy-key-1", "203.0.113.10") + if err != nil { + t.Fatalf("ActorForTokenString failed: %v", err) + } + if actor.Role != RoleAdmin || actor.Type != "legacy_key" { + t.Fatalf("expected legacy admin actor, got %#v", actor) + } +} diff --git a/internal/docker/stats.go b/internal/docker/stats.go index db353b9..6910f94 100644 --- a/internal/docker/stats.go +++ b/internal/docker/stats.go @@ -2,23 +2,25 @@ package docker import ( "encoding/json" + "log" "os/exec" "strconv" "strings" ) type ContainerStats struct { - ContainerID string `json:"container_id"` - Name string `json:"name"` - CPUPercent float64 `json:"cpu_percent"` - MemoryUsage uint64 `json:"memory_usage"` - MemoryLimit uint64 `json:"memory_limit"` - MemoryPercent float64 `json:"memory_percent"` - NetworkRx uint64 `json:"network_rx"` - NetworkTx uint64 `json:"network_tx"` - BlockRead uint64 `json:"block_read"` - BlockWrite uint64 `json:"block_write"` - PIDs int `json:"pids"` + ContainerID string `json:"container_id"` + Name string `json:"name"` + DeploymentName string `json:"deployment_name,omitempty"` + CPUPercent float64 `json:"cpu_percent"` + MemoryUsage uint64 `json:"memory_usage"` + MemoryLimit uint64 `json:"memory_limit"` + MemoryPercent float64 `json:"memory_percent"` + NetworkRx uint64 `json:"network_rx"` + NetworkTx uint64 `json:"network_tx"` + BlockRead uint64 `json:"block_read"` + BlockWrite uint64 `json:"block_write"` + PIDs int `json:"pids"` } type dockerStatsJSON struct { @@ -48,6 +50,7 @@ func GetContainerStats(containerID string) (*ContainerStats, error) { } func GetAllContainerStats() ([]ContainerStats, error) { + deployments := listContainerDeploymentLabels() cmd := exec.Command("docker", "stats", "--no-stream", "--format", "{{json .}}") output, err := cmd.Output() if err != nil { @@ -64,12 +67,40 @@ func GetAllContainerStats() ([]ContainerStats, error) { if err := json.Unmarshal([]byte(line), &raw); err != nil { continue } - stats = append(stats, *parseStats(&raw)) + stat := parseStats(&raw) + if deploymentName := deployments[stat.ContainerID]; deploymentName != "" { + stat.DeploymentName = deploymentName + } else if deploymentName := deployments[stat.Name]; deploymentName != "" { + stat.DeploymentName = deploymentName + } + stats = append(stats, *stat) } return stats, nil } +func listContainerDeploymentLabels() map[string]string { + labels := make(map[string]string) + cmd := exec.Command("docker", "ps", "-a", "--format", "{{.ID}}|{{.Names}}|{{.Label \"com.docker.compose.project\"}}") + output, err := cmd.Output() + if err != nil { + log.Printf("warning: failed to list container deployment labels: %v", err) + return labels + } + for _, line := range strings.Split(strings.TrimSpace(string(output)), "\n") { + if line == "" { + continue + } + parts := strings.SplitN(line, "|", 3) + if len(parts) != 3 || parts[2] == "" { + continue + } + labels[parts[0]] = parts[2] + labels[parts[1]] = parts[2] + } + return labels +} + func GetDeploymentStats(projectName string) ([]ContainerStats, error) { if projectName == "" { return []ContainerStats{}, nil diff --git a/internal/networks/manager.go b/internal/networks/manager.go index 4e8608f..5322d22 100644 --- a/internal/networks/manager.go +++ b/internal/networks/manager.go @@ -192,13 +192,14 @@ func (m *Manager) EnsureContainerOnNetwork(networkName, containerName string) er } type ContainerInfo struct { - ID string `json:"id"` - Name string `json:"name"` - Image string `json:"image"` - State string `json:"state"` - Status string `json:"status"` - Ports []string `json:"ports"` - Created string `json:"created"` + ID string `json:"id"` + Name string `json:"name"` + Image string `json:"image"` + State string `json:"state"` + Status string `json:"status"` + Ports []string `json:"ports"` + Created string `json:"created"` + DeploymentName string `json:"deployment_name,omitempty"` } type ImageInfo struct { @@ -292,7 +293,7 @@ func (m *Manager) GetVolumeStats() (map[string]int, error) { } func (m *Manager) ListContainers() ([]ContainerInfo, error) { - cmd := exec.Command("docker", "ps", "-a", "--format", "{{.ID}}|{{.Names}}|{{.Image}}|{{.State}}|{{.Status}}|{{.Ports}}|{{.CreatedAt}}") + cmd := exec.Command("docker", "ps", "-a", "--format", "{{.ID}}|{{.Names}}|{{.Image}}|{{.State}}|{{.Status}}|{{.Ports}}|{{.CreatedAt}}|{{.Label \"com.docker.compose.project\"}}") output, err := cmd.Output() if err != nil { return nil, fmt.Errorf("failed to list containers: %w", err) @@ -305,7 +306,7 @@ func (m *Manager) ListContainers() ([]ContainerInfo, error) { if line == "" { continue } - parts := strings.SplitN(line, "|", 7) + parts := strings.SplitN(line, "|", 8) if len(parts) < 7 { continue } @@ -320,14 +321,20 @@ func (m *Manager) ListContainers() ([]ContainerInfo, error) { } } + deploymentName := "" + if len(parts) >= 8 { + deploymentName = parts[7] + } + containers = append(containers, ContainerInfo{ - ID: parts[0], - Name: parts[1], - Image: parts[2], - State: parts[3], - Status: parts[4], - Ports: ports, - Created: parts[6], + ID: parts[0], + Name: parts[1], + Image: parts[2], + State: parts[3], + Status: parts[4], + Ports: ports, + Created: parts[6], + DeploymentName: deploymentName, }) }