diff --git a/internal/api/authz.go b/internal/api/authz.go new file mode 100644 index 0000000..c4414f8 --- /dev/null +++ b/internal/api/authz.go @@ -0,0 +1,100 @@ +package api + +import ( + "net/http" + "os/exec" + "strings" + + "github.com/flatrun/agent/internal/auth" + "github.com/gin-gonic/gin" +) + +const composeProjectLabel = "com.docker.compose.project" + +func (s *Server) requireDeploymentAccess(c *gin.Context, deploymentName, level string) bool { + if deploymentName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Deployment name required"}) + return false + } + + actor := auth.GetActorFromContext(c) + // Nil actor is allowed for direct handler tests; production routes set an actor via auth middleware + // or explicit anonymous-admin context when auth is disabled. + if actor == nil { + return true + } + if actor.Role == auth.RoleAdmin { + return true + } + + if !actor.CanAccessDeployment(deploymentName, level) { + c.JSON(http.StatusForbidden, gin.H{"error": "No access to this deployment"}) + return false + } + + return true +} + +func (s *Server) requireContainerAccess(c *gin.Context, containerID, level string) bool { + if containerID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Container ID required"}) + return false + } + + actor := auth.GetActorFromContext(c) + // Nil actor is allowed for direct handler tests; production routes set an actor via auth middleware + // or explicit anonymous-admin context when auth is disabled. + if actor == nil { + return true + } + if actor.Role == auth.RoleAdmin { + // Admins can see missing-container errors; non-admins below get a non-enumerating 403. + if _, err := containerDeploymentName(containerID); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Container not found"}) + return false + } + return true + } + + deploymentName, err := containerDeploymentName(containerID) + if err != nil || deploymentName == "" { + c.JSON(http.StatusForbidden, gin.H{"error": "No access to this container"}) + return false + } + + if !actor.CanAccessDeployment(deploymentName, level) { + c.JSON(http.StatusForbidden, gin.H{"error": "No access to this container"}) + return false + } + + return true +} + +func (s *Server) actorCanAccessContainer(c *gin.Context, containerID, level string) bool { + actor := auth.GetActorFromContext(c) + if actor == nil || actor.Role == auth.RoleAdmin { + return true + } + + deploymentName, err := containerDeploymentName(containerID) + if err != nil || deploymentName == "" { + return false + } + + return actor.CanAccessDeployment(deploymentName, level) +} + +func containerDeploymentName(containerID string) (string, error) { + cmd := exec.Command("docker", "inspect", "--format", "{{ index .Config.Labels \""+composeProjectLabel+"\" }}", containerID) + output, err := cmd.Output() + if err != nil { + return "", err + } + + deploymentName := strings.TrimSpace(string(output)) + if deploymentName == "" { + return "", nil + } + + return deploymentName, nil +} diff --git a/internal/api/authz_test.go b/internal/api/authz_test.go new file mode 100644 index 0000000..f4fe7e0 --- /dev/null +++ b/internal/api/authz_test.go @@ -0,0 +1,321 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/internal/backup" + "github.com/flatrun/agent/internal/contextkeys" + "github.com/flatrun/agent/internal/docker" + "github.com/flatrun/agent/internal/nginx" + "github.com/flatrun/agent/internal/proxy" + "github.com/flatrun/agent/internal/scheduler" + "github.com/flatrun/agent/internal/security" + "github.com/flatrun/agent/internal/traffic" + "github.com/flatrun/agent/pkg/config" + "github.com/flatrun/agent/pkg/models" + "github.com/gin-gonic/gin" +) + +func testActor(role auth.Role, deployments map[string]string) *auth.ActorContext { + return &auth.ActorContext{ + Type: "user", + Role: role, + Deployments: deployments, + } +} + +func actorMiddleware(actor *auth.ActorContext) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set(contextkeys.Actor, actor) + c.Next() + } +} + +func TestListVirtualHostsFiltersByDeploymentAccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "vhosts-authz-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + confDir := filepath.Join(tmpDir, "nginx", "conf.d") + if err := os.MkdirAll(confDir, 0755); err != nil { + t.Fatalf("failed to create conf dir: %v", err) + } + for _, name := range []string{"allowed-app.conf", "other-app.conf"} { + if err := os.WriteFile(filepath.Join(confDir, name), []byte("server {}"), 0644); err != nil { + t.Fatalf("failed to write vhost: %v", err) + } + } + + cfg := &config.Config{DeploymentsPath: tmpDir, Nginx: config.NginxConfig{ConfigPath: confDir}} + server := &Server{ + config: cfg, + proxyOrchestrator: proxy.NewOrchestratorWithManagers( + nginx.NewManager(&cfg.Nginx, tmpDir, ""), + nil, + ), + } + + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"allowed-app": auth.AccessLevelRead}))) + router.GET("/proxy/vhosts", server.listVirtualHosts) + + req := httptest.NewRequest(http.MethodGet, "/proxy/vhosts", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp struct { + VirtualHosts []struct { + Name string `json:"name"` + } `json:"virtual_hosts"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if len(resp.VirtualHosts) != 1 || resp.VirtualHosts[0].Name != "allowed-app" { + t.Fatalf("expected only allowed-app vhost, got %#v", resp.VirtualHosts) + } +} + +func TestSyncAllProxiesRequiresAdmin(t *testing.T) { + gin.SetMode(gin.TestMode) + + server := &Server{} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"app": auth.AccessLevelWrite}))) + router.POST("/proxy/sync", server.syncAllProxies) + + req := httptest.NewRequest(http.MethodPost, "/proxy/sync", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestSecurityEventsRequireDeploymentFilterForNonAdmin(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "security-authz-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + securityManager, err := security.NewManager(tmpDir) + if err != nil { + t.Fatalf("failed to create security manager: %v", err) + } + defer securityManager.Close() + + server := &Server{securityManager: securityManager} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"app": auth.AccessLevelRead}))) + router.GET("/security/events", server.listSecurityEvents) + + req := httptest.NewRequest(http.MethodGet, "/security/events", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestTrafficLogsKeepDBTotalForScopedNonAdmin(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "traffic-authz-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + trafficManager, err := traffic.NewManager(tmpDir, 7) + if err != nil { + t.Fatalf("failed to create traffic manager: %v", err) + } + defer trafficManager.Close() + + for i := 0; i < 2; i++ { + if _, err := trafficManager.IngestLog(&traffic.IngestTrafficLog{ + DeploymentName: "app", + RequestPath: "/", + RequestMethod: "GET", + StatusCode: 200, + SourceIP: "203.0.113.10", + }); err != nil { + t.Fatalf("failed to ingest traffic log: %v", err) + } + } + + server := &Server{trafficManager: trafficManager} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"app": auth.AccessLevelRead}))) + router.GET("/traffic/logs", server.getTrafficLogs) + + req := httptest.NewRequest(http.MethodGet, "/traffic/logs?deployment=app&limit=1", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp struct { + Total int `json:"total"` + Logs []traffic.TrafficLog `json:"logs"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if len(resp.Logs) != 1 { + t.Fatalf("expected one paged log, got %d", len(resp.Logs)) + } + if resp.Total != 2 { + t.Fatalf("expected DB total 2, got %d", resp.Total) + } +} + +func TestUpdateContainerResourcesValidatesBodyBeforeContainerLookup(t *testing.T) { + gin.SetMode(gin.TestMode) + + server := &Server{} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleAdmin, nil))) + router.PUT("/containers/:id/resources", server.updateContainerResources) + + req := httptest.NewRequest(http.MethodPut, "/containers/does-not-exist/resources", bytes.NewBufferString("{invalid")) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400 before container lookup, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestContainerLifecycleDeniesUnscopedContainerForNonAdmin(t *testing.T) { + gin.SetMode(gin.TestMode) + + server := &Server{} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"app": auth.AccessLevelWrite}))) + router.POST("/containers/:id/start", server.startContainer) + router.POST("/containers/:id/stop", server.stopContainer) + router.POST("/containers/:id/restart", server.restartContainer) + router.DELETE("/containers/:id", server.removeContainer) + + tests := []struct { + method string + path string + }{ + {http.MethodPost, "/containers/does-not-exist/start"}, + {http.MethodPost, "/containers/does-not-exist/stop"}, + {http.MethodPost, "/containers/does-not-exist/restart"}, + {http.MethodDelete, "/containers/does-not-exist"}, + } + + for _, tt := range tests { + req := httptest.NewRequest(tt.method, tt.path, nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + if w.Code != http.StatusForbidden { + t.Fatalf("%s %s: expected 403, got %d: %s", tt.method, tt.path, w.Code, w.Body.String()) + } + } +} + +func TestRestoreBackupRequiresWriteOnTargetDeployment(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "backup-authz-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + for _, name := range []string{"source-app", "target-app"} { + createTestDeployment(t, tmpDir, name, &models.ServiceMetadata{Name: name}) + } + + backupManager, err := backup.NewManager(tmpDir) + if err != nil { + t.Fatalf("failed to create backup manager: %v", err) + } + created, err := backupManager.CreateBackup(context.Background(), "source-app", nil) + if err != nil { + t.Fatalf("failed to create backup: %v", err) + } + + server := &Server{backupManager: backupManager} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{ + "source-app": auth.AccessLevelRead, + "target-app": auth.AccessLevelRead, + }))) + router.POST("/backups/:id/restore", server.restoreBackup) + + body := bytes.NewBufferString(`{"deployment_name":"target-app"}`) + req := httptest.NewRequest(http.MethodPost, "/backups/"+created.ID+"/restore", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestCreateScheduledTaskRequiresWriteDeploymentAccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "scheduler-authz-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + createTestDeployment(t, tmpDir, "app", &models.ServiceMetadata{Name: "app"}) + + backupManager, err := backup.NewManager(tmpDir) + if err != nil { + t.Fatalf("failed to create backup manager: %v", err) + } + manager := docker.NewManager(tmpDir) + schedulerManager, err := scheduler.NewManager(tmpDir, scheduler.NewExecutor(backupManager, manager)) + if err != nil { + t.Fatalf("failed to create scheduler manager: %v", err) + } + defer schedulerManager.Stop() + + server := &Server{manager: manager, schedulerManager: schedulerManager} + router := gin.New() + router.Use(actorMiddleware(testActor(auth.RoleOperator, map[string]string{"app": auth.AccessLevelRead}))) + router.POST("/scheduler/tasks", server.createScheduledTask) + + body := bytes.NewBufferString(`{"name":"backup","type":"backup","deployment_name":"app","cron_expr":"0 * * * *","enabled":true}`) + req := httptest.NewRequest(http.MethodPost, "/scheduler/tasks", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String()) + } +} diff --git a/internal/api/backup_handlers.go b/internal/api/backup_handlers.go index c8d1c69..75b44db 100644 --- a/internal/api/backup_handlers.go +++ b/internal/api/backup_handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "strconv" + "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/backup" "github.com/flatrun/agent/pkg/models" "github.com/gin-gonic/gin" @@ -18,6 +19,11 @@ func (s *Server) listBackups(c *gin.Context) { filter := &backup.BackupListFilter{ DeploymentName: c.Query("deployment"), } + if filter.DeploymentName != "" { + if !s.requireDeploymentAccess(c, filter.DeploymentName, auth.AccessLevelRead) { + return + } + } if limit := c.Query("limit"); limit != "" { if l, err := strconv.Atoi(limit); err == nil { @@ -31,6 +37,17 @@ func (s *Server) listBackups(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := backups[:0] + for _, b := range backups { + if actor.CanAccessDeployment(b.DeploymentName, auth.AccessLevelRead) { + filtered = append(filtered, b) + } + } + backups = filtered + } + c.JSON(http.StatusOK, gin.H{"backups": backups}) } @@ -46,6 +63,9 @@ func (s *Server) getBackup(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } + if !s.requireDeploymentAccess(c, b.DeploymentName, auth.AccessLevelRead) { + return + } c.JSON(http.StatusOK, gin.H{"backup": b}) } @@ -61,6 +81,9 @@ func (s *Server) createBackup(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + if !s.requireDeploymentAccess(c, req.DeploymentName, auth.AccessLevelWrite) { + return + } deployment, err := s.manager.GetDeployment(req.DeploymentName) if err != nil { @@ -133,6 +156,15 @@ func (s *Server) deleteBackup(c *gin.Context) { } backupID := c.Param("id") + b, err := s.backupManager.GetBackup(backupID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + if !s.requireDeploymentAccess(c, b.DeploymentName, auth.AccessLevelWrite) { + return + } + if err := s.backupManager.DeleteBackup(backupID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -148,6 +180,15 @@ func (s *Server) downloadBackup(c *gin.Context) { } backupID := c.Param("id") + b, err := s.backupManager.GetBackup(backupID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + if !s.requireDeploymentAccess(c, b.DeploymentName, auth.AccessLevelRead) { + return + } + backupPath, err := s.backupManager.GetBackupPath(backupID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) @@ -217,6 +258,22 @@ func (s *Server) restoreBackup(c *gin.Context) { } req.BackupID = backupID + b, err := s.backupManager.GetBackup(backupID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + targetDeployment := b.DeploymentName + if req.DeploymentName != "" { + targetDeployment = req.DeploymentName + } + if !s.requireDeploymentAccess(c, b.DeploymentName, auth.AccessLevelRead) { + return + } + if !s.requireDeploymentAccess(c, targetDeployment, auth.AccessLevelWrite) { + return + } + jobID := s.backupManager.StartRestoreJob(&req) c.JSON(http.StatusAccepted, gin.H{"job_id": jobID, "message": "Restore job started"}) } @@ -233,6 +290,9 @@ func (s *Server) getBackupJob(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "Job not found"}) return } + if !s.requireDeploymentAccess(c, job.DeploymentName, auth.AccessLevelRead) { + return + } c.JSON(http.StatusOK, gin.H{"job": job}) } @@ -244,6 +304,11 @@ func (s *Server) listBackupJobs(c *gin.Context) { } deploymentName := c.Query("deployment") + if deploymentName != "" { + if !s.requireDeploymentAccess(c, deploymentName, auth.AccessLevelRead) { + return + } + } limit := 50 if l := c.Query("limit"); l != "" { if parsed, err := strconv.Atoi(l); err == nil { @@ -252,5 +317,16 @@ func (s *Server) listBackupJobs(c *gin.Context) { } jobs := s.backupManager.ListJobs(deploymentName, limit) + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := jobs[:0] + for _, job := range jobs { + if actor.CanAccessDeployment(job.DeploymentName, auth.AccessLevelRead) { + filtered = append(filtered, job) + } + } + jobs = filtered + } + c.JSON(http.StatusOK, gin.H{"jobs": jobs}) } diff --git a/internal/api/container_exec.go b/internal/api/container_exec.go index 3dbd711..ace77fc 100644 --- a/internal/api/container_exec.go +++ b/internal/api/container_exec.go @@ -11,6 +11,8 @@ import ( "time" "github.com/creack/pty" + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/internal/contextkeys" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) @@ -48,7 +50,6 @@ func (s *Server) containerExec(c *gin.Context) { } defer conn.Close() - // First-message authentication if s.authMiddleware.IsAuthEnabled() { _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second)) @@ -64,13 +65,28 @@ func (s *Server) containerExec(c *gin.Context) { return } - if !s.authMiddleware.ValidateTokenString(auth.Token) { + actor, err := s.authMiddleware.ActorForTokenString(auth.Token, c.ClientIP()) + if err != nil { sendError(conn, "Invalid or expired token") return } + c.Set(contextkeys.Actor, actor) _ = conn.SetReadDeadline(time.Time{}) + } else { + c.Set(contextkeys.Actor, &auth.ActorContext{Type: "anonymous", Role: auth.RoleAdmin}) + } + actor := auth.GetActorFromContext(c) + if actor == nil || !actor.HasPermission(auth.PermContainersWrite) { + sendError(conn, "Permission denied: containers:write required") + return + } + if !s.actorCanAccessContainer(c, containerID, auth.AccessLevelWrite) { + sendError(conn, "No access to this container") + return + } + if s.authMiddleware.IsAuthEnabled() { if err := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"auth_success"}`)); err != nil { return } @@ -192,6 +208,9 @@ func sendError(conn *websocket.Conn, msg string) { func (s *Server) containerExecHTTP(c *gin.Context) { containerID := c.Param("id") + if !s.requireContainerAccess(c, containerID, auth.AccessLevelWrite) { + return + } var req struct { Command string `json:"command"` diff --git a/internal/api/container_exec_test.go b/internal/api/container_exec_test.go index e072fb3..c822dbd 100644 --- a/internal/api/container_exec_test.go +++ b/internal/api/container_exec_test.go @@ -3,11 +3,17 @@ package api import ( "bytes" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "os" + "strings" "testing" + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/pkg/config" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func TestDetectShell(t *testing.T) { @@ -69,6 +75,97 @@ func TestWebSocketUpgrader(t *testing.T) { } } +func TestContainerExecWebSocketDeniesReadOnlyDeploymentAccessBeforeAuthSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "container-exec-ws-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.AuthConfig{ + Enabled: true, + JWTSecret: "test-jwt-secret-for-ws-exec", + } + os.Setenv("FLATRUN_ADMIN_PASSWORD", "testadminpass") + defer os.Unsetenv("FLATRUN_ADMIN_PASSWORD") + + authManager, err := auth.NewManager(tmpDir, cfg, true) + if err != nil { + t.Fatalf("failed to create auth manager: %v", err) + } + defer authManager.Close() + + user, err := authManager.CreateUser("readonly-operator", "", "password", auth.RoleOperator, nil) + if err != nil { + t.Fatalf("failed to create user: %v", err) + } + admin, err := authManager.GetUserByUsername("admin") + if err != nil { + t.Fatalf("failed to get admin: %v", err) + } + if err := authManager.AssignDeployment(user.ID, "app", auth.AccessLevelRead, admin.ID); err != nil { + t.Fatalf("failed to assign deployment: %v", err) + } + + authMiddleware := auth.NewMiddlewareWithManager(cfg, authManager) + token, err := authMiddleware.GenerateJWTForUser(user, "") + if err != nil { + t.Fatalf("failed to generate token: %v", err) + } + + server := &Server{authMiddleware: authMiddleware} + router := gin.New() + router.GET("/containers/:id/exec", server.containerExec) + + httpServer := newSkippableHTTPServer(t, router) + defer httpServer.Close() + + wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "/containers/does-not-exist/exec" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial websocket: %v", err) + } + defer conn.Close() + + if err := conn.WriteJSON(authMessage{Type: "auth", Token: token}); err != nil { + t.Fatalf("failed to write auth message: %v", err) + } + + _, message, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read websocket response: %v", err) + } + + text := string(message) + if strings.Contains(text, "auth_success") { + t.Fatalf("received auth_success before authorization denial: %q", text) + } + if !strings.Contains(text, "No access to this container") { + t.Fatalf("expected no-access denial, got %q", text) + } +} + +func newSkippableHTTPServer(t *testing.T, handler http.Handler) *httptest.Server { + t.Helper() + + var srv *httptest.Server + func() { + defer func() { + if r := recover(); r != nil { + msg := fmt.Sprint(r) + if strings.Contains(strings.ToLower(msg), "operation not permitted") { + t.Skipf("httptest listener not permitted in this environment: %v", r) + } + panic(r) + } + }() + srv = httptest.NewServer(handler) + }() + return srv +} + func TestAuthMessageParsing(t *testing.T) { tests := []struct { name string diff --git a/internal/api/resource_handlers.go b/internal/api/resource_handlers.go index c40aee4..70d6f63 100644 --- a/internal/api/resource_handlers.go +++ b/internal/api/resource_handlers.go @@ -3,12 +3,16 @@ package api import ( "net/http" + "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/docker" "github.com/gin-gonic/gin" ) func (s *Server) getContainerResources(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelRead) { + return + } resources, err := docker.GetContainerResources(id) if err != nil { @@ -42,6 +46,10 @@ func (s *Server) updateContainerResources(c *gin.Context) { return } + if !s.requireContainerAccess(c, id, auth.AccessLevelWrite) { + return + } + if err := docker.UpdateContainerResources(id, &update); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": err.Error(), diff --git a/internal/api/scheduler_handlers.go b/internal/api/scheduler_handlers.go index 59edda7..ba5b191 100644 --- a/internal/api/scheduler_handlers.go +++ b/internal/api/scheduler_handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "strconv" + "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/scheduler" "github.com/gin-gonic/gin" ) @@ -20,6 +21,9 @@ func (s *Server) listScheduledTasks(c *gin.Context) { var err error if deploymentName != "" { + if !s.requireDeploymentAccess(c, deploymentName, auth.AccessLevelRead) { + return + } tasks, err = s.schedulerManager.GetTasksByDeployment(deploymentName) } else { tasks, err = s.schedulerManager.GetAllTasks() @@ -30,6 +34,17 @@ func (s *Server) listScheduledTasks(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := tasks[:0] + for _, task := range tasks { + if actor.CanAccessDeployment(task.DeploymentName, auth.AccessLevelRead) { + filtered = append(filtered, task) + } + } + tasks = filtered + } + c.JSON(http.StatusOK, gin.H{"tasks": tasks}) } @@ -50,6 +65,9 @@ func (s *Server) getScheduledTask(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) return } + if !s.requireDeploymentAccess(c, task.DeploymentName, auth.AccessLevelRead) { + return + } c.JSON(http.StatusOK, gin.H{"task": task}) } @@ -65,6 +83,9 @@ func (s *Server) createScheduledTask(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + if !s.requireDeploymentAccess(c, req.DeploymentName, auth.AccessLevelWrite) { + return + } if err := s.schedulerManager.ValidateCronExpr(req.CronExpr); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid cron expression: " + err.Error()}) @@ -107,6 +128,15 @@ func (s *Server) updateScheduledTask(c *gin.Context) { return } + existingTask, err := s.schedulerManager.GetTask(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) + return + } + if !s.requireDeploymentAccess(c, existingTask.DeploymentName, auth.AccessLevelWrite) { + return + } + var req scheduler.UpdateTaskRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -121,11 +151,6 @@ func (s *Server) updateScheduledTask(c *gin.Context) { } if req.Config != nil && req.Config.CommandConfig != nil { - existingTask, err := s.schedulerManager.GetTask(id) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) - return - } resolved, err := s.resolveService(existingTask.DeploymentName, req.Config.CommandConfig.Service) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -154,6 +179,14 @@ func (s *Server) deleteScheduledTask(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"}) return } + task, err := s.schedulerManager.GetTask(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) + return + } + if !s.requireDeploymentAccess(c, task.DeploymentName, auth.AccessLevelWrite) { + return + } if err := s.schedulerManager.DeleteTask(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -174,6 +207,14 @@ func (s *Server) runTaskNow(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"}) return } + task, err := s.schedulerManager.GetTask(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) + return + } + if !s.requireDeploymentAccess(c, task.DeploymentName, auth.AccessLevelWrite) { + return + } if err := s.schedulerManager.RunTaskNow(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -194,6 +235,14 @@ func (s *Server) getTaskExecutions(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"}) return } + task, err := s.schedulerManager.GetTask(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Task not found"}) + return + } + if !s.requireDeploymentAccess(c, task.DeploymentName, auth.AccessLevelRead) { + return + } limit := 50 if l := c.Query("limit"); l != "" { @@ -230,5 +279,26 @@ func (s *Server) getRecentExecutions(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := executions[:0] + taskDeployments := make(map[int64]string) + for _, execution := range executions { + deploymentName, ok := taskDeployments[execution.TaskID] + if !ok { + task, err := s.schedulerManager.GetTask(execution.TaskID) + if err != nil { + continue + } + deploymentName = task.DeploymentName + taskDeployments[execution.TaskID] = deploymentName + } + if actor.CanAccessDeployment(deploymentName, auth.AccessLevelRead) { + filtered = append(filtered, execution) + } + } + executions = filtered + } + c.JSON(http.StatusOK, gin.H{"executions": executions}) } diff --git a/internal/api/security_handlers.go b/internal/api/security_handlers.go index 83fee8a..bfd670f 100644 --- a/internal/api/security_handlers.go +++ b/internal/api/security_handlers.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/security" "github.com/flatrun/agent/pkg/config" "github.com/flatrun/agent/pkg/models" @@ -86,6 +87,17 @@ func (s *Server) listSecurityEvents(c *gin.Context) { SourceIP: c.Query("source_ip"), DeploymentName: c.Query("deployment"), } + if filter.DeploymentName != "" { + if !s.requireDeploymentAccess(c, filter.DeploymentName, auth.AccessLevelRead) { + return + } + } else { + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Deployment filter required"}) + return + } + } if limit := c.Query("limit"); limit != "" { if l, err := strconv.Atoi(limit); err == nil { @@ -145,6 +157,9 @@ func (s *Server) getSecurityEvent(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "Event not found"}) return } + if event.DeploymentName != "" && !s.requireDeploymentAccess(c, event.DeploymentName, auth.AccessLevelRead) { + return + } c.JSON(http.StatusOK, gin.H{"event": event}) } diff --git a/internal/api/server.go b/internal/api/server.go index 78cc9d5..54f6001 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -348,9 +348,9 @@ func (s *Server) setupRoutes() { protected.POST("/deployments/:name/certificates/renew", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.renewDeploymentCertificates) // Proxy endpoints - protected.GET("/proxy/status/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.getProxyStatus) - protected.POST("/proxy/setup/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.setupProxy) - protected.DELETE("/proxy/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesDelete), s.teardownProxy) + protected.GET("/proxy/status/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getProxyStatus) + protected.POST("/proxy/setup/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.setupProxy) + protected.DELETE("/proxy/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesDelete), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.teardownProxy) protected.GET("/proxy/vhosts", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.listVirtualHosts) protected.POST("/proxy/sync", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.syncAllProxies) protected.POST("/deployments/:name/ssl/disable", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.disableSSL) @@ -576,7 +576,7 @@ func (s *Server) setupRoutes() { } // Get users with access to a deployment - protected.GET("/deployments/:name/users", s.authMiddleware.RequirePermission(auth.PermUsersRead), s.getDeploymentUsers) + protected.GET("/deployments/:name/users", s.authMiddleware.RequirePermission(auth.PermUsersRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelAdmin), s.getDeploymentUsers) } // DNS plugin routes @@ -1918,6 +1918,10 @@ func (s *Server) connectContainer(c *gin.Context) { return } + if !s.requireContainerAccess(c, req.Container, auth.AccessLevelWrite) { + return + } + if err := s.networksManager.ConnectContainer(networkName, req.Container); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": err.Error(), @@ -1946,6 +1950,10 @@ func (s *Server) disconnectContainer(c *gin.Context) { return } + if !s.requireContainerAccess(c, req.Container, auth.AccessLevelWrite) { + return + } + if err := s.networksManager.DisconnectContainer(networkName, req.Container); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": err.Error(), @@ -3800,6 +3808,18 @@ func (s *Server) listVirtualHosts(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := vhosts[:0] + for _, vhost := range vhosts { + // VirtualHostInfo.Name is the deployment name derived from .conf. + if actor.CanAccessDeployment(vhost.Name, auth.AccessLevelRead) { + filtered = append(filtered, vhost) + } + } + vhosts = filtered + } + c.JSON(http.StatusOK, gin.H{ "virtual_hosts": vhosts, }) @@ -3814,6 +3834,12 @@ type ProxySyncResult struct { } func (s *Server) syncAllProxies(c *gin.Context) { + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Admin access required"}) + return + } + deployments, err := s.manager.ListDeployments() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4325,6 +4351,18 @@ func (s *Server) listContainers(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := containers[:0] + for _, container := range containers { + // Non-admins only see FlatRun/Compose containers assigned through deployment access. + if actor.CanAccessDeployment(container.DeploymentName, auth.AccessLevelRead) { + filtered = append(filtered, container) + } + } + containers = filtered + } + c.JSON(http.StatusOK, gin.H{ "containers": containers, }) @@ -4332,6 +4370,9 @@ func (s *Server) listContainers(c *gin.Context) { func (s *Server) startContainer(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelWrite) { + return + } if err := s.networksManager.StartContainer(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4348,6 +4389,9 @@ func (s *Server) startContainer(c *gin.Context) { func (s *Server) stopContainer(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelWrite) { + return + } if err := s.networksManager.StopContainer(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4364,6 +4408,9 @@ func (s *Server) stopContainer(c *gin.Context) { func (s *Server) restartContainer(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelWrite) { + return + } if err := s.networksManager.RestartContainer(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4380,6 +4427,9 @@ func (s *Server) restartContainer(c *gin.Context) { func (s *Server) removeContainer(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelAdmin) { + return + } if err := s.networksManager.RemoveContainer(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4396,6 +4446,9 @@ func (s *Server) removeContainer(c *gin.Context) { func (s *Server) getContainerLogs(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelRead) { + return + } tailStr := c.DefaultQuery("tail", "100") tail, err := strconv.Atoi(tailStr) @@ -4419,6 +4472,9 @@ func (s *Server) getContainerLogs(c *gin.Context) { func (s *Server) getContainerStats(c *gin.Context) { id := c.Param("id") + if !s.requireContainerAccess(c, id, auth.AccessLevelRead) { + return + } stats, err := docker.GetContainerStats(id) if err != nil { @@ -4442,6 +4498,18 @@ func (s *Server) getAllContainerStats(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + filtered := stats[:0] + for _, stat := range stats { + // Non-admins only see FlatRun/Compose containers assigned through deployment access. + if actor.CanAccessDeployment(stat.DeploymentName, auth.AccessLevelRead) { + filtered = append(filtered, stat) + } + } + stats = filtered + } + c.JSON(http.StatusOK, gin.H{ "stats": stats, }) diff --git a/internal/api/traffic_handlers.go b/internal/api/traffic_handlers.go index 0e7fdc2..5622572 100644 --- a/internal/api/traffic_handlers.go +++ b/internal/api/traffic_handlers.go @@ -5,6 +5,7 @@ import ( "strconv" "time" + "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/traffic" "github.com/gin-gonic/gin" ) @@ -45,6 +46,17 @@ func (s *Server) getTrafficLogs(c *gin.Context) { SourceIP: c.Query("source_ip"), RequestPath: c.Query("path"), } + if filter.DeploymentName != "" { + if !s.requireDeploymentAccess(c, filter.DeploymentName, auth.AccessLevelRead) { + return + } + } else { + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Deployment filter required"}) + return + } + } if statusCode := c.Query("status_code"); statusCode != "" { if code, err := strconv.Atoi(statusCode); err == nil { @@ -100,6 +112,17 @@ func (s *Server) getTrafficStats(c *gin.Context) { } deploymentName := c.Query("deployment") + if deploymentName != "" { + if !s.requireDeploymentAccess(c, deploymentName, auth.AccessLevelRead) { + return + } + } else { + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Deployment filter required"}) + return + } + } since := 24 * time.Hour if sinceStr := c.Query("since"); sinceStr != "" { diff --git a/internal/api/user_deployments.go b/internal/api/user_deployments.go index e35db5c..7505fdd 100644 --- a/internal/api/user_deployments.go +++ b/internal/api/user_deployments.go @@ -135,6 +135,10 @@ func (s *Server) removeUserDeployment(c *gin.Context) { return } + if !s.requireDeploymentAccess(c, deploymentName, auth.AccessLevelAdmin) { + return + } + if err := s.authManager.RemoveDeploymentAccess(userID, deploymentName); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to remove deployment access"}) return diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 3e5b59b..41e84cf 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -489,10 +489,62 @@ func (m *Middleware) GetAuthStatus(c *gin.Context) { } func (m *Middleware) ValidateTokenString(token string) bool { + _, err := m.ActorForTokenString(token, "") + return err == nil +} + +func (m *Middleware) ActorForTokenString(token string, clientIP string) (*ActorContext, error) { if !m.config.Enabled { - return true + return &ActorContext{ + Type: "anonymous", + Role: RoleAdmin, + }, nil + } + + if claims := m.validateJWTWithClaims(token); claims != nil { + if m.manager != nil && claims.SessionID != "" { + session, err := m.manager.GetSessionByID(claims.SessionID) + if err != nil || !session.RevokedAt.IsZero() { + return nil, fmt.Errorf("session revoked or invalid") + } + } + + if m.manager != nil && claims.UserID > 0 { + user, err := m.manager.GetUser(claims.UserID) + if err != nil { + return nil, err + } + if !user.IsActive { + return nil, ErrUserInactive + } + return m.manager.BuildActorContext(user, nil) + } + + return &ActorContext{ + Type: "jwt", + Role: RoleAdmin, + }, nil + } + + if m.manager != nil { + apiKey, user, err := m.manager.ValidateAPIKey(token) + if err == nil { + if clientIP != "" { + _ = m.manager.UpdateAPIKeyLastUsed(apiKey.ID, clientIP) + } + return m.manager.BuildActorContext(user, apiKey) + } } - return m.validateJWT(token) || m.validateAPIKey(token) + + if m.validateAPIKey(token) { + log.Printf("Warning: Legacy API key used. Consider migrating to user-based API keys.") + return &ActorContext{ + Type: "legacy_key", + Role: RoleAdmin, + }, nil + } + + return nil, fmt.Errorf("invalid or expired token") } func (m *Middleware) IsAuthEnabled() bool { diff --git a/internal/auth/middleware_test.go b/internal/auth/middleware_test.go new file mode 100644 index 0000000..05910a5 --- /dev/null +++ b/internal/auth/middleware_test.go @@ -0,0 +1,53 @@ +package auth + +import ( + "testing" + "time" +) + +func TestActorForTokenStringUpdatesAPIKeyLastUsed(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + manager.config.Enabled = true + + user, _ := manager.CreateUser("wskeyuser", "", "pass", RoleOperator, nil) + key, plainKey, err := manager.CreateAPIKey(user.ID, "WS Key", "", "", nil, nil, time.Time{}) + if err != nil { + t.Fatalf("CreateAPIKey failed: %v", err) + } + + mw := NewMiddlewareWithManager(manager.config, manager) + actor, err := mw.ActorForTokenString(plainKey, "203.0.113.10") + if err != nil { + t.Fatalf("ActorForTokenString failed: %v", err) + } + if actor.APIKey == nil || actor.APIKey.ID != key.ID { + t.Fatalf("expected API key actor, got %#v", actor) + } + + updated, err := manager.GetAPIKey(key.ID) + if err != nil { + t.Fatalf("GetAPIKey failed: %v", err) + } + if updated.LastUsedAt.IsZero() { + t.Fatal("expected LastUsedAt to be set") + } + if updated.LastUsedIP != "203.0.113.10" { + t.Fatalf("expected LastUsedIP to be 203.0.113.10, got %q", updated.LastUsedIP) + } +} + +func TestActorForTokenStringLegacyKeyReturnsAdminActor(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + manager.config.Enabled = true + + mw := NewMiddlewareWithManager(manager.config, manager) + actor, err := mw.ActorForTokenString("legacy-key-1", "203.0.113.10") + if err != nil { + t.Fatalf("ActorForTokenString failed: %v", err) + } + if actor.Role != RoleAdmin || actor.Type != "legacy_key" { + t.Fatalf("expected legacy admin actor, got %#v", actor) + } +} diff --git a/internal/docker/stats.go b/internal/docker/stats.go index db353b9..6910f94 100644 --- a/internal/docker/stats.go +++ b/internal/docker/stats.go @@ -2,23 +2,25 @@ package docker import ( "encoding/json" + "log" "os/exec" "strconv" "strings" ) type ContainerStats struct { - ContainerID string `json:"container_id"` - Name string `json:"name"` - CPUPercent float64 `json:"cpu_percent"` - MemoryUsage uint64 `json:"memory_usage"` - MemoryLimit uint64 `json:"memory_limit"` - MemoryPercent float64 `json:"memory_percent"` - NetworkRx uint64 `json:"network_rx"` - NetworkTx uint64 `json:"network_tx"` - BlockRead uint64 `json:"block_read"` - BlockWrite uint64 `json:"block_write"` - PIDs int `json:"pids"` + ContainerID string `json:"container_id"` + Name string `json:"name"` + DeploymentName string `json:"deployment_name,omitempty"` + CPUPercent float64 `json:"cpu_percent"` + MemoryUsage uint64 `json:"memory_usage"` + MemoryLimit uint64 `json:"memory_limit"` + MemoryPercent float64 `json:"memory_percent"` + NetworkRx uint64 `json:"network_rx"` + NetworkTx uint64 `json:"network_tx"` + BlockRead uint64 `json:"block_read"` + BlockWrite uint64 `json:"block_write"` + PIDs int `json:"pids"` } type dockerStatsJSON struct { @@ -48,6 +50,7 @@ func GetContainerStats(containerID string) (*ContainerStats, error) { } func GetAllContainerStats() ([]ContainerStats, error) { + deployments := listContainerDeploymentLabels() cmd := exec.Command("docker", "stats", "--no-stream", "--format", "{{json .}}") output, err := cmd.Output() if err != nil { @@ -64,12 +67,40 @@ func GetAllContainerStats() ([]ContainerStats, error) { if err := json.Unmarshal([]byte(line), &raw); err != nil { continue } - stats = append(stats, *parseStats(&raw)) + stat := parseStats(&raw) + if deploymentName := deployments[stat.ContainerID]; deploymentName != "" { + stat.DeploymentName = deploymentName + } else if deploymentName := deployments[stat.Name]; deploymentName != "" { + stat.DeploymentName = deploymentName + } + stats = append(stats, *stat) } return stats, nil } +func listContainerDeploymentLabels() map[string]string { + labels := make(map[string]string) + cmd := exec.Command("docker", "ps", "-a", "--format", "{{.ID}}|{{.Names}}|{{.Label \"com.docker.compose.project\"}}") + output, err := cmd.Output() + if err != nil { + log.Printf("warning: failed to list container deployment labels: %v", err) + return labels + } + for _, line := range strings.Split(strings.TrimSpace(string(output)), "\n") { + if line == "" { + continue + } + parts := strings.SplitN(line, "|", 3) + if len(parts) != 3 || parts[2] == "" { + continue + } + labels[parts[0]] = parts[2] + labels[parts[1]] = parts[2] + } + return labels +} + func GetDeploymentStats(projectName string) ([]ContainerStats, error) { if projectName == "" { return []ContainerStats{}, nil diff --git a/internal/networks/manager.go b/internal/networks/manager.go index 4e8608f..5322d22 100644 --- a/internal/networks/manager.go +++ b/internal/networks/manager.go @@ -192,13 +192,14 @@ func (m *Manager) EnsureContainerOnNetwork(networkName, containerName string) er } type ContainerInfo struct { - ID string `json:"id"` - Name string `json:"name"` - Image string `json:"image"` - State string `json:"state"` - Status string `json:"status"` - Ports []string `json:"ports"` - Created string `json:"created"` + ID string `json:"id"` + Name string `json:"name"` + Image string `json:"image"` + State string `json:"state"` + Status string `json:"status"` + Ports []string `json:"ports"` + Created string `json:"created"` + DeploymentName string `json:"deployment_name,omitempty"` } type ImageInfo struct { @@ -292,7 +293,7 @@ func (m *Manager) GetVolumeStats() (map[string]int, error) { } func (m *Manager) ListContainers() ([]ContainerInfo, error) { - cmd := exec.Command("docker", "ps", "-a", "--format", "{{.ID}}|{{.Names}}|{{.Image}}|{{.State}}|{{.Status}}|{{.Ports}}|{{.CreatedAt}}") + cmd := exec.Command("docker", "ps", "-a", "--format", "{{.ID}}|{{.Names}}|{{.Image}}|{{.State}}|{{.Status}}|{{.Ports}}|{{.CreatedAt}}|{{.Label \"com.docker.compose.project\"}}") output, err := cmd.Output() if err != nil { return nil, fmt.Errorf("failed to list containers: %w", err) @@ -305,7 +306,7 @@ func (m *Manager) ListContainers() ([]ContainerInfo, error) { if line == "" { continue } - parts := strings.SplitN(line, "|", 7) + parts := strings.SplitN(line, "|", 8) if len(parts) < 7 { continue } @@ -320,14 +321,20 @@ func (m *Manager) ListContainers() ([]ContainerInfo, error) { } } + deploymentName := "" + if len(parts) >= 8 { + deploymentName = parts[7] + } + containers = append(containers, ContainerInfo{ - ID: parts[0], - Name: parts[1], - Image: parts[2], - State: parts[3], - Status: parts[4], - Ports: ports, - Created: parts[6], + ID: parts[0], + Name: parts[1], + Image: parts[2], + State: parts[3], + Status: parts[4], + Ports: ports, + Created: parts[6], + DeploymentName: deploymentName, }) }