diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ff925e50..c1793381 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,7 +47,7 @@ jobs: services: # PostgreSQL for comparison tests postgres: - image: postgres:16-alpine + image: public.ecr.aws/docker/library/postgres:16-alpine env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres @@ -62,7 +62,7 @@ jobs: # DuckLake metadata store ducklake-metadata: - image: postgres:16-alpine + image: public.ecr.aws/docker/library/postgres:16-alpine env: POSTGRES_USER: ducklake POSTGRES_PASSWORD: ducklake @@ -138,7 +138,7 @@ jobs: services: postgres: - image: postgres:16-alpine + image: public.ecr.aws/docker/library/postgres:16-alpine env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres @@ -171,7 +171,7 @@ jobs: services: postgres: - image: postgres:16-alpine + image: public.ecr.aws/docker/library/postgres:16-alpine env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres @@ -204,10 +204,11 @@ jobs: timeout-minutes: 30 env: DUCKGRES_KIND_CLUSTER_NAME: duckgres + DUCKGRES_KIND_NODE_IMAGE: kindest/node:v1.31.0@sha256:53df588e04085fd41ae12de0c3fe4c72f7013bba32a20e7325357a1ac94ba865 services: postgres: - image: postgres:16-alpine + image: public.ecr.aws/docker/library/postgres:16-alpine env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres @@ -237,5 +238,19 @@ jobs: chmod +x /tmp/kind sudo mv /tmp/kind /usr/local/bin/kind kind --version + - name: Clear Docker Hub credentials for kind pulls + run: | + docker logout registry-1.docker.io || true + docker logout docker.io || true + docker logout https://index.docker.io/v1/ || true + - name: Pre-pull kind node image + run: | + for attempt in 1 2 3; do + if docker pull "${DUCKGRES_KIND_NODE_IMAGE}"; then + exit 0 + fi + sleep $((attempt * 5)) + done + exit 1 - name: Run Kubernetes integration tests run: just test-k8s-integration diff --git a/README.md b/README.md index b3fdce1c..2a7a368e 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,7 @@ Run with config file: | `DUCKGRES_THREADS` | DuckDB threads per session | `runtime.NumCPU()` | | `DUCKGRES_PROCESS_ISOLATION` | Enable process isolation (`1` or `true`) | `false` | | `DUCKGRES_IDLE_TIMEOUT` | Connection idle timeout (e.g., `30m`, `1h`, `-1` to disable) | `24h` | +| `DUCKGRES_HANDOVER_DRAIN_TIMEOUT` | Max time to drain planned shutdowns and upgrades before forcing exit | `24h` in process mode, `15m` in remote K8s mode | | `DUCKGRES_K8S_SHARED_WARM_TARGET` | Neutral shared warm-worker target for K8s multi-tenant mode (`0` disables prewarm) | `0` | | `DUCKGRES_DUCKLAKE_METADATA_STORE` | DuckLake metadata connection string | - | | `POSTHOG_API_KEY` | PostHog project API key (`phc_...`); enables log export | - | @@ -598,7 +599,7 @@ kill -USR2 ### Remote Worker Backend -In Kubernetes environments, `--worker-backend remote` is now the multitenant path only. It requires `--config-store`, and the control plane then spawns worker pods via the Kubernetes API, communicates with them over gRPC (Arrow Flight SQL), and uses owner references for automatic garbage collection when the control plane pod is deleted. +In Kubernetes environments, `--worker-backend remote` is the multitenant path. It requires `--config-store`. Control-plane replicas coordinate through durable runtime rows in the config-store Postgres DB, spawn worker pods via the Kubernetes API, and communicate with them over gRPC (Arrow Flight SQL). Planned rolling deploys mark old replicas draining, fail readiness, and wait up to `handover_drain_timeout` before forcing shutdown. Unplanned control-plane failure still drops live pgwire connections; Flight may reconnect with a durable session token if the worker survives and the token is still valid. When a shared warm-worker target is configured (`--k8s-shared-warm-target`), the pool keeps workers neutral at startup, reserves them per org, activates tenant runtime over the activation RPC, and retires them after use. The full lifecycle is: idle → reserved → activating → hot → draining → retired. diff --git a/controlplane/configstore/models.go b/controlplane/configstore/models.go index 439190da..d704877a 100644 --- a/controlplane/configstore/models.go +++ b/controlplane/configstore/models.go @@ -186,6 +186,92 @@ type QueryLogConfig struct { func (QueryLogConfig) TableName() string { return "duckgres_query_log_config" } +// ControlPlaneInstanceState describes the liveness state of a control-plane instance. +type ControlPlaneInstanceState string + +const ( + ControlPlaneInstanceStateActive ControlPlaneInstanceState = "active" + ControlPlaneInstanceStateDraining ControlPlaneInstanceState = "draining" + ControlPlaneInstanceStateExpired ControlPlaneInstanceState = "expired" +) + +// ControlPlaneInstance is a runtime coordination record for one control-plane process. +// These rows live in the runtime schema, not the snapshot-backed config tables. +type ControlPlaneInstance struct { + ID string `gorm:"primaryKey;size:255" json:"id"` + PodName string `gorm:"size:255;not null" json:"pod_name"` + PodUID string `gorm:"size:255;not null" json:"pod_uid"` + BootID string `gorm:"size:255;not null" json:"boot_id"` + State ControlPlaneInstanceState `gorm:"size:32;not null" json:"state"` + StartedAt time.Time `json:"started_at"` + LastHeartbeatAt time.Time `gorm:"index" json:"last_heartbeat_at"` + DrainingAt *time.Time `json:"draining_at,omitempty"` + ExpiredAt *time.Time `json:"expired_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (ControlPlaneInstance) TableName() string { return "cp_instances" } + +// WorkerState is the durable lifecycle state for a worker pod. +type WorkerState string + +const ( + WorkerStateSpawning WorkerState = "spawning" + WorkerStateIdle WorkerState = "idle" + WorkerStateReserved WorkerState = "reserved" + WorkerStateActivating WorkerState = "activating" + WorkerStateHot WorkerState = "hot" + WorkerStateDraining WorkerState = "draining" + WorkerStateRetired WorkerState = "retired" + WorkerStateLost WorkerState = "lost" +) + +// WorkerRecord is the durable runtime coordination record for one worker pod. +type WorkerRecord struct { + WorkerID int `gorm:"primaryKey" json:"worker_id"` + PodName string `gorm:"size:255;not null;uniqueIndex" json:"pod_name"` + PodUID string `gorm:"size:255" json:"pod_uid"` + State WorkerState `gorm:"size:32;not null;index" json:"state"` + OrgID string `gorm:"size:255;index" json:"org_id"` + OwnerCPInstanceID string `gorm:"size:255;index" json:"owner_cp_instance_id"` + OwnerEpoch int64 `gorm:"not null" json:"owner_epoch"` + ActivationStartedAt *time.Time `json:"activation_started_at,omitempty"` + LastHeartbeatAt time.Time `json:"last_heartbeat_at"` + RetireReason string `gorm:"size:64" json:"retire_reason"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (WorkerRecord) TableName() string { return "worker_records" } + +// FlightSessionState is the durable reconnect state for Flight-only sessions. +type FlightSessionState string + +const ( + FlightSessionStateActive FlightSessionState = "active" + FlightSessionStateReconnecting FlightSessionState = "reconnecting" + FlightSessionStateExpired FlightSessionState = "expired" + FlightSessionStateClosed FlightSessionState = "closed" +) + +// FlightSessionRecord is the durable reconnect record for Flight sessions. +type FlightSessionRecord struct { + SessionToken string `gorm:"primaryKey;size:255" json:"session_token"` + Username string `gorm:"size:255;not null" json:"username"` + OrgID string `gorm:"size:255;not null" json:"org_id"` + WorkerID int `gorm:"not null;index" json:"worker_id"` + OwnerEpoch int64 `gorm:"not null" json:"owner_epoch"` + CPInstanceID string `gorm:"size:255" json:"cp_instance_id"` + State FlightSessionState `gorm:"size:32;not null" json:"state"` + ExpiresAt time.Time `gorm:"index" json:"expires_at"` + LastSeenAt time.Time `json:"last_seen_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (FlightSessionRecord) TableName() string { return "flight_session_records" } + // OrgConfig is a convenience view combining org metadata with resource limits. type OrgConfig struct { Name string diff --git a/controlplane/configstore/store.go b/controlplane/configstore/store.go index 88353e6b..66e3949e 100644 --- a/controlplane/configstore/store.go +++ b/controlplane/configstore/store.go @@ -2,17 +2,23 @@ package configstore import ( "context" + "errors" "fmt" + "hash/fnv" "log/slog" + "strings" "sync" "time" "golang.org/x/crypto/bcrypt" "gorm.io/driver/postgres" "gorm.io/gorm" + "gorm.io/gorm/clause" "gorm.io/gorm/logger" ) +var ErrWorkerOwnerEpochMismatch = errors.New("worker owner epoch mismatch") + // Snapshot holds a point-in-time copy of all config data for fast lookups. type Snapshot struct { Orgs map[string]*OrgConfig @@ -26,11 +32,12 @@ type Snapshot struct { // ConfigStore manages configuration stored in a PostgreSQL database. type ConfigStore struct { - db *gorm.DB - mu sync.RWMutex - snapshot *Snapshot - pollInterval time.Duration - onChange []func(old, new *Snapshot) + db *gorm.DB + runtimeSchema string + mu sync.RWMutex + snapshot *Snapshot + pollInterval time.Duration + onChange []func(old, new *Snapshot) } // NewConfigStore connects to the PostgreSQL config store, runs migrations, @@ -60,6 +67,17 @@ func NewConfigStore(connStr string, pollInterval time.Duration) (*ConfigStore, e return nil, fmt.Errorf("auto-migrate config store: %w", err) } + runtimeSchema, err := resolveRuntimeSchema(db) + if err != nil { + return nil, fmt.Errorf("resolve runtime schema: %w", err) + } + if err := ensureRuntimeSchema(db, runtimeSchema); err != nil { + return nil, fmt.Errorf("ensure runtime schema: %w", err) + } + if err := autoMigrateRuntimeTables(db, runtimeSchema); err != nil { + return nil, fmt.Errorf("auto-migrate runtime schema: %w", err) + } + // Ensure singleton rows exist with defaults db.FirstOrCreate(&GlobalConfig{}, GlobalConfig{ID: 1}) db.FirstOrCreate(&DuckLakeConfig{}, DuckLakeConfig{ID: 1}) @@ -67,8 +85,9 @@ func NewConfigStore(connStr string, pollInterval time.Duration) (*ConfigStore, e db.FirstOrCreate(&QueryLogConfig{}, QueryLogConfig{ID: 1}) cs := &ConfigStore{ - db: db, - pollInterval: pollInterval, + db: db, + runtimeSchema: runtimeSchema, + pollInterval: pollInterval, } // Load initial snapshot @@ -242,11 +261,523 @@ func (cs *ConfigStore) UpdateWarehouseState(orgID string, expectedState ManagedW return nil } +func resolveRuntimeSchema(db *gorm.DB) (string, error) { + var currentSchema string + if err := db.Raw("SELECT current_schema()").Scan(¤tSchema).Error; err != nil { + return "", err + } + if currentSchema == "" || currentSchema == "public" { + return "cp_runtime", nil + } + return currentSchema + "_runtime", nil +} + +func ensureRuntimeSchema(db *gorm.DB, runtimeSchema string) error { + return db.Exec(`CREATE SCHEMA IF NOT EXISTS "` + quoteIdentifier(runtimeSchema) + `"`).Error +} + +func autoMigrateRuntimeTables(db *gorm.DB, runtimeSchema string) error { + for _, spec := range []struct { + table string + model any + }{ + {table: runtimeSchema + ".cp_instances", model: &ControlPlaneInstance{}}, + {table: runtimeSchema + ".worker_records", model: &WorkerRecord{}}, + {table: runtimeSchema + ".flight_session_records", model: &FlightSessionRecord{}}, + } { + if err := db.Table(spec.table).AutoMigrate(spec.model); err != nil { + return err + } + } + return nil +} + +func quoteIdentifier(v string) string { + return strings.ReplaceAll(v, `"`, `""`) +} + // DB exposes the GORM database for direct CRUD operations (used by admin API). func (cs *ConfigStore) DB() *gorm.DB { return cs.db } +// RuntimeSchema returns the dedicated runtime coordination schema name. +func (cs *ConfigStore) RuntimeSchema() string { + return cs.runtimeSchema +} + +func (cs *ConfigStore) runtimeTable(base string) string { + return cs.runtimeSchema + "." + base +} + +// UpsertControlPlaneInstance inserts or updates a runtime control-plane instance row. +func (cs *ConfigStore) UpsertControlPlaneInstance(instance *ControlPlaneInstance) error { + if err := cs.db.Table(cs.runtimeTable(instance.TableName())).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "id"}}, + DoUpdates: clause.AssignmentColumns([]string{"pod_name", "pod_uid", "boot_id", "state", "started_at", "last_heartbeat_at", "draining_at", "expired_at", "updated_at"}), + }).Create(instance).Error; err != nil { + return fmt.Errorf("upsert control plane instance: %w", err) + } + return nil +} + +// GetControlPlaneInstance returns a runtime control-plane instance row by id. +func (cs *ConfigStore) GetControlPlaneInstance(id string) (*ControlPlaneInstance, error) { + var instance ControlPlaneInstance + if err := cs.db.Table(cs.runtimeTable(instance.TableName())).First(&instance, "id = ?", id).Error; err != nil { + return nil, fmt.Errorf("get control plane instance: %w", err) + } + return &instance, nil +} + +// ExpireControlPlaneInstances marks stale control-plane instance rows as expired. +func (cs *ConfigStore) ExpireControlPlaneInstances(cutoff time.Time) (int64, error) { + now := time.Now() + result := cs.db.Table(cs.runtimeTable((&ControlPlaneInstance{}).TableName())). + Where("state <> ? AND last_heartbeat_at < ?", ControlPlaneInstanceStateExpired, cutoff). + Updates(map[string]any{ + "state": ControlPlaneInstanceStateExpired, + "expired_at": now, + "updated_at": now, + }) + if result.Error != nil { + return 0, fmt.Errorf("expire control plane instances: %w", result.Error) + } + return result.RowsAffected, nil +} + +// ExpireDrainingControlPlaneInstances marks draining control-plane rows expired +// once their draining_at timestamp exceeds the configured handover timeout. +func (cs *ConfigStore) ExpireDrainingControlPlaneInstances(before time.Time) (int64, error) { + now := time.Now() + result := cs.db.Table(cs.runtimeTable((&ControlPlaneInstance{}).TableName())). + Where("state = ? AND draining_at IS NOT NULL AND draining_at <= ?", ControlPlaneInstanceStateDraining, before). + Updates(map[string]any{ + "state": ControlPlaneInstanceStateExpired, + "expired_at": now, + "updated_at": now, + }) + if result.Error != nil { + return 0, fmt.Errorf("expire draining control plane instances: %w", result.Error) + } + return result.RowsAffected, nil +} + +// UpsertWorkerRecord inserts or updates a runtime worker row. +func (cs *ConfigStore) UpsertWorkerRecord(record *WorkerRecord) error { + if err := cs.db.Table(cs.runtimeTable(record.TableName())).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "worker_id"}}, + DoUpdates: clause.AssignmentColumns([]string{"pod_name", "pod_uid", "state", "org_id", "owner_cp_instance_id", "owner_epoch", "activation_started_at", "last_heartbeat_at", "retire_reason", "updated_at"}), + }).Create(record).Error; err != nil { + return fmt.Errorf("upsert worker record: %w", err) + } + return nil +} + +// GetWorkerRecord returns a runtime worker row by worker id. +func (cs *ConfigStore) GetWorkerRecord(workerID int) (*WorkerRecord, error) { + var record WorkerRecord + if err := cs.db.Table(cs.runtimeTable(record.TableName())).First(&record, "worker_id = ?", workerID).Error; err != nil { + return nil, fmt.Errorf("get worker record: %w", err) + } + return &record, nil +} + +// ClaimIdleWorker atomically claims one idle worker row for a control-plane instance. +// The selected row is locked with SKIP LOCKED and transitioned to reserved while +// incrementing owner_epoch. When maxOrgWorkers is set, org claims are serialized +// under the same advisory lock used for spawn-slot allocation. +func (cs *ConfigStore) ClaimIdleWorker(ownerCPInstanceID, orgID string, maxOrgWorkers int) (*WorkerRecord, error) { + var claimed *WorkerRecord + err := cs.db.Transaction(func(tx *gorm.DB) error { + if orgID != "" { + if err := tx.Exec("SELECT pg_advisory_xact_lock(?)", advisoryLockKey("duckgres:org:"+orgID)).Error; err != nil { + return err + } + } + if maxOrgWorkers > 0 && orgID != "" { + count, err := cs.countActiveWorkers(tx, "org_id = ?", orgID) + if err != nil { + return err + } + if count >= int64(maxOrgWorkers) { + return nil + } + } + + var current WorkerRecord + err := tx.Table(cs.runtimeTable(current.TableName())). + Clauses(clause.Locking{Strength: "UPDATE", Options: "SKIP LOCKED"}). + Where("state = ?", WorkerStateIdle). + Order("worker_id ASC"). + Take(¤t).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil + } + return err + } + + now := time.Now() + if err := tx.Table(cs.runtimeTable(current.TableName())). + Where("worker_id = ?", current.WorkerID). + Updates(map[string]any{ + "state": WorkerStateReserved, + "org_id": orgID, + "owner_cp_instance_id": ownerCPInstanceID, + "owner_epoch": gorm.Expr("owner_epoch + 1"), + "updated_at": now, + }).Error; err != nil { + return err + } + + if err := tx.Table(cs.runtimeTable(current.TableName())). + First(¤t, "worker_id = ?", current.WorkerID).Error; err != nil { + return err + } + claimed = ¤t + return nil + }) + if err != nil { + return nil, fmt.Errorf("claim idle worker: %w", err) + } + return claimed, nil +} + +// TakeOverWorker transfers durable worker ownership to a new control-plane +// instance when the caller still has the expected prior owner_epoch. +func (cs *ConfigStore) TakeOverWorker(workerID int, ownerCPInstanceID, orgID string, expectedOwnerEpoch int64) (*WorkerRecord, error) { + var claimed *WorkerRecord + err := cs.db.Transaction(func(tx *gorm.DB) error { + var current WorkerRecord + err := tx.Table(cs.runtimeTable(current.TableName())). + Clauses(clause.Locking{Strength: "UPDATE"}). + Where("worker_id = ?", workerID). + Take(¤t).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil + } + return err + } + if current.OwnerEpoch != expectedOwnerEpoch { + return ErrWorkerOwnerEpochMismatch + } + now := time.Now() + if err := tx.Table(cs.runtimeTable(current.TableName())). + Where("worker_id = ?", current.WorkerID). + Updates(map[string]any{ + "state": WorkerStateReserved, + "org_id": orgID, + "owner_cp_instance_id": ownerCPInstanceID, + "owner_epoch": gorm.Expr("owner_epoch + 1"), + "updated_at": now, + }).Error; err != nil { + return err + } + if err := tx.Table(cs.runtimeTable(current.TableName())). + First(¤t, "worker_id = ?", current.WorkerID).Error; err != nil { + return err + } + claimed = ¤t + return nil + }) + if err != nil { + return nil, fmt.Errorf("take over worker: %w", err) + } + return claimed, nil +} + +// CreateSpawningWorkerSlot creates a durable spawning worker row under advisory-lock +// protected org/global capacity checks. A nil result means capacity blocked the spawn. +func (cs *ConfigStore) CreateSpawningWorkerSlot(ownerCPInstanceID, orgID string, ownerEpoch int64, podNamePrefix string, maxOrgWorkers, maxGlobalWorkers int) (*WorkerRecord, error) { + if strings.TrimSpace(podNamePrefix) == "" { + return nil, fmt.Errorf("pod name prefix is required") + } + + var created *WorkerRecord + err := cs.db.Transaction(func(tx *gorm.DB) error { + if orgID != "" { + if err := tx.Exec("SELECT pg_advisory_xact_lock(?)", advisoryLockKey("duckgres:org:"+orgID)).Error; err != nil { + return err + } + } + if err := tx.Exec("SELECT pg_advisory_xact_lock(?)", advisoryLockKey("duckgres:global-worker-capacity")).Error; err != nil { + return err + } + + if maxOrgWorkers > 0 && orgID != "" { + count, err := cs.countActiveWorkers(tx, "org_id = ?", orgID) + if err != nil { + return err + } + if count >= int64(maxOrgWorkers) { + return nil + } + } + + if maxGlobalWorkers > 0 { + count, err := cs.countActiveWorkers(tx) + if err != nil { + return err + } + if count >= int64(maxGlobalWorkers) { + return nil + } + } + + var workerID int64 + if err := tx.Raw("SELECT COALESCE(MAX(worker_id), 0) + 1 FROM " + cs.runtimeTable((&WorkerRecord{}).TableName())).Scan(&workerID).Error; err != nil { + return err + } + now := time.Now() + record := &WorkerRecord{ + WorkerID: int(workerID), + PodName: fmt.Sprintf("%s-%d", podNamePrefix, workerID), + State: WorkerStateSpawning, + OrgID: orgID, + OwnerCPInstanceID: ownerCPInstanceID, + OwnerEpoch: ownerEpoch, + LastHeartbeatAt: now, + } + if err := tx.Table(cs.runtimeTable(record.TableName())).Create(record).Error; err != nil { + return err + } + created = record + return nil + }) + if err != nil { + return nil, fmt.Errorf("create spawning worker slot: %w", err) + } + return created, nil +} + +// CreateNeutralWarmWorkerSlot creates a durable spawning worker row for the shared +// neutral warm pool under advisory-lock protected cluster-wide warm-target and +// global capacity checks. A nil result means capacity already satisfies the target +// or the global worker cap blocked the spawn. +func (cs *ConfigStore) CreateNeutralWarmWorkerSlot(ownerCPInstanceID, podNamePrefix string, targetWarmWorkers, maxGlobalWorkers int) (*WorkerRecord, error) { + if strings.TrimSpace(podNamePrefix) == "" { + return nil, fmt.Errorf("pod name prefix is required") + } + + var created *WorkerRecord + err := cs.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Exec("SELECT pg_advisory_xact_lock(?)", advisoryLockKey("duckgres:shared-warm-target")).Error; err != nil { + return err + } + if err := tx.Exec("SELECT pg_advisory_xact_lock(?)", advisoryLockKey("duckgres:global-worker-capacity")).Error; err != nil { + return err + } + + if targetWarmWorkers > 0 { + count, err := cs.countNeutralWarmWorkers(tx) + if err != nil { + return err + } + if count >= int64(targetWarmWorkers) { + return nil + } + } + + if maxGlobalWorkers > 0 { + count, err := cs.countActiveWorkers(tx) + if err != nil { + return err + } + if count >= int64(maxGlobalWorkers) { + return nil + } + } + + var workerID int64 + if err := tx.Raw("SELECT COALESCE(MAX(worker_id), 0) + 1 FROM " + cs.runtimeTable((&WorkerRecord{}).TableName())).Scan(&workerID).Error; err != nil { + return err + } + now := time.Now() + record := &WorkerRecord{ + WorkerID: int(workerID), + PodName: fmt.Sprintf("%s-%d", podNamePrefix, workerID), + State: WorkerStateSpawning, + OrgID: "", + OwnerCPInstanceID: ownerCPInstanceID, + OwnerEpoch: 0, + LastHeartbeatAt: now, + } + if err := tx.Table(cs.runtimeTable(record.TableName())).Create(record).Error; err != nil { + return err + } + created = record + return nil + }) + if err != nil { + return nil, fmt.Errorf("create neutral warm worker slot: %w", err) + } + return created, nil +} + +// ListOrphanedWorkers returns workers whose owning control-plane instance has +// already been marked expired long enough ago to pass the orphan grace cutoff. +// Retired/lost rows are included so a replacement janitor can finish deleting +// worker pods when the original control plane died after persisting retirement +// but before the Kubernetes delete completed. +func (cs *ConfigStore) ListOrphanedWorkers(before time.Time) ([]WorkerRecord, error) { + var workers []WorkerRecord + cleanupStates := []WorkerState{ + WorkerStateSpawning, + WorkerStateIdle, + WorkerStateReserved, + WorkerStateActivating, + WorkerStateHot, + WorkerStateDraining, + WorkerStateRetired, + WorkerStateLost, + } + workerTable := cs.runtimeTable((&WorkerRecord{}).TableName()) + cpTable := cs.runtimeTable((&ControlPlaneInstance{}).TableName()) + err := cs.db.Table(workerTable+" AS w"). + Select("w.*"). + Joins("JOIN "+cpTable+" AS cp ON cp.id = w.owner_cp_instance_id"). + Where("w.state IN ?", cleanupStates). + Where("cp.state = ?", ControlPlaneInstanceStateExpired). + Where("cp.expired_at IS NOT NULL AND cp.expired_at <= ?", before). + Order("w.worker_id ASC"). + Find(&workers).Error + if err != nil { + return nil, fmt.Errorf("list orphaned workers: %w", err) + } + return workers, nil +} + +// ListStuckWorkers returns workers stuck in spawning, reserved, or activating +// beyond their respective cutoffs. +func (cs *ConfigStore) ListStuckWorkers(spawningBefore, activatingBefore time.Time) ([]WorkerRecord, error) { + var workers []WorkerRecord + workerTable := cs.runtimeTable((&WorkerRecord{}).TableName()) + cpTable := cs.runtimeTable((&ControlPlaneInstance{}).TableName()) + err := cs.db.Table(workerTable+" AS w"). + Select("w.*"). + Joins("LEFT JOIN "+cpTable+" AS cp ON cp.id = w.owner_cp_instance_id"). + Where("(w.state = ? AND w.updated_at <= ?) OR (w.state IN ? AND w.updated_at <= ?)", + WorkerStateSpawning, + spawningBefore, + []WorkerState{WorkerStateReserved, WorkerStateActivating}, + activatingBefore, + ). + Where("cp.id IS NULL OR cp.state <> ?", ControlPlaneInstanceStateExpired). + Find(&workers).Error + if err != nil { + return nil, fmt.Errorf("list stuck workers: %w", err) + } + return workers, nil +} + +// ExpireFlightSessionRecords marks reconnectable Flight sessions expired when +// their reconnect deadline has passed. +func (cs *ConfigStore) ExpireFlightSessionRecords(before time.Time) (int64, error) { + result := cs.db.Table(cs.runtimeTable((&FlightSessionRecord{}).TableName())). + Where("state NOT IN ?", []FlightSessionState{FlightSessionStateExpired, FlightSessionStateClosed}). + Where("expires_at <= ?", before). + Updates(map[string]any{ + "state": FlightSessionStateExpired, + "updated_at": time.Now(), + }) + if result.Error != nil { + return 0, fmt.Errorf("expire flight session records: %w", result.Error) + } + return result.RowsAffected, nil +} + +func (cs *ConfigStore) countActiveWorkers(tx *gorm.DB, where ...any) (int64, error) { + var count int64 + activeStates := []WorkerState{ + WorkerStateSpawning, + WorkerStateIdle, + WorkerStateReserved, + WorkerStateActivating, + WorkerStateHot, + WorkerStateDraining, + } + query := tx.Table(cs.runtimeTable((&WorkerRecord{}).TableName())).Where("state IN ?", activeStates) + if len(where) > 0 { + if clauseStr, ok := where[0].(string); ok { + query = query.Where(clauseStr, where[1:]...) + } + } + if err := query.Count(&count).Error; err != nil { + return 0, err + } + return count, nil +} + +func (cs *ConfigStore) countNeutralWarmWorkers(tx *gorm.DB) (int64, error) { + var count int64 + if err := tx.Table(cs.runtimeTable((&WorkerRecord{}).TableName())). + Where("org_id = ''"). + Where("state IN ?", []WorkerState{WorkerStateIdle, WorkerStateSpawning}). + Count(&count).Error; err != nil { + return 0, err + } + return count, nil +} + +func advisoryLockKey(s string) int64 { + h := fnv.New64a() + _, _ = h.Write([]byte(s)) + return int64(h.Sum64() & 0x7fffffffffffffff) +} + +// UpsertFlightSessionRecord inserts or updates a durable Flight reconnect row. +func (cs *ConfigStore) UpsertFlightSessionRecord(record *FlightSessionRecord) error { + if err := cs.db.Table(cs.runtimeTable(record.TableName())).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "session_token"}}, + DoUpdates: clause.AssignmentColumns([]string{"username", "org_id", "worker_id", "owner_epoch", "cp_instance_id", "state", "expires_at", "last_seen_at", "updated_at"}), + }).Create(record).Error; err != nil { + return fmt.Errorf("upsert flight session record: %w", err) + } + return nil +} + +// GetFlightSessionRecord returns a durable Flight reconnect row by session token. +func (cs *ConfigStore) GetFlightSessionRecord(sessionToken string) (*FlightSessionRecord, error) { + var record FlightSessionRecord + err := cs.db.Table(cs.runtimeTable(record.TableName())).First(&record, "session_token = ?", sessionToken).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, nil + } + return nil, fmt.Errorf("get flight session record: %w", err) + } + return &record, nil +} + +func (cs *ConfigStore) TouchFlightSessionRecord(sessionToken string, lastSeenAt time.Time) error { + result := cs.db.Table(cs.runtimeTable((&FlightSessionRecord{}).TableName())). + Where("session_token = ?", sessionToken). + Updates(map[string]any{ + "last_seen_at": lastSeenAt, + "updated_at": time.Now(), + }) + if result.Error != nil { + return fmt.Errorf("touch flight session record: %w", result.Error) + } + return nil +} + +func (cs *ConfigStore) CloseFlightSessionRecord(sessionToken string, closedAt time.Time) error { + result := cs.db.Table(cs.runtimeTable((&FlightSessionRecord{}).TableName())). + Where("session_token = ?", sessionToken). + Updates(map[string]any{ + "state": FlightSessionStateClosed, + "last_seen_at": closedAt, + "updated_at": time.Now(), + }) + if result.Error != nil { + return fmt.Errorf("close flight session record: %w", result.Error) + } + return nil +} + // Reload forces an immediate config reload from the database. func (cs *ConfigStore) Reload() error { newSnap, err := cs.load() diff --git a/controlplane/control.go b/controlplane/control.go index c517d889..6fcdcd1c 100644 --- a/controlplane/control.go +++ b/controlplane/control.go @@ -21,6 +21,7 @@ import ( "time" "github.com/cloudflare/tableflip" + "github.com/posthog/duckgres/controlplane/configstore" "github.com/posthog/duckgres/server" "github.com/posthog/duckgres/server/flightsqlingress" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -61,7 +62,6 @@ type ControlPlaneConfig struct { // InternalSecret is the shared secret for API authentication. // When empty, a random secret is generated and logged at startup. InternalSecret string - } type ProcessConfig struct { @@ -71,23 +71,23 @@ type ProcessConfig struct { // K8sConfig holds Kubernetes worker backend configuration. type K8sConfig struct { - WorkerImage string // Container image for worker pods (required) - WorkerNamespace string // K8s namespace (default: auto-detect from service account) - ControlPlaneID string // Unique CP identifier for labeling worker pods (default: os.Hostname()) - WorkerPort int // gRPC port on worker pods (default: 8816) - WorkerSecret string // K8s Secret name containing bearer token - WorkerConfigMap string // ConfigMap name for duckgres.yaml - ImagePullPolicy string // Image pull policy for worker pods (e.g., "Never", "IfNotPresent", "Always") - ServiceAccount string // ServiceAccount name for worker pods (default: "default") - MaxWorkers int // Global cap for the shared K8s worker pool (0 = auto-derived) - SharedWarmTarget int // Neutral shared warm-worker target for K8s multi-tenant mode (0 = disabled) - WorkerCPURequest string // CPU request for worker pods (e.g., "500m") - WorkerMemoryRequest string // Memory request for worker pods (e.g., "1Gi") + WorkerImage string // Container image for worker pods (required) + WorkerNamespace string // K8s namespace (default: auto-detect from service account) + ControlPlaneID string // Unique CP identifier for labeling worker pods (default: os.Hostname()) + WorkerPort int // gRPC port on worker pods (default: 8816) + WorkerSecret string // K8s Secret name containing bearer token + WorkerConfigMap string // ConfigMap name for duckgres.yaml + ImagePullPolicy string // Image pull policy for worker pods (e.g., "Never", "IfNotPresent", "Always") + ServiceAccount string // ServiceAccount name for worker pods (default: "default") + MaxWorkers int // Global cap for the shared K8s worker pool (0 = auto-derived) + SharedWarmTarget int // Neutral shared warm-worker target for K8s multi-tenant mode (0 = disabled) + WorkerCPURequest string // CPU request for worker pods (e.g., "500m") + WorkerMemoryRequest string // Memory request for worker pods (e.g., "1Gi") WorkerNodeSelector string // JSON map for worker pod nodeSelector (e.g., '{"posthog.com/nodepool":"workers"}') WorkerTolerationKey string // Taint key for worker pod NoSchedule toleration WorkerTolerationValue string // Taint value for worker pod NoSchedule toleration - WorkerExclusiveNode bool // One worker per node via pod anti-affinity - AWSRegion string // AWS region for STS client + WorkerExclusiveNode bool // One worker per node via pod anti-affinity + AWSRegion string // AWS region for STS client } // ControlPlane manages the TCP listener and routes connections to Flight SQL workers. @@ -116,9 +116,11 @@ type ControlPlane struct { acmeDNSManager *server.ACMEDNSManager // ACME manager for DNS-01 (nil when not using DNS challenges) // Multi-tenant fields (non-nil in remote multitenant mode) - orgRouter OrgRouterInterface - configStore ConfigStoreInterface - apiServer *http.Server // API server on :8080 (shut down on graceful exit) + orgRouter OrgRouterInterface + configStore ConfigStoreInterface + apiServer *http.Server // API server on :8080 (shut down on graceful exit) + runtimeTracker *ControlPlaneRuntimeTracker + janitorLeader *JanitorLeaderManager } // ConfigStoreInterface abstracts the config store for the control plane. @@ -126,11 +128,16 @@ type ControlPlane struct { type ConfigStoreInterface interface { ValidateUser(username, password string) (orgID string, ok bool) OrgForUser(username string) string + UpsertFlightSessionRecord(record *configstore.FlightSessionRecord) error + GetFlightSessionRecord(sessionToken string) (*configstore.FlightSessionRecord, error) + TouchFlightSessionRecord(sessionToken string, lastSeenAt time.Time) error + CloseFlightSessionRecord(sessionToken string, closedAt time.Time) error } // OrgRouterInterface abstracts the org router for the control plane. type OrgRouterInterface interface { StackForUser(username string) (pool WorkerPool, sessions *SessionManager, rebalancer *MemoryRebalancer, ok bool) + StackForOrg(orgID string) (pool WorkerPool, sessions *SessionManager, rebalancer *MemoryRebalancer, ok bool) IsMigratingForUser(username string) (migrating bool, orgID string) ShutdownAll() } @@ -151,7 +158,11 @@ func RunControlPlane(cfg ControlPlaneConfig) { cfg.WorkerIdleTimeout = 5 * time.Minute } if cfg.HandoverDrainTimeout == 0 { - cfg.HandoverDrainTimeout = 24 * time.Hour + if cfg.WorkerBackend == "remote" { + cfg.HandoverDrainTimeout = 15 * time.Minute + } else { + cfg.HandoverDrainTimeout = 24 * time.Hour + } } // Enforce secure defaults for control-plane mode. @@ -327,7 +338,7 @@ func RunControlPlane(cfg ControlPlaneConfig) { // Multi-tenant mode: config store + per-org pools (K8s remote backend only) if cfg.WorkerBackend == "remote" { - store, adapter, apiServer, err := SetupMultiTenant(cfg, srv, memBudget, k8sMaxWorkers) + store, adapter, apiServer, runtimeTracker, janitorLeader, err := SetupMultiTenant(cfg, srv, memBudget, k8sMaxWorkers) if err != nil { slog.Error("Failed to set up multi-tenant config store.", "error", err) os.Exit(1) @@ -335,8 +346,22 @@ func RunControlPlane(cfg ControlPlaneConfig) { cp.configStore = store cp.orgRouter = adapter cp.apiServer = apiServer + cp.runtimeTracker = runtimeTracker + cp.janitorLeader = janitorLeader cp.cfg = cfg _ = store // keep linter happy + if cp.runtimeTracker != nil { + if err := cp.runtimeTracker.Start(context.Background()); err != nil { + slog.Error("Failed to start control-plane runtime tracker.", "error", err) + os.Exit(1) + } + } + if cp.janitorLeader != nil { + if err := cp.janitorLeader.Start(context.Background()); err != nil { + slog.Error("Failed to start janitor leader election.", "error", err) + os.Exit(1) + } + } } else { // Single-tenant mode: one shared process pool + session manager procPool := NewFlightWorkerPool(cfg.SocketDir, cfg.ConfigPath, processMinWorkers, processMaxWorkers) @@ -446,7 +471,19 @@ func RunControlPlane(cfg ControlPlaneConfig) { os.Exit(0) } slog.Info("Received shutdown signal.", "signal", s) - cp.shutdown() + if cp.runtimeTracker != nil { + if err := cp.runtimeTracker.MarkDraining(); err != nil { + slog.Warn("Failed to mark control plane draining.", "error", err) + } + } + if cp.janitorLeader != nil { + cp.janitorLeader.Stop() + } + if isK8s { + cp.drainAndShutdown(cp.cfg.HandoverDrainTimeout) + } else { + cp.shutdown() + } os.Exit(0) }() @@ -763,6 +800,11 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) { sessions = cp.sessions rebalancer = cp.rebalancer } + if cp.isDraining() { + _ = server.WriteErrorResponse(writer, "FATAL", "57P03", "control plane is draining, retry shortly") + _ = writer.Flush() + return + } // Feed initial parameters and backend key data to the client IMMEDIATELY. // This keeps JDBC drivers happy while we perform the slow worker acquisition. @@ -940,22 +982,88 @@ func fullRead(conn net.Conn, buf []byte) (int, error) { } func (cp *ControlPlane) shutdown() { - cp.closeMu.Lock() - cp.closed = true - cp.closeMu.Unlock() - - if cp.pgListener != nil { - _ = cp.pgListener.Close() - } + cp.stopAcceptingPGConnections() if cp.flight != nil { cp.flight.Shutdown() cp.flight = nil } + if cp.janitorLeader != nil { + cp.janitorLeader.Stop() + } // Wait for in-flight connections to finish slog.Info("Waiting for connections to drain...") cp.wg.Wait() + cp.shutdownRuntimeResources() +} + +func (cp *ControlPlane) drainAndShutdown(timeout time.Duration) { + cp.stopAcceptingPGConnections() + if cp.flight != nil { + cp.flight.BeginDrain() + } + slog.Info("Waiting for planned shutdown drain.", "timeout", timeout) + if cp.waitForDrain(timeout) { + slog.Info("All pgwire connections and Flight sessions drained before shutdown.") + } else { + slog.Warn("Planned shutdown drain timeout exceeded, forcing shutdown.", "timeout", timeout) + } + if cp.flight != nil { + cp.flight.Shutdown() + cp.flight = nil + } + cp.shutdownRuntimeResources() +} + +func (cp *ControlPlane) stopAcceptingPGConnections() { + cp.closeMu.Lock() + cp.closed = true + cp.closeMu.Unlock() + + if cp.pgListener != nil { + _ = cp.pgListener.Close() + } +} + +func (cp *ControlPlane) waitForDrain(timeout time.Duration) bool { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + pgDone := make(chan struct{}) + go func() { + cp.wg.Wait() + close(pgDone) + }() + + flightDone := make(chan bool, 1) + go func() { + if cp.flight != nil { + flightDone <- cp.flight.WaitForZeroSessions(ctx) + return + } + flightDone <- true + }() + + pgClosed := false + flightClosed := false + for !pgClosed || !flightClosed { + select { + case <-ctx.Done(): + return false + case <-pgDone: + pgClosed = true + case drained := <-flightDone: + if !drained { + return false + } + flightClosed = true + } + } + return true +} + +func (cp *ControlPlane) shutdownRuntimeResources() { slog.Info("Shutting down workers...") if cp.orgRouter != nil { cp.orgRouter.ShutdownAll() @@ -984,6 +1092,10 @@ func (cp *ControlPlane) shutdown() { slog.Info("Control plane shutdown complete.") } +func (cp *ControlPlane) isDraining() bool { + return cp.runtimeTracker != nil && cp.runtimeTracker.Draining() +} + func (cp *ControlPlane) stopQueryLogger() { if cp.srv != nil && cp.srv.QueryLogger() != nil { cp.srv.QueryLogger().Stop() @@ -1177,8 +1289,9 @@ func (cp *ControlPlane) startFlightIngress() { return ok }) provider = &orgRoutedSessionProvider{ - orgRouter: cp.orgRouter, - pidSession: make(map[int32]*SessionManager), + orgRouter: cp.orgRouter, + configStore: cp.configStore, + pidSession: make(map[int32]flightOwnedSession), } case cp.sessions != nil: // Single-tenant: static users map, single session manager. diff --git a/controlplane/flight_ingress.go b/controlplane/flight_ingress.go index 69065392..da4809b6 100644 --- a/controlplane/flight_ingress.go +++ b/controlplane/flight_ingress.go @@ -3,10 +3,13 @@ package controlplane import ( "context" "crypto/tls" + "errors" "fmt" "log/slog" "sync" + "time" + "github.com/posthog/duckgres/controlplane/configstore" "github.com/posthog/duckgres/server" "github.com/posthog/duckgres/server/flightsqlingress" ) @@ -44,13 +47,19 @@ func (p *flightSessionProvider) DestroySession(pid int32) { p.sm.DestroySession(pid) } +type flightOwnedSession struct { + orgID string + sessions *SessionManager +} + // orgRoutedSessionProvider routes Flight SQL session operations to the correct // org's SessionManager based on the username→org mapping in the config store. type orgRoutedSessionProvider struct { - orgRouter OrgRouterInterface + orgRouter OrgRouterInterface + configStore ConfigStoreInterface mu sync.RWMutex - pidSession map[int32]*SessionManager // pid → owning session manager + pidSession map[int32]flightOwnedSession // pid → owning session manager } func (p *orgRoutedSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) { @@ -69,8 +78,12 @@ func (p *orgRoutedSessionProvider) CreateSession(ctx context.Context, username s sessions.SetProtocol(workerPID, "flight") + orgID := "" + if p.configStore != nil { + orgID = p.configStore.OrgForUser(username) + } p.mu.Lock() - p.pidSession[workerPID] = sessions + p.pidSession[workerPID] = flightOwnedSession{orgID: orgID, sessions: sessions} p.mu.Unlock() return workerPID, executor, nil @@ -78,16 +91,111 @@ func (p *orgRoutedSessionProvider) CreateSession(ctx context.Context, username s func (p *orgRoutedSessionProvider) DestroySession(pid int32) { p.mu.RLock() - sm, ok := p.pidSession[pid] + owned, ok := p.pidSession[pid] p.mu.RUnlock() if !ok { slog.Warn("Flight SQL destroy: no session manager for pid.", "pid", pid) return } - sm.DestroySession(pid) + owned.sessions.DestroySession(pid) p.mu.Lock() delete(p.pidSession, pid) p.mu.Unlock() } + +func (p *orgRoutedSessionProvider) DurableSessionMetadata(pid int32, username string) (flightsqlingress.DurableSessionMetadata, error) { + p.mu.RLock() + owned, ok := p.pidSession[pid] + p.mu.RUnlock() + if !ok { + return flightsqlingress.DurableSessionMetadata{}, fmt.Errorf("no session manager for pid %d", pid) + } + workerID := owned.sessions.WorkerIDForPID(pid) + if workerID < 0 { + return flightsqlingress.DurableSessionMetadata{}, fmt.Errorf("worker not found for pid %d", pid) + } + worker, ok := owned.sessions.pool.Worker(workerID) + if !ok { + return flightsqlingress.DurableSessionMetadata{}, fmt.Errorf("worker %d not found for pid %d", workerID, pid) + } + return flightsqlingress.DurableSessionMetadata{ + Username: username, + OrgID: owned.orgID, + WorkerID: workerID, + OwnerEpoch: worker.OwnerEpoch(), + CPInstanceID: worker.OwnerCPInstanceID(), + }, nil +} + +func (p *orgRoutedSessionProvider) ReconnectSession(ctx context.Context, record flightsqlingress.DurableSessionRecord) (int32, *server.FlightExecutor, error) { + _, sessions, _, ok := p.orgRouter.StackForOrg(record.OrgID) + if !ok { + return 0, nil, fmt.Errorf("no org stack for org %q", record.OrgID) + } + + pid, executor, err := sessions.ReconnectFlightSession(ctx, record.Username, record.WorkerID, record.OwnerEpoch) + if err != nil { + if errors.Is(err, configstore.ErrWorkerOwnerEpochMismatch) { + return 0, nil, flightsqlingress.MarkDurableReconnectTerminal(err) + } + return 0, nil, err + } + + p.mu.Lock() + p.pidSession[pid] = flightOwnedSession{orgID: record.OrgID, sessions: sessions} + p.mu.Unlock() + return pid, executor, nil +} + +func (p *orgRoutedSessionProvider) DurableSessionStore() flightsqlingress.DurableSessionStore { + if p == nil || p.configStore == nil { + return nil + } + return &configStoreFlightSessionStore{store: p.configStore} +} + +type configStoreFlightSessionStore struct { + store ConfigStoreInterface +} + +func (s *configStoreFlightSessionStore) UpsertSession(record flightsqlingress.DurableSessionRecord) error { + return s.store.UpsertFlightSessionRecord(&configstore.FlightSessionRecord{ + SessionToken: record.SessionToken, + Username: record.Username, + OrgID: record.OrgID, + WorkerID: record.WorkerID, + OwnerEpoch: record.OwnerEpoch, + CPInstanceID: record.CPInstanceID, + State: configstore.FlightSessionState(record.State), + ExpiresAt: record.ExpiresAt, + LastSeenAt: record.LastSeenAt, + }) +} + +func (s *configStoreFlightSessionStore) GetSession(sessionToken string) (*flightsqlingress.DurableSessionRecord, error) { + record, err := s.store.GetFlightSessionRecord(sessionToken) + if err != nil || record == nil { + return nil, err + } + return &flightsqlingress.DurableSessionRecord{ + SessionToken: record.SessionToken, + Username: record.Username, + OrgID: record.OrgID, + WorkerID: record.WorkerID, + OwnerEpoch: record.OwnerEpoch, + CPInstanceID: record.CPInstanceID, + State: flightsqlingress.DurableSessionState(record.State), + ExpiresAt: record.ExpiresAt, + LastSeenAt: record.LastSeenAt, + }, nil +} + +func (s *configStoreFlightSessionStore) TouchSession(sessionToken string, lastSeenAt time.Time) error { + return s.store.TouchFlightSessionRecord(sessionToken, lastSeenAt) +} + +func (s *configStoreFlightSessionStore) CloseSession(sessionToken string, closedAt time.Time) error { + return s.store.CloseFlightSessionRecord(sessionToken, closedAt) +} diff --git a/controlplane/flight_ingress_test.go b/controlplane/flight_ingress_test.go index 638bebcf..0a103f90 100644 --- a/controlplane/flight_ingress_test.go +++ b/controlplane/flight_ingress_test.go @@ -1,89 +1,59 @@ -//go:build !kubernetes - package controlplane import ( "context" - "sync" "testing" -) -func TestOrgRoutedSessionProviderCreateSessionTeamNotFound(t *testing.T) { - provider := &orgRoutedSessionProvider{ - orgRouter: &mockOrgRouter{ok: false}, - pidSession: make(map[int32]*SessionManager), - } + "github.com/posthog/duckgres/server/flightsqlingress" +) - _, _, err := provider.CreateSession(context.Background(), "unknown", 0, "", 0) - if err == nil { - t.Fatal("expected error for unknown org") - } - if err.Error() != `no org configured for user "unknown"` { - t.Fatalf("unexpected error message: %v", err) - } +type reconnectTestOrgRouter struct { + stackByUserCalls int + stackByOrgCalls int + orgID string } -func TestOrgRoutedSessionProviderDestroySessionRemovesPid(t *testing.T) { - sm := NewSessionManager(nil, nil) - - provider := &orgRoutedSessionProvider{ - orgRouter: &mockOrgRouter{sessions: sm, ok: true}, - pidSession: map[int32]*SessionManager{ - 42: sm, - }, - } - - // Destroy known pid — should remove from map. - // sm.DestroySession(42) is a no-op for unknown internal session, which is fine; - // we're testing the adapter's pid map bookkeeping. - provider.DestroySession(42) - - provider.mu.RLock() - _, exists := provider.pidSession[42] - provider.mu.RUnlock() - if exists { - t.Fatal("expected pid 42 to be removed from pidSession map after destroy") - } +func (r *reconnectTestOrgRouter) StackForUser(username string) (WorkerPool, *SessionManager, *MemoryRebalancer, bool) { + r.stackByUserCalls++ + return nil, nil, nil, false } -func TestOrgRoutedSessionProviderDestroyUnknownPidNoOp(t *testing.T) { - provider := &orgRoutedSessionProvider{ - orgRouter: &mockOrgRouter{ok: true}, - pidSession: make(map[int32]*SessionManager), - } - - // Should not panic. - provider.DestroySession(999) +func (r *reconnectTestOrgRouter) StackForOrg(orgID string) (WorkerPool, *SessionManager, *MemoryRebalancer, bool) { + r.stackByOrgCalls++ + return nil, nil, nil, false } -func TestOrgRoutedSessionProviderConcurrentDestroys(t *testing.T) { - sm := NewSessionManager(nil, nil) +func (r *reconnectTestOrgRouter) IsMigratingForUser(string) (bool, string) { return false, "" } +func (r *reconnectTestOrgRouter) ShutdownAll() {} +func TestOrgRoutedSessionProviderReconnectSessionUsesDurableOrgID(t *testing.T) { + router := &reconnectTestOrgRouter{ + orgID: "analytics", + } provider := &orgRoutedSessionProvider{ - orgRouter: &mockOrgRouter{sessions: sm, ok: true}, - pidSession: make(map[int32]*SessionManager), + orgRouter: router, + pidSession: make(map[int32]flightOwnedSession), + configStore: nil, + } + + pid, _, err := provider.ReconnectSession(context.Background(), flightsqlingress.DurableSessionRecord{ + SessionToken: "flight-token", + Username: "alice", + OrgID: "analytics", + WorkerID: 17, + OwnerEpoch: 4, + CPInstanceID: "cp-old:boot-z", + }) + if err == nil { + t.Fatal("expected reconnect to fail without a live org stack") } - - // Pre-populate - for i := int32(0); i < 100; i++ { - provider.pidSession[i] = sm + if pid != 0 { + t.Fatalf("expected pid 0 on failed reconnect, got %d", pid) } - - // Concurrent destroys should not race. - var wg sync.WaitGroup - for i := int32(0); i < 100; i++ { - wg.Add(1) - go func(pid int32) { - defer wg.Done() - provider.DestroySession(pid) - }(i) + if router.stackByUserCalls != 0 { + t.Fatalf("expected reconnect not to use StackForUser, got %d calls", router.stackByUserCalls) } - wg.Wait() - - provider.mu.RLock() - remaining := len(provider.pidSession) - provider.mu.RUnlock() - if remaining != 0 { - t.Fatalf("expected all pids removed, got %d remaining", remaining) + if router.stackByOrgCalls != 1 { + t.Fatalf("expected reconnect to use StackForOrg once, got %d", router.stackByOrgCalls) } } diff --git a/controlplane/janitor.go b/controlplane/janitor.go new file mode 100644 index 00000000..e6634595 --- /dev/null +++ b/controlplane/janitor.go @@ -0,0 +1,130 @@ +package controlplane + +import ( + "context" + "log/slog" + "time" + + "github.com/posthog/duckgres/controlplane/configstore" +) + +const ( + janitorRetireReasonOrphaned = "orphaned" + janitorRetireReasonStuckActivating = "stuck_activating" +) + +type controlPlaneExpiryStore interface { + ExpireControlPlaneInstances(cutoff time.Time) (int64, error) + ExpireDrainingControlPlaneInstances(before time.Time) (int64, error) + ListOrphanedWorkers(before time.Time) ([]configstore.WorkerRecord, error) + ListStuckWorkers(spawningBefore, activatingBefore time.Time) ([]configstore.WorkerRecord, error) + ExpireFlightSessionRecords(before time.Time) (int64, error) +} + +type ControlPlaneJanitor struct { + store controlPlaneExpiryStore + interval time.Duration + expiryTimeout time.Duration + orphanGrace time.Duration + spawnTimeout time.Duration + activateTimeout time.Duration + maxDrainTimeout time.Duration + now func() time.Time + retireWorker func(record configstore.WorkerRecord, reason string) + reconcileWarmCapacity func() +} + +func NewControlPlaneJanitor(store controlPlaneExpiryStore, interval, expiryTimeout time.Duration) *ControlPlaneJanitor { + if interval <= 0 { + interval = 5 * time.Second + } + if expiryTimeout <= 0 { + expiryTimeout = 20 * time.Second + } + return &ControlPlaneJanitor{ + store: store, + interval: interval, + expiryTimeout: expiryTimeout, + orphanGrace: 30 * time.Second, + spawnTimeout: 2 * time.Minute, + activateTimeout: 2 * time.Minute, + maxDrainTimeout: 15 * time.Minute, + now: time.Now, + } +} + +func (j *ControlPlaneJanitor) Run(ctx context.Context) { + if j == nil || j.store == nil { + return + } + + j.runOnce() + + ticker := time.NewTicker(j.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + j.runOnce() + } + } +} + +func (j *ControlPlaneJanitor) runOnce() { + cutoff := j.now().Add(-j.expiryTimeout) + expired, err := j.store.ExpireControlPlaneInstances(cutoff) + if err != nil { + slog.Warn("Janitor failed to expire stale control-plane instances.", "error", err) + } else if expired > 0 { + slog.Info("Janitor expired stale control-plane instances.", "count", expired, "cutoff", cutoff) + } + + if j.maxDrainTimeout > 0 { + drainingBefore := j.now().Add(-j.maxDrainTimeout) + expiredDraining, err := j.store.ExpireDrainingControlPlaneInstances(drainingBefore) + if err != nil { + slog.Warn("Janitor failed to expire overdue draining control-plane instances.", "error", err) + } else if expiredDraining > 0 { + slog.Info("Janitor expired overdue draining control-plane instances.", "count", expiredDraining, "cutoff", drainingBefore) + } + } + + orphanedBefore := j.now().Add(-j.orphanGrace) + orphaned, err := j.store.ListOrphanedWorkers(orphanedBefore) + if err != nil { + slog.Warn("Janitor failed to list orphaned workers.", "error", err) + } else { + for _, worker := range orphaned { + j.retireRuntimeWorker(worker, janitorRetireReasonOrphaned) + } + } + + spawningBefore := j.now().Add(-j.spawnTimeout) + activatingBefore := j.now().Add(-j.activateTimeout) + stuckWorkers, err := j.store.ListStuckWorkers(spawningBefore, activatingBefore) + if err != nil { + slog.Warn("Janitor failed to list stuck workers.", "error", err) + } else { + for _, worker := range stuckWorkers { + j.retireRuntimeWorker(worker, janitorRetireReasonStuckActivating) + } + } + + if _, err := j.store.ExpireFlightSessionRecords(j.now()); err != nil { + slog.Warn("Janitor failed to expire stale Flight sessions.", "error", err) + } + + if j.reconcileWarmCapacity != nil { + j.reconcileWarmCapacity() + } +} + +func (j *ControlPlaneJanitor) retireRuntimeWorker(record configstore.WorkerRecord, reason string) { + if j == nil || j.retireWorker == nil { + return + } + j.retireWorker(record, reason) +} diff --git a/controlplane/janitor_leader_k8s.go b/controlplane/janitor_leader_k8s.go new file mode 100644 index 00000000..13c3a48f --- /dev/null +++ b/controlplane/janitor_leader_k8s.go @@ -0,0 +1,144 @@ +//go:build kubernetes + +package controlplane + +import ( + "context" + "fmt" + "log/slog" + "os" + "sync" + "time" + + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" + coordinationv1client "k8s.io/client-go/kubernetes/typed/coordination/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/leaderelection" + "k8s.io/client-go/tools/leaderelection/resourcelock" +) + +const defaultJanitorLeaseName = "duckgres-janitor" + +type leaderElectorRunner interface { + Run(context.Context) +} + +type JanitorLeaderManager struct { + elector leaderElectorRunner + leaderLoop *leaderOnlyLoop + mu sync.Mutex + cancel context.CancelFunc +} + +func NewJanitorLeaderManager(namespace, identity string, janitor *ControlPlaneJanitor) (*JanitorLeaderManager, error) { + if janitor == nil { + return nil, nil + } + if namespace == "" { + return nil, fmt.Errorf("leader election namespace is required") + } + if identity == "" { + hostname, _ := os.Hostname() + identity = hostname + } + + restCfg, err := rest.InClusterConfig() + if err != nil { + return nil, fmt.Errorf("load in-cluster config for leader election: %w", err) + } + clientset, err := kubernetes.NewForConfig(restCfg) + if err != nil { + return nil, fmt.Errorf("create kubernetes client for leader election: %w", err) + } + + return newJanitorLeaderManagerFromClients( + namespace, + defaultJanitorLeaseName, + identity, + clientset.CoreV1(), + clientset.CoordinationV1(), + janitor, + ) +} + +func newJanitorLeaderManagerFromClients( + namespace, leaseName, identity string, + coreClient corev1client.CoreV1Interface, + coordClient coordinationv1client.CoordinationV1Interface, + janitor *ControlPlaneJanitor, +) (*JanitorLeaderManager, error) { + lock, err := resourcelock.New( + resourcelock.LeasesResourceLock, + namespace, + leaseName, + coreClient, + coordClient, + resourcelock.ResourceLockConfig{ + Identity: identity, + }, + ) + if err != nil { + return nil, fmt.Errorf("create leader-election lease lock: %w", err) + } + + leaderLoop := newLeaderOnlyLoop(janitor.Run) + elector, err := leaderelection.NewLeaderElector(leaderelection.LeaderElectionConfig{ + Lock: lock, + LeaseDuration: 20 * time.Second, + RenewDeadline: 15 * time.Second, + RetryPeriod: 5 * time.Second, + ReleaseOnCancel: true, + Name: "duckgres-janitor", + Callbacks: leaderelection.LeaderCallbacks{ + OnStartedLeading: leaderLoop.onStartedLeading, + OnStoppedLeading: func() { + leaderLoop.onStoppedLeading() + slog.Info("Lost janitor leadership.") + }, + OnNewLeader: func(current string) { + slog.Debug("Janitor leader observed.", "identity", current) + }, + }, + }) + if err != nil { + return nil, fmt.Errorf("create leader elector: %w", err) + } + + return &JanitorLeaderManager{ + elector: elector, + leaderLoop: leaderLoop, + }, nil +} + +func (m *JanitorLeaderManager) Start(ctx context.Context) error { + if m == nil || m.elector == nil { + return nil + } + runCtx, cancel := context.WithCancel(ctx) + m.mu.Lock() + if m.cancel != nil { + m.cancel() + } + m.cancel = cancel + m.mu.Unlock() + go m.elector.Run(runCtx) + return nil +} + +func (m *JanitorLeaderManager) Stop() { + if m == nil { + return + } + m.mu.Lock() + cancel := m.cancel + m.cancel = nil + m.mu.Unlock() + if cancel != nil { + cancel() + } + if m.leaderLoop == nil { + return + } + m.leaderLoop.onStoppedLeading() +} diff --git a/controlplane/janitor_leader_k8s_test.go b/controlplane/janitor_leader_k8s_test.go new file mode 100644 index 00000000..d84d72b0 --- /dev/null +++ b/controlplane/janitor_leader_k8s_test.go @@ -0,0 +1,37 @@ +//go:build kubernetes + +package controlplane + +import ( + "context" + "testing" + "time" +) + +type captureLeaderElector struct { + ctxDone chan struct{} +} + +func (e *captureLeaderElector) Run(ctx context.Context) { + <-ctx.Done() + close(e.ctxDone) +} + +func TestJanitorLeaderManagerStopCancelsLeaderElection(t *testing.T) { + elector := &captureLeaderElector{ctxDone: make(chan struct{})} + manager := &JanitorLeaderManager{ + elector: elector, + leaderLoop: newLeaderOnlyLoop(func(context.Context) {}), + } + if err := manager.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + + manager.Stop() + + select { + case <-elector.ctxDone: + case <-time.After(time.Second): + t.Fatal("expected Stop to cancel leader-election context") + } +} diff --git a/controlplane/janitor_leader_stub.go b/controlplane/janitor_leader_stub.go new file mode 100644 index 00000000..78c56150 --- /dev/null +++ b/controlplane/janitor_leader_stub.go @@ -0,0 +1,17 @@ +//go:build !kubernetes + +package controlplane + +import "context" + +type JanitorLeaderManager struct{} + +func NewJanitorLeaderManager(namespace, identity string, janitor *ControlPlaneJanitor) (*JanitorLeaderManager, error) { + return nil, nil +} + +func (m *JanitorLeaderManager) Start(ctx context.Context) error { + return nil +} + +func (m *JanitorLeaderManager) Stop() {} diff --git a/controlplane/janitor_test.go b/controlplane/janitor_test.go new file mode 100644 index 00000000..0cb73b5c --- /dev/null +++ b/controlplane/janitor_test.go @@ -0,0 +1,243 @@ +package controlplane + +import ( + "errors" + "sync" + "testing" + "time" + + "github.com/posthog/duckgres/controlplane/configstore" +) + +type captureControlPlaneExpiryStore struct { + mu sync.Mutex + cutoffs []time.Time + count int64 + expireErr error + drainingCutoffs []time.Time + drainingCount int64 + orphanedBefore []time.Time + orphanedWorkers []configstore.WorkerRecord + stuckSpawningBefore []time.Time + stuckActivatingBefore []time.Time + stuckWorkers []configstore.WorkerRecord + expiredSessionsBefore []time.Time +} + +func (s *captureControlPlaneExpiryStore) ExpireControlPlaneInstances(cutoff time.Time) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.cutoffs = append(s.cutoffs, cutoff) + return s.count, s.expireErr +} + +func (s *captureControlPlaneExpiryStore) ExpireDrainingControlPlaneInstances(before time.Time) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.drainingCutoffs = append(s.drainingCutoffs, before) + return s.drainingCount, nil +} + +func (s *captureControlPlaneExpiryStore) snapshot() []time.Time { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]time.Time, len(s.cutoffs)) + copy(out, s.cutoffs) + return out +} + +func (s *captureControlPlaneExpiryStore) ListOrphanedWorkers(before time.Time) ([]configstore.WorkerRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.orphanedBefore = append(s.orphanedBefore, before) + out := make([]configstore.WorkerRecord, len(s.orphanedWorkers)) + copy(out, s.orphanedWorkers) + return out, nil +} + +func (s *captureControlPlaneExpiryStore) ListStuckWorkers(spawningBefore, activatingBefore time.Time) ([]configstore.WorkerRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.stuckSpawningBefore = append(s.stuckSpawningBefore, spawningBefore) + s.stuckActivatingBefore = append(s.stuckActivatingBefore, activatingBefore) + out := make([]configstore.WorkerRecord, len(s.stuckWorkers)) + copy(out, s.stuckWorkers) + return out, nil +} + +func (s *captureControlPlaneExpiryStore) ExpireFlightSessionRecords(before time.Time) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.expiredSessionsBefore = append(s.expiredSessionsBefore, before) + return 0, nil +} + +func TestControlPlaneJanitorRunExpiresStaleInstances(t *testing.T) { + store := &captureControlPlaneExpiryStore{} + now := time.Date(2026, time.March, 26, 15, 0, 0, 0, time.UTC) + janitor := NewControlPlaneJanitor(store, 10*time.Millisecond, 20*time.Second) + janitor.now = func() time.Time { return now } + janitor.maxDrainTimeout = 15 * time.Minute + + janitor.runOnce() + janitor.runOnce() + + calls := store.snapshot() + if len(calls) != 2 { + t.Fatalf("expected janitor to run exactly twice, got %d", len(calls)) + } + wantCutoff := now.Add(-20 * time.Second) + for i, cutoff := range calls { + if !cutoff.Equal(wantCutoff) { + t.Fatalf("call %d expected cutoff %v, got %v", i, wantCutoff, cutoff) + } + } + wantDrainCutoff := now.Add(-15 * time.Minute) + if len(store.drainingCutoffs) != 2 { + t.Fatalf("expected janitor to expire overdue draining instances exactly twice, got %d", len(store.drainingCutoffs)) + } + for i, cutoff := range store.drainingCutoffs { + if !cutoff.Equal(wantDrainCutoff) { + t.Fatalf("draining call %d expected cutoff %v, got %v", i, wantDrainCutoff, cutoff) + } + } +} + +func TestControlPlaneJanitorRunRetiresOrphanedAndStuckWorkers(t *testing.T) { + store := &captureControlPlaneExpiryStore{ + orphanedWorkers: []configstore.WorkerRecord{ + {WorkerID: 7, PodName: "duckgres-worker-7"}, + }, + stuckWorkers: []configstore.WorkerRecord{ + {WorkerID: 9, PodName: "duckgres-worker-9", State: configstore.WorkerStateActivating}, + }, + } + now := time.Date(2026, time.March, 26, 16, 0, 0, 0, time.UTC) + janitor := NewControlPlaneJanitor(store, 10*time.Millisecond, 20*time.Second) + janitor.now = func() time.Time { return now } + + var mu sync.Mutex + var retired []struct { + id int + reason string + } + janitor.retireWorker = func(record configstore.WorkerRecord, reason string) { + mu.Lock() + defer mu.Unlock() + retired = append(retired, struct { + id int + reason string + }{id: record.WorkerID, reason: reason}) + } + + janitor.runOnce() + + mu.Lock() + defer mu.Unlock() + if len(retired) != 2 { + t.Fatalf("expected janitor to retire exactly two workers, got %d", len(retired)) + } + if retired[0].id != 7 || retired[0].reason != janitorRetireReasonOrphaned { + t.Fatalf("expected orphaned worker 7 with orphaned reason, got %+v", retired[0]) + } + if retired[1].id != 9 || retired[1].reason != janitorRetireReasonStuckActivating { + t.Fatalf("expected stuck worker 9 with stuck_activating, got %+v", retired[1]) + } + if len(store.orphanedBefore) == 0 { + t.Fatal("expected orphaned worker cutoff lookup") + } + wantOrphanedBefore := now.Add(-30 * time.Second) + for i, cutoff := range store.orphanedBefore { + if !cutoff.Equal(wantOrphanedBefore) { + t.Fatalf("orphaned cutoff call %d expected %v, got %v", i, wantOrphanedBefore, cutoff) + } + } + if len(store.stuckSpawningBefore) == 0 || len(store.stuckActivatingBefore) == 0 { + t.Fatal("expected stuck worker cutoff lookup") + } + if len(store.expiredSessionsBefore) == 0 { + t.Fatal("expected expired flight session cleanup") + } +} + +func TestControlPlaneJanitorRunReconcilesWarmCapacity(t *testing.T) { + store := &captureControlPlaneExpiryStore{} + now := time.Date(2026, time.March, 27, 14, 0, 0, 0, time.UTC) + janitor := NewControlPlaneJanitor(store, 10*time.Millisecond, 20*time.Second) + janitor.now = func() time.Time { return now } + + var mu sync.Mutex + calls := 0 + janitor.reconcileWarmCapacity = func() { + mu.Lock() + defer mu.Unlock() + calls++ + } + + janitor.runOnce() + + mu.Lock() + defer mu.Unlock() + if calls != 1 { + t.Fatalf("expected janitor to reconcile warm capacity exactly once, got %d", calls) + } +} + +func TestControlPlaneJanitorRunOnceContinuesAfterExpireError(t *testing.T) { + store := &captureControlPlaneExpiryStore{ + expireErr: errors.New("boom"), + orphanedWorkers: []configstore.WorkerRecord{ + {WorkerID: 7, PodName: "duckgres-worker-7"}, + }, + stuckWorkers: []configstore.WorkerRecord{ + {WorkerID: 9, PodName: "duckgres-worker-9", State: configstore.WorkerStateActivating}, + }, + } + now := time.Date(2026, time.March, 27, 18, 0, 0, 0, time.UTC) + janitor := NewControlPlaneJanitor(store, time.Second, 20*time.Second) + janitor.now = func() time.Time { return now } + + var mu sync.Mutex + var retired []struct { + id int + reason string + } + reconciled := 0 + janitor.retireWorker = func(record configstore.WorkerRecord, reason string) { + mu.Lock() + defer mu.Unlock() + retired = append(retired, struct { + id int + reason string + }{id: record.WorkerID, reason: reason}) + } + janitor.reconcileWarmCapacity = func() { + mu.Lock() + defer mu.Unlock() + reconciled++ + } + + janitor.runOnce() + + if len(store.cutoffs) != 1 { + t.Fatalf("expected stale control-plane expiry to run once, got %d", len(store.cutoffs)) + } + if len(store.orphanedBefore) != 1 { + t.Fatalf("expected orphaned worker lookup despite expiry error, got %d", len(store.orphanedBefore)) + } + if len(store.stuckSpawningBefore) != 1 || len(store.stuckActivatingBefore) != 1 { + t.Fatalf("expected stuck worker lookup despite expiry error, got spawning=%d activating=%d", len(store.stuckSpawningBefore), len(store.stuckActivatingBefore)) + } + if len(store.expiredSessionsBefore) != 1 { + t.Fatalf("expected flight session expiry despite expiry error, got %d", len(store.expiredSessionsBefore)) + } + + mu.Lock() + defer mu.Unlock() + if len(retired) != 2 { + t.Fatalf("expected orphaned and stuck workers to be retired, got %+v", retired) + } + if reconciled != 1 { + t.Fatalf("expected warm capacity reconciliation despite expiry error, got %d", reconciled) + } +} diff --git a/controlplane/k8s_pool.go b/controlplane/k8s_pool.go index 494b9adc..34a634d8 100644 --- a/controlplane/k8s_pool.go +++ b/controlplane/k8s_pool.go @@ -5,15 +5,19 @@ package controlplane import ( "context" "crypto/rand" + "crypto/sha1" "encoding/hex" + stderrors "errors" "fmt" "log/slog" "os" "strconv" + "strings" "sync" "time" "github.com/apache/arrow-go/v18/arrow/flight/flightsql" + "github.com/posthog/duckgres/controlplane/configstore" "github.com/posthog/duckgres/server" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -42,35 +46,38 @@ type K8sWorkerPool struct { shuttingDown bool shutdownCh chan struct{} - clientset kubernetes.Interface - namespace string - cpID string - cpUID types.UID - workerImage string - workerPort int - secretName string - configMap string - configPath string - imagePullPolicy corev1.PullPolicy - serviceAccount string - workerCPURequest string // CPU request for worker pods (e.g., "500m") - workerMemoryRequest string // memory request for worker pods (e.g., "1Gi") + clientset kubernetes.Interface + namespace string + cpID string + cpInstanceID string + cpUID types.UID + workerImage string + workerPort int + secretName string + configMap string + configPath string + imagePullPolicy corev1.PullPolicy + serviceAccount string + workerCPURequest string // CPU request for worker pods (e.g., "500m") + workerMemoryRequest string // memory request for worker pods (e.g., "1Gi") workerNodeSelector map[string]string // node selector for worker pods workerTolerationKey string // taint key for NoSchedule toleration workerTolerationValue string // taint value for NoSchedule toleration - workerExclusiveNode bool // one worker per node via anti-affinity - orgID string // org ID for pod labels (multi-tenant mode) - workerIDGenerator func() int // shared ID generator across orgs (nil = internal counter) - cachedToken string // cached bearer token (immutable after setup) - informer cache.SharedIndexInformer - stopInform chan struct{} - spawnSem chan struct{} // limits concurrent pod creates to avoid overwhelming the K8s API - podReady sync.Map // podName -> chan string (pod IP); signaled by informer + workerExclusiveNode bool // one worker per node via anti-affinity + orgID string // org ID for pod labels (multi-tenant mode) + workerIDGenerator func() int // shared ID generator across orgs (nil = internal counter) + cachedToken string // cached bearer token (immutable after setup) + informer cache.SharedIndexInformer + stopInform chan struct{} + spawnSem chan struct{} // limits concurrent pod creates to avoid overwhelming the K8s API + podReady sync.Map // podName -> chan string (pod IP); signaled by informer spawnWarmWorkerFunc func(ctx context.Context, id int) error spawnWarmWorkerBackgroundFunc func(id int) activateTenantFunc func(ctx context.Context, worker *ManagedWorker, payload TenantActivationPayload) error healthCheckFunc func(context.Context, *ManagedWorker) error + connectWorkerFunc func(ctx context.Context, podName, podIP, bearerToken string) (*flightsql.Client, error) + runtimeStore RuntimeWorkerStore activatingTimeout time.Duration // max time a worker can stay in reserved/activating before being reaped } @@ -119,36 +126,41 @@ func newK8sWorkerPool(cfg K8sWorkerPoolConfig, clientset kubernetes.Interface) ( // Allow up to 3 concurrent pod creates to limit K8s API pressure. spawnConcurrency := 3 pool := &K8sWorkerPool{ - workers: make(map[int]*ManagedWorker), - maxWorkers: cfg.MaxWorkers, - idleTimeout: cfg.IdleTimeout, - shutdownCh: make(chan struct{}), - stopInform: make(chan struct{}), - clientset: clientset, - namespace: cfg.Namespace, - cpID: cfg.CPID, - workerImage: cfg.WorkerImage, - workerPort: cfg.WorkerPort, - secretName: cfg.SecretName, - configMap: cfg.ConfigMap, - configPath: cfg.ConfigPath, - imagePullPolicy: corev1.PullPolicy(cfg.ImagePullPolicy), - serviceAccount: cfg.ServiceAccount, - workerCPURequest: cfg.WorkerCPURequest, - workerMemoryRequest: cfg.WorkerMemoryRequest, + workers: make(map[int]*ManagedWorker), + maxWorkers: cfg.MaxWorkers, + idleTimeout: cfg.IdleTimeout, + shutdownCh: make(chan struct{}), + stopInform: make(chan struct{}), + clientset: clientset, + namespace: cfg.Namespace, + cpID: cfg.CPID, + cpInstanceID: cfg.CPInstanceID, + workerImage: cfg.WorkerImage, + workerPort: cfg.WorkerPort, + secretName: cfg.SecretName, + configMap: cfg.ConfigMap, + configPath: cfg.ConfigPath, + imagePullPolicy: corev1.PullPolicy(cfg.ImagePullPolicy), + serviceAccount: cfg.ServiceAccount, + workerCPURequest: cfg.WorkerCPURequest, + workerMemoryRequest: cfg.WorkerMemoryRequest, workerNodeSelector: cfg.WorkerNodeSelector, workerTolerationKey: cfg.WorkerTolerationKey, workerTolerationValue: cfg.WorkerTolerationValue, - workerExclusiveNode: cfg.WorkerExclusiveNode, - orgID: cfg.OrgID, - workerIDGenerator: cfg.WorkerIDGenerator, - spawnSem: make(chan struct{}, spawnConcurrency), + workerExclusiveNode: cfg.WorkerExclusiveNode, + orgID: cfg.OrgID, + workerIDGenerator: cfg.WorkerIDGenerator, + runtimeStore: cfg.RuntimeStore, + spawnSem: make(chan struct{}, spawnConcurrency), } // Resolve CP pod UID for owner references if err := pool.resolveCPUID(context.Background()); err != nil { slog.Warn("Could not resolve CP pod UID for owner references. Worker pods will not be garbage-collected if CP is deleted.", "error", err) } + if pool.cpInstanceID == "" { + pool.cpInstanceID = pool.cpID + } // Ensure bearer token secret exists if err := pool.ensureBearerTokenSecret(context.Background()); err != nil { @@ -174,11 +186,11 @@ func (p *K8sWorkerPool) resolveCPUID(ctx context.Context) error { } // ensureBearerTokenSecret ensures the bearer token K8s Secret exists. -// If no secret name is configured, it generates one named "duckgres-worker-token-". +// If no secret name is configured, it uses the shared default "duckgres-worker-token". // If the secret doesn't exist, it creates one with a random 32-byte hex token. func (p *K8sWorkerPool) ensureBearerTokenSecret(ctx context.Context) error { if p.secretName == "" { - p.secretName = "duckgres-worker-token-" + p.cpID + p.secretName = "duckgres-worker-token" } existing, err := p.clientset.CoreV1().Secrets(p.namespace).Get(ctx, p.secretName, metav1.GetOptions{}) @@ -365,24 +377,13 @@ func (p *K8sWorkerPool) SpawnWorker(ctx context.Context, id int) error { podName := p.podNameForWorker(id) - // Build owner references for GC on CP deletion - var ownerRefs []metav1.OwnerReference - if p.cpUID != "" { - ownerRefs = []metav1.OwnerReference{ - { - APIVersion: "v1", - Kind: "Pod", - Name: p.cpID, - UID: p.cpUID, - }, - } - } - // Build pod labels podLabels := map[string]string{ - "app": "duckgres-worker", - "duckgres/control-plane": p.cpID, - "duckgres/worker-id": strconv.Itoa(id), + "app": "duckgres-worker", + "duckgres/control-plane": p.cpID, + "duckgres/cp-instance-id": controlPlaneIDLabelValue(p.cpInstanceID), + "duckgres/worker-id": strconv.Itoa(id), + "duckgres/owner-epoch": "0", } if p.orgID != "" { podLabels["duckgres/org"] = p.orgID @@ -391,10 +392,9 @@ func (p *K8sWorkerPool) SpawnWorker(ctx context.Context, id int) error { // Build pod spec pod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ - Name: podName, - Namespace: p.namespace, - Labels: podLabels, - OwnerReferences: ownerRefs, + Name: podName, + Namespace: p.namespace, + Labels: podLabels, }, Spec: corev1.PodSpec{ RestartPolicy: corev1.RestartPolicyNever, @@ -512,6 +512,8 @@ func (p *K8sWorkerPool) SpawnWorker(ctx context.Context, id int) error { GracePeriodSeconds: int64Ptr(0), }) + p.persistWorkerRecord(p.workerRecordFor(id, nil, 0, configstore.WorkerStateSpawning, "", nil)) + // Create pod with exponential backoff on transient errors. if err := p.createPodWithBackoff(ctx, pod); err != nil { return err @@ -539,6 +541,7 @@ func (p *K8sWorkerPool) SpawnWorker(ctx context.Context, id int) error { done := make(chan struct{}) w := &ManagedWorker{ ID: id, + podName: podName, bearerToken: token, client: client, done: done, @@ -549,12 +552,55 @@ func (p *K8sWorkerPool) SpawnWorker(ctx context.Context, id int) error { workerCount := len(p.workers) observeWarmPoolLifecycleGauges(p.workers) p.mu.Unlock() + p.persistWorkerRecord(p.workerRecordFor(id, w, w.OwnerEpoch(), configstore.WorkerStateIdle, "", nil)) observeControlPlaneWorkers(workerCount) slog.Info("K8s worker spawned.", "id", id, "pod", podName, "addr", addr) return nil } +func controlPlaneIDLabelValue(cpInstanceID string) string { + sanitized := strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r + case r >= 'A' && r <= 'Z': + return r + case r >= '0' && r <= '9': + return r + case r == '-' || r == '_' || r == '.': + return r + default: + return '-' + } + }, cpInstanceID) + + sanitized = strings.TrimFunc(sanitized, func(r rune) bool { + return !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9')) + }) + if sanitized == "" { + sum := sha1.Sum([]byte(cpInstanceID)) + return hex.EncodeToString(sum[:])[:12] + } + if len(sanitized) <= 63 { + return sanitized + } + + sum := sha1.Sum([]byte(cpInstanceID)) + suffix := hex.EncodeToString(sum[:])[:12] + prefixLen := 63 - len(suffix) - 1 + if prefixLen < 1 { + prefixLen = 1 + } + prefix := strings.TrimRightFunc(sanitized[:prefixLen], func(r rune) bool { + return !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9')) + }) + if prefix == "" { + return suffix + } + return prefix + "-" + suffix +} + // createPodWithBackoff creates a pod, retrying transient K8s API errors // with exponential backoff (500ms, 1s, 2s, 4s). func (p *K8sWorkerPool) createPodWithBackoff(ctx context.Context, pod *corev1.Pod) error { @@ -700,7 +746,7 @@ func (p *K8sWorkerPool) AcquireWorker(ctx context.Context) (*ManagedWorker, erro w.peakSessions = w.activeSessions } if canSpawn { - id := p.allocateWorkerIDLocked() + id := p.allocateBackgroundSpawnIDLocked() p.spawning++ p.mu.Unlock() slog.Debug("Assigned to least-loaded worker, spawning new worker in background.", @@ -717,7 +763,7 @@ func (p *K8sWorkerPool) AcquireWorker(ctx context.Context) (*ManagedWorker, erro // 3. No live workers at all (cold start or all dead) — must block on spawn if canSpawn { - id := p.allocateWorkerIDLocked() + id := p.allocateBackgroundSpawnIDLocked() p.spawning++ p.mu.Unlock() @@ -761,7 +807,7 @@ func (p *K8sWorkerPool) spawnWorkerBackground(id int) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) defer cancel() - err := p.SpawnWorker(ctx, id) + err := p.spawnWarmWorker(ctx, id) p.mu.Lock() p.spawning-- @@ -775,14 +821,16 @@ func (p *K8sWorkerPool) spawnWorkerBackground(id int) { // ReleaseWorker decrements the active session count for a worker. func (p *K8sWorkerPool) ReleaseWorker(id int) { p.mu.Lock() - defer p.mu.Unlock() w, ok := p.workers[id] - if ok { - if w.activeSessions > 0 { - w.activeSessions-- - } - w.lastUsed = time.Now() + if !ok { + p.mu.Unlock() + return + } + if w.activeSessions > 0 { + w.activeSessions-- } + w.lastUsed = time.Now() + p.mu.Unlock() } // RetireWorker removes a worker from the pool and deletes its pod. @@ -847,15 +895,26 @@ func (p *K8sWorkerPool) Worker(id int) (*ManagedWorker, bool) { func (p *K8sWorkerPool) ActivateReservedWorker(ctx context.Context, worker *ManagedWorker, payload TenantActivationPayload) error { p.mu.Lock() var err error + var prevState SharedWorkerState + hadPrevState := false + var activatingRecord *configstore.WorkerRecord switch worker.SharedState().NormalizedLifecycle() { case WorkerLifecycleReserved: + prevState = worker.SharedState() + hadPrevState = true nextState, transitionErr := worker.SharedState().Transition(WorkerLifecycleActivating, nil) if transitionErr == nil { transitionErr = worker.SetSharedState(nextState) } + if transitionErr == nil { + now := time.Now() + activatingRecord = p.workerRecordFor(worker.ID, worker, worker.OwnerEpoch(), configstore.WorkerStateActivating, "", &now) + } err = transitionErr case WorkerLifecycleActivating: err = nil + now := time.Now() + activatingRecord = p.workerRecordFor(worker.ID, worker, worker.OwnerEpoch(), configstore.WorkerStateActivating, "", &now) default: err = fmt.Errorf("worker %d is not reserved for activation", worker.ID) } @@ -863,36 +922,51 @@ func (p *K8sWorkerPool) ActivateReservedWorker(ctx context.Context, worker *Mana if err != nil { return err } + p.persistWorkerRecord(activatingRecord) activate := p.activateTenantFunc if activate == nil { activate = func(ctx context.Context, worker *ManagedWorker, payload TenantActivationPayload) error { return worker.ActivateTenant(ctx, server.WorkerActivationPayload{ - OrgID: payload.OrgID, - LeaseExpiresAt: payload.LeaseExpiresAt, - DuckLake: payload.DuckLake, + WorkerControlMetadata: server.WorkerControlMetadata{ + WorkerID: worker.ID, + OwnerEpoch: worker.OwnerEpoch(), + CPInstanceID: worker.OwnerCPInstanceID(), + }, + OrgID: payload.OrgID, + DuckLake: payload.DuckLake, }) } } if err := activate(ctx, worker, payload); err != nil { + if hadPrevState { + p.mu.Lock() + _ = worker.SetSharedState(prevState) + p.mu.Unlock() + } p.retireWorkerWithReason(worker.ID, RetireReasonActivationFailure) return err } p.mu.Lock() - defer p.mu.Unlock() if worker.SharedState().NormalizedLifecycle() == WorkerLifecycleHot { + p.mu.Unlock() return nil } nextState, err := worker.SharedState().Transition(WorkerLifecycleHot, nil) if err != nil { + p.mu.Unlock() return err } if setErr := worker.SetSharedState(nextState); setErr != nil { + p.mu.Unlock() return setErr } + hotRecord := p.workerRecordFor(worker.ID, worker, worker.OwnerEpoch(), configstore.WorkerStateHot, "", nil) observeWarmPoolLifecycleGauges(p.workers) + p.mu.Unlock() + p.persistWorkerRecord(hotRecord) return nil } @@ -909,6 +983,22 @@ func (p *K8sWorkerPool) ReserveSharedWorker(ctx context.Context, assignment *Wor default: } + if p.runtimeStore != nil { + claimed, err := p.runtimeStore.ClaimIdleWorker(p.cpInstanceID, assignment.OrgID, assignment.MaxWorkers) + if err != nil { + return nil, err + } + if claimed != nil { + worker, reserveErr := p.reserveClaimedWorker(ctx, claimed, assignment) + if reserveErr == nil { + return worker, nil + } + slog.Warn("Claimed idle worker could not be reserved, retiring claimed pod.", "worker_id", claimed.WorkerID, "pod", claimed.PodName, "error", reserveErr) + p.retireClaimedWorker(claimed, RetireReasonCrash) + continue + } + } + p.mu.Lock() if p.shuttingDown { p.mu.Unlock() @@ -928,15 +1018,19 @@ func (p *K8sWorkerPool) ReserveSharedWorker(ctx context.Context, assignment *Wor p.mu.Unlock() return nil, err } + idle.SetOwnerCPInstanceID(p.cpInstanceID) + idle.IncrementOwnerEpoch() idle.reservedAt = time.Now() + reservedRecord := p.workerRecordFor(idle.ID, idle, idle.OwnerEpoch(), configstore.WorkerStateReserved, "", nil) observeWarmPoolLifecycleGauges(p.workers) shouldReplenish := p.shouldReplenishWarmCapacityLocked() var replenishID int if shouldReplenish { - replenishID = p.allocateWorkerIDLocked() + replenishID = p.allocateBackgroundSpawnIDLocked() p.spawning++ } p.mu.Unlock() + p.persistWorkerRecord(reservedRecord) if err := p.checkReservedWorkerLiveness(ctx, idle); err != nil { slog.Warn("Reserved warm worker failed liveness recheck.", "worker", idle.ID, "error", err) @@ -957,6 +1051,47 @@ func (p *K8sWorkerPool) ReserveSharedWorker(ctx context.Context, assignment *Wor liveCount := p.liveWorkerCountLocked() if p.maxWorkers == 0 || liveCount < p.maxWorkers { + if p.runtimeStore != nil { + slot, err := p.runtimeStore.CreateSpawningWorkerSlot( + p.cpInstanceID, + assignment.OrgID, + 1, + p.workerPodNamePrefix(), + assignment.MaxWorkers, + p.maxWorkers, + ) + if err != nil { + p.mu.Unlock() + return nil, err + } + if slot == nil { + p.mu.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(100 * time.Millisecond): + } + continue + } + p.spawning++ + p.mu.Unlock() + + err = p.spawnWarmWorker(ctx, slot.WorkerID) + + p.mu.Lock() + p.spawning-- + if err == nil { + observeWarmPoolLifecycleGauges(p.workers) + } + p.mu.Unlock() + + if err != nil { + p.retireClaimedWorker(slot, RetireReasonCrash) + return nil, err + } + return p.reserveClaimedWorker(ctx, slot, assignment) + } + id := p.allocateWorkerIDLocked() p.spawning++ p.mu.Unlock() @@ -985,6 +1120,146 @@ func (p *K8sWorkerPool) ReserveSharedWorker(ctx context.Context, assignment *Wor } } +func (p *K8sWorkerPool) reserveClaimedWorker(ctx context.Context, claimed *configstore.WorkerRecord, assignment *WorkerAssignment) (*ManagedWorker, error) { + p.mu.Lock() + if p.shuttingDown { + p.mu.Unlock() + return nil, fmt.Errorf("pool is shutting down") + } + p.cleanDeadWorkersLocked() + worker, ok := p.workers[claimed.WorkerID] + p.mu.Unlock() + + if !ok { + adopted, err := p.adoptClaimedWorker(ctx, claimed) + if err != nil { + return nil, err + } + p.mu.Lock() + if existing, exists := p.workers[claimed.WorkerID]; exists { + p.mu.Unlock() + if adopted.client != nil { + _ = adopted.client.Close() + } + worker = existing + } else { + p.workers[claimed.WorkerID] = adopted + p.mu.Unlock() + worker = adopted + } + } + + p.mu.Lock() + var reservedRecord *configstore.WorkerRecord + if p.shuttingDown { + p.mu.Unlock() + return nil, fmt.Errorf("pool is shutting down") + } + if claimed.PodName != "" { + worker.podName = claimed.PodName + } + worker.SetOwnerCPInstanceID(claimed.OwnerCPInstanceID) + worker.SetOwnerEpoch(claimed.OwnerEpoch) + nextState, err := worker.SharedState().Transition(WorkerLifecycleReserved, assignment) + if err != nil { + p.mu.Unlock() + return nil, err + } + if err := worker.SetSharedState(nextState); err != nil { + p.mu.Unlock() + return nil, err + } + worker.reservedAt = time.Now() + observeWarmPoolLifecycleGauges(p.workers) + if claimed.State != configstore.WorkerStateReserved { + reservedRecord = p.workerRecordFor(worker.ID, worker, worker.OwnerEpoch(), configstore.WorkerStateReserved, "", nil) + } + p.mu.Unlock() + if reservedRecord != nil { + p.persistWorkerRecord(reservedRecord) + } + if err := p.checkReservedWorkerLiveness(ctx, worker); err != nil { + slog.Warn("Claimed worker failed liveness recheck.", "worker", worker.ID, "pod", worker.PodName(), "error", err) + p.retireWorkerWithReason(worker.ID, RetireReasonCrash) + return nil, err + } + return worker, nil +} + +func (p *K8sWorkerPool) claimSpecificWorker(ctx context.Context, workerID int, expectedOwnerEpoch int64, assignment *WorkerAssignment) (*ManagedWorker, error) { + if p.runtimeStore == nil { + return nil, fmt.Errorf("runtime worker store is not configured") + } + if err := validateWorkerAssignment(assignment); err != nil { + return nil, err + } + record, err := p.runtimeStore.TakeOverWorker(workerID, p.cpInstanceID, assignment.OrgID, expectedOwnerEpoch) + if err != nil { + if stderrors.Is(err, configstore.ErrWorkerOwnerEpochMismatch) { + return nil, fmt.Errorf("worker %d ownership changed before takeover: %w", workerID, err) + } + return nil, err + } + if record == nil { + return nil, fmt.Errorf("worker %d could not be claimed", workerID) + } + return p.reserveClaimedWorker(ctx, record, assignment) +} + +func (p *K8sWorkerPool) adoptClaimedWorker(ctx context.Context, claimed *configstore.WorkerRecord) (*ManagedWorker, error) { + token, err := p.readBearerToken(ctx) + if err != nil { + return nil, fmt.Errorf("read bearer token: %w", err) + } + pod, err := p.clientset.CoreV1().Pods(p.namespace).Get(ctx, claimed.PodName, metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("get claimed worker pod %s: %w", claimed.PodName, err) + } + client, err := p.connectWorker(ctx, claimed.PodName, pod.Status.PodIP, token) + if err != nil { + return nil, err + } + worker := &ManagedWorker{ + ID: claimed.WorkerID, + podName: claimed.PodName, + bearerToken: token, + client: client, + done: make(chan struct{}), + } + worker.SetOwnerCPInstanceID(claimed.OwnerCPInstanceID) + worker.SetOwnerEpoch(claimed.OwnerEpoch) + return worker, nil +} + +func (p *K8sWorkerPool) connectWorker(ctx context.Context, podName, podIP, bearerToken string) (*flightsql.Client, error) { + if p.connectWorkerFunc != nil { + return p.connectWorkerFunc(ctx, podName, podIP, bearerToken) + } + if podIP == "" { + return nil, fmt.Errorf("worker pod %s has no IP", podName) + } + addr := fmt.Sprintf("%s:%d", podIP, p.workerPort) + client, err := waitForWorkerTCP(addr, bearerToken, 30*time.Second) + if err != nil { + return nil, fmt.Errorf("connect to claimed worker %s: %w", podName, err) + } + return client, nil +} + +func (p *K8sWorkerPool) retireClaimedWorker(claimed *configstore.WorkerRecord, reason string) { + worker := &ManagedWorker{ + ID: claimed.WorkerID, + podName: claimed.PodName, + done: make(chan struct{}), + } + worker.SetOwnerCPInstanceID(claimed.OwnerCPInstanceID) + worker.SetOwnerEpoch(claimed.OwnerEpoch) + p.mu.Lock() + p.markWorkerRetiredLocked(worker, reason) + p.mu.Unlock() + go p.retireWorkerPod(worker.ID, worker) +} + func (p *K8sWorkerPool) checkReservedWorkerLiveness(ctx context.Context, worker *ManagedWorker) error { check := p.healthCheckFunc if check == nil { @@ -994,7 +1269,7 @@ func (p *K8sWorkerPool) checkReservedWorkerLiveness(ctx context.Context, worker } hctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() - _, err := doHealthCheck(hctx, worker.client) + _, err := doHealthCheckWithMetadata(hctx, worker.client, p.healthCheckPayloadForWorker(worker)) return err } } @@ -1007,6 +1282,49 @@ func (p *K8sWorkerPool) SpawnMinWorkers(count int) error { return nil } + if p.runtimeStore != nil { + p.mu.Lock() + if count > p.minWorkers { + p.minWorkers = count + } + p.mu.Unlock() + + ctx := context.Background() + for i := 0; i < count; i++ { + slot, err := p.runtimeStore.CreateNeutralWarmWorkerSlot( + p.cpInstanceID, + p.workerPodNamePrefix(), + count, + p.maxWorkers, + ) + if err != nil { + return err + } + if slot == nil { + break + } + + p.mu.Lock() + p.spawning++ + p.mu.Unlock() + + err = p.spawnWarmWorker(ctx, slot.WorkerID) + + p.mu.Lock() + p.spawning-- + if err == nil { + observeWarmPoolLifecycleGauges(p.workers) + } + p.mu.Unlock() + + if err != nil { + p.retireClaimedWorker(slot, RetireReasonCrash) + return err + } + } + return nil + } + p.mu.Lock() if count > p.minWorkers { p.minWorkers = count @@ -1022,25 +1340,25 @@ func (p *K8sWorkerPool) SpawnMinWorkers(count int) error { ids := make([]int, 0, missing) for i := 0; i < missing; i++ { - ids = append(ids, p.allocateWorkerIDLocked()) + ids = append(ids, p.allocateBackgroundSpawnIDLocked()) p.spawning++ } p.mu.Unlock() ctx := context.Background() - for _, id := range ids { - if err := p.spawnWarmWorker(ctx, id); err != nil { - p.mu.Lock() - p.spawning-- - p.mu.Unlock() - return err - } + for _, id := range ids { + if err := p.spawnWarmWorker(ctx, id); err != nil { p.mu.Lock() p.spawning-- - observeWarmPoolLifecycleGauges(p.workers) p.mu.Unlock() + return err } + p.mu.Lock() + p.spawning-- + observeWarmPoolLifecycleGauges(p.workers) + p.mu.Unlock() + } return nil } @@ -1114,7 +1432,7 @@ func (p *K8sWorkerPool) HealthCheckLoop(ctx context.Context, interval time.Durat func() { defer recoverWorkerPanic(&healthErr) hctx, cancel := context.WithTimeout(ctx, 3*time.Second) - hcResult, healthErr = doHealthCheck(hctx, w.client) + hcResult, healthErr = doHealthCheckWithMetadata(hctx, w.client, p.healthCheckPayloadForWorker(w)) cancel() }() @@ -1195,7 +1513,7 @@ func (p *K8sWorkerPool) ShutdownAll() { ctx := context.Background() for _, w := range workers { - podName := p.podNameForWorker(w.ID) + podName := p.workerPodName(w) gracePeriod := int64(10) slog.Info("Shutting down K8s worker.", "id", w.ID, "pod", podName) _ = p.clientset.CoreV1().Pods(p.namespace).Delete(ctx, podName, metav1.DeleteOptions{ @@ -1218,7 +1536,7 @@ func (p *K8sWorkerPool) retireWorkerPod(id int, w *ManagedWorker) { if w.client != nil { _ = w.client.Close() } - podName := p.podNameForWorker(id) + podName := p.workerPodName(w) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _ = p.clientset.CoreV1().Pods(p.namespace).Delete(ctx, podName, metav1.DeleteOptions{ @@ -1319,7 +1637,7 @@ func (p *K8sWorkerPool) reapStuckActivatingWorkers() { var spawnIDs []int for range toRetire { if p.shouldReplenishWarmCapacityLocked() { - id := p.allocateWorkerIDLocked() + id := p.allocateBackgroundSpawnIDLocked() p.spawning++ spawnIDs = append(spawnIDs, id) } @@ -1404,14 +1722,14 @@ func (p *K8sWorkerPool) cleanDeadWorkersLocked() { go func(c *flightsql.Client) { _ = c.Close() }(w.client) } // Delete the failed pod from K8s to avoid accumulating terminated pods - go func(workerID int) { - podName := p.podNameForWorker(workerID) + go func(worker *ManagedWorker) { + podName := p.workerPodName(worker) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() _ = p.clientset.CoreV1().Pods(p.namespace).Delete(ctx, podName, metav1.DeleteOptions{ GracePeriodSeconds: int64Ptr(0), }) - }(id) + }(w) default: } } @@ -1454,6 +1772,13 @@ func (p *K8sWorkerPool) allocateWorkerIDLocked() int { return id } +func (p *K8sWorkerPool) allocateBackgroundSpawnIDLocked() int { + if p.runtimeStore != nil { + return 0 + } + return p.allocateWorkerIDLocked() +} + func (p *K8sWorkerPool) removeWorkerLocked(id int) (*ManagedWorker, int, int, bool) { w, ok := p.workers[id] if !ok { @@ -1465,7 +1790,7 @@ func (p *K8sWorkerPool) removeWorkerLocked(id int) (*ManagedWorker, int, int, bo if !p.shouldReplenishWarmCapacityLocked() { return w, workerCount, 0, false } - replacementID := p.allocateWorkerIDLocked() + replacementID := p.allocateBackgroundSpawnIDLocked() p.spawning++ return w, workerCount, replacementID, true } @@ -1508,6 +1833,9 @@ func (p *K8sWorkerPool) idleWarmWorkerCountLocked() int { } func (p *K8sWorkerPool) shouldReplenishWarmCapacityLocked() bool { + if p.runtimeStore != nil { + return false + } if p.minWorkers <= 0 { return false } @@ -1519,6 +1847,21 @@ func (p *K8sWorkerPool) shouldReplenishWarmCapacityLocked() bool { } func (p *K8sWorkerPool) spawnWarmWorker(ctx context.Context, id int) error { + if id <= 0 && p.runtimeStore != nil { + slot, err := p.runtimeStore.CreateNeutralWarmWorkerSlot( + p.cpInstanceID, + p.workerPodNamePrefix(), + p.minWorkers, + p.maxWorkers, + ) + if err != nil { + return err + } + if slot == nil { + return nil + } + id = slot.WorkerID + } if p.spawnWarmWorkerFunc != nil { return p.spawnWarmWorkerFunc(ctx, id) } @@ -1542,10 +1885,67 @@ func (p *K8sWorkerPool) markWorkerRetiredLocked(w *ManagedWorker, reason string) return } _ = w.SetSharedState(nextState) + workerState := configstore.WorkerStateRetired + if reason == RetireReasonCrash { + workerState = configstore.WorkerStateLost + } + p.persistWorkerRecord(p.workerRecordFor(w.ID, w, w.OwnerEpoch(), workerState, reason, nil)) observeWorkerRetirement(reason) observeWarmPoolLifecycleGauges(p.workers) } +func (p *K8sWorkerPool) persistWorkerRecord(record *configstore.WorkerRecord) { + if p.runtimeStore == nil || record == nil { + return + } + if err := p.runtimeStore.UpsertWorkerRecord(record); err != nil { + slog.Warn("Persisting worker runtime record failed.", "worker_id", record.WorkerID, "state", record.State, "error", err) + } +} + +func (p *K8sWorkerPool) workerRecordFor(id int, worker *ManagedWorker, ownerEpoch int64, state configstore.WorkerState, retireReason string, activationStartedAt *time.Time) *configstore.WorkerRecord { + record := &configstore.WorkerRecord{ + WorkerID: id, + PodName: p.podNameForWorker(id), + State: state, + OwnerCPInstanceID: p.cpInstanceID, + OwnerEpoch: ownerEpoch, + LastHeartbeatAt: time.Now(), + RetireReason: retireReason, + } + if activationStartedAt != nil { + startedAt := *activationStartedAt + record.ActivationStartedAt = &startedAt + } + if worker == nil { + if state == configstore.WorkerStateIdle { + record.OwnerCPInstanceID = "" + } + return record + } + record.PodName = p.workerPodName(worker) + record.OwnerCPInstanceID = worker.OwnerCPInstanceID() + if assignment := worker.SharedState().Assignment; assignment != nil { + record.OrgID = assignment.OrgID + } + if state == configstore.WorkerStateIdle { + record.OwnerCPInstanceID = "" + record.OrgID = "" + } + return record +} + +func (p *K8sWorkerPool) healthCheckPayloadForWorker(worker *ManagedWorker) server.WorkerHealthCheckPayload { + payload := server.WorkerHealthCheckPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + WorkerID: worker.ID, + OwnerEpoch: worker.OwnerEpoch(), + CPInstanceID: worker.OwnerCPInstanceID(), + }, + } + return payload +} + // podNameForWorker returns the pod name for a given worker ID, // including the org ID if set (multi-tenant mode). func (p *K8sWorkerPool) podNameForWorker(id int) string { @@ -1555,6 +1955,23 @@ func (p *K8sWorkerPool) podNameForWorker(id int) string { return fmt.Sprintf("duckgres-worker-%s-%d", p.cpID, id) } +func (p *K8sWorkerPool) workerPodName(worker *ManagedWorker) string { + if worker != nil && worker.PodName() != "" { + return worker.PodName() + } + if worker == nil { + return "" + } + return p.podNameForWorker(worker.ID) +} + +func (p *K8sWorkerPool) workerPodNamePrefix() string { + if p.orgID != "" { + return fmt.Sprintf("duckgres-worker-%s-%s", p.cpID, p.orgID) + } + return fmt.Sprintf("duckgres-worker-%s", p.cpID) +} + // SetMaxWorkers updates the maximum number of workers. 0 means unlimited. func (p *K8sWorkerPool) SetMaxWorkers(n int) { p.mu.Lock() @@ -1573,6 +1990,12 @@ func (p *K8sWorkerPool) SetWarmCapacityTarget(n int) { p.minWorkers = n } +func (p *K8sWorkerPool) WarmCapacityTarget() int { + p.mu.RLock() + defer p.mu.RUnlock() + return p.minWorkers +} + func boolPtr(b bool) *bool { return &b } func int64Ptr(i int64) *int64 { return &i } diff --git a/controlplane/k8s_pool_test.go b/controlplane/k8s_pool_test.go index ccf386f2..cd89a890 100644 --- a/controlplane/k8s_pool_test.go +++ b/controlplane/k8s_pool_test.go @@ -4,11 +4,15 @@ package controlplane import ( "context" + "errors" + "regexp" "strconv" "sync" "testing" "time" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql" + "github.com/posthog/duckgres/controlplane/configstore" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -16,6 +20,144 @@ import ( k8stesting "k8s.io/client-go/testing" ) +type captureRuntimeWorkerStore struct { + mu sync.Mutex + records []configstore.WorkerRecord + claimed *configstore.WorkerRecord + claimErr error + claimCalls int + claimOwnerCPID string + claimOrgID string + claimMaxOrgWorkers int + spawned *configstore.WorkerRecord + spawnErr error + spawnCalls int + spawnOwnerCPID string + spawnOrgID string + spawnOwnerEpoch int64 + spawnPodNamePrefix string + spawnMaxOrgWorkers int + spawnMaxGlobalWorks int + neutralSpawned *configstore.WorkerRecord + neutralSpawnErr error + neutralSpawnCalls int + neutralSpawnOwnerCPID string + neutralSpawnPodPrefix string + neutralSpawnTarget int + neutralSpawnMaxGlobal int + takenOver *configstore.WorkerRecord + takeOverErr error + takeOverWorkerID int + takeOverOwnerCPID string + takeOverOrgID string + takeOverExpectedEpoch int64 +} + +func (s *captureRuntimeWorkerStore) UpsertWorkerRecord(record *configstore.WorkerRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + s.records = append(s.records, *record) + return nil +} + +func (s *captureRuntimeWorkerStore) snapshot() []configstore.WorkerRecord { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]configstore.WorkerRecord, len(s.records)) + copy(out, s.records) + return out +} + +func (s *captureRuntimeWorkerStore) ClaimIdleWorker(ownerCPInstanceID, orgID string, maxOrgWorkers int) (*configstore.WorkerRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.claimCalls++ + s.claimOwnerCPID = ownerCPInstanceID + s.claimOrgID = orgID + s.claimMaxOrgWorkers = maxOrgWorkers + if s.claimErr != nil { + return nil, s.claimErr + } + if s.claimed == nil { + return nil, nil + } + claimed := *s.claimed + return &claimed, nil +} + +func (s *captureRuntimeWorkerStore) CreateSpawningWorkerSlot(ownerCPInstanceID, orgID string, ownerEpoch int64, podNamePrefix string, maxOrgWorkers, maxGlobalWorkers int) (*configstore.WorkerRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.spawnCalls++ + s.spawnOwnerCPID = ownerCPInstanceID + s.spawnOrgID = orgID + s.spawnOwnerEpoch = ownerEpoch + s.spawnPodNamePrefix = podNamePrefix + s.spawnMaxOrgWorkers = maxOrgWorkers + s.spawnMaxGlobalWorks = maxGlobalWorkers + if s.spawnErr != nil { + return nil, s.spawnErr + } + if s.spawned == nil { + return nil, nil + } + spawned := *s.spawned + return &spawned, nil +} + +func (s *captureRuntimeWorkerStore) CreateNeutralWarmWorkerSlot(ownerCPInstanceID, podNamePrefix string, targetWarmWorkers, maxGlobalWorkers int) (*configstore.WorkerRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.neutralSpawnCalls++ + s.neutralSpawnOwnerCPID = ownerCPInstanceID + s.neutralSpawnPodPrefix = podNamePrefix + s.neutralSpawnTarget = targetWarmWorkers + s.neutralSpawnMaxGlobal = maxGlobalWorkers + if s.neutralSpawnErr != nil { + return nil, s.neutralSpawnErr + } + if s.neutralSpawned == nil { + return nil, nil + } + spawned := *s.neutralSpawned + return &spawned, nil +} + +func (s *captureRuntimeWorkerStore) GetWorkerRecord(workerID int) (*configstore.WorkerRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.claimed != nil && s.claimed.WorkerID == workerID { + record := *s.claimed + return &record, nil + } + if s.spawned != nil && s.spawned.WorkerID == workerID { + record := *s.spawned + return &record, nil + } + if s.takenOver != nil && s.takenOver.WorkerID == workerID { + record := *s.takenOver + return &record, nil + } + return nil, nil +} + +func (s *captureRuntimeWorkerStore) TakeOverWorker(workerID int, ownerCPInstanceID, orgID string, expectedOwnerEpoch int64) (*configstore.WorkerRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.takeOverWorkerID = workerID + s.takeOverOwnerCPID = ownerCPInstanceID + s.takeOverOrgID = orgID + s.takeOverExpectedEpoch = expectedOwnerEpoch + if s.takeOverErr != nil { + return nil, s.takeOverErr + } + if s.takenOver == nil { + return nil, nil + } + record := *s.takenOver + return &record, nil +} + func newTestK8sPool(t *testing.T, maxWorkers int) (*K8sWorkerPool, *fake.Clientset) { t.Helper() cs := fake.NewSimpleClientset() @@ -34,19 +176,20 @@ func newTestK8sPool(t *testing.T, maxWorkers int) (*K8sWorkerPool, *fake.Clients } pool := &K8sWorkerPool{ - workers: make(map[int]*ManagedWorker), - maxWorkers: maxWorkers, - idleTimeout: 5 * time.Minute, - shutdownCh: make(chan struct{}), - stopInform: make(chan struct{}), - clientset: cs, - namespace: "default", - cpID: "test-cp", - cpUID: "cp-uid-123", - workerImage: "duckgres:test", - workerPort: 8816, - secretName: "test-secret", - spawnSem: make(chan struct{}, 1), + workers: make(map[int]*ManagedWorker), + maxWorkers: maxWorkers, + idleTimeout: 5 * time.Minute, + shutdownCh: make(chan struct{}), + stopInform: make(chan struct{}), + clientset: cs, + namespace: "default", + cpID: "test-cp", + cpInstanceID: "cp-uid-123:boot-abc", + cpUID: "cp-uid-123", + workerImage: "duckgres:test", + workerPort: 8816, + secretName: "test-secret", + spawnSem: make(chan struct{}, 1), } return pool, cs @@ -71,6 +214,26 @@ func TestK8sPool_EnsureBearerTokenSecret_CreatesNew(t *testing.T) { } } +func TestK8sPool_EnsureBearerTokenSecret_DefaultsToSharedSecretName(t *testing.T) { + pool, cs := newTestK8sPool(t, 5) + pool.secretName = "" + + if err := pool.ensureBearerTokenSecret(context.Background()); err != nil { + t.Fatalf("ensureBearerTokenSecret failed: %v", err) + } + if pool.secretName != "duckgres-worker-token" { + t.Fatalf("expected shared default secret name duckgres-worker-token, got %q", pool.secretName) + } + + secret, err := cs.CoreV1().Secrets("default").Get(context.Background(), "duckgres-worker-token", metav1.GetOptions{}) + if err != nil { + t.Fatalf("shared default secret not found: %v", err) + } + if _, ok := secret.Data["bearer-token"]; !ok { + t.Fatal("shared default secret missing bearer-token key") + } +} + func TestK8sPool_EnsureBearerTokenSecret_ExistingIsPreserved(t *testing.T) { pool, cs := newTestK8sPool(t, 5) @@ -210,8 +373,7 @@ func TestK8sPoolActivateReservedWorkerTransitionsToHot(t *testing.T) { if err := worker.SetSharedState(SharedWorkerState{ Lifecycle: WorkerLifecycleReserved, Assignment: &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "analytics", }, }); err != nil { t.Fatalf("SetSharedState: %v", err) @@ -246,8 +408,7 @@ func TestK8sPoolActivateReservedWorkerRetiresOnFailure(t *testing.T) { if err := worker.SetSharedState(SharedWorkerState{ Lifecycle: WorkerLifecycleReserved, Assignment: &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "analytics", }, }); err != nil { t.Fatalf("SetSharedState: %v", err) @@ -271,6 +432,45 @@ func TestK8sPoolActivateReservedWorkerRetiresOnFailure(t *testing.T) { } } +func TestK8sPoolReserveClaimedWorkerUnlocksPoolOnTransitionError(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + worker := &ManagedWorker{ID: 9, done: make(chan struct{})} + if err := worker.SetSharedState(SharedWorkerState{ + Lifecycle: WorkerLifecycleHot, + Assignment: &WorkerAssignment{ + OrgID: "analytics", + }, + }); err != nil { + t.Fatalf("SetSharedState: %v", err) + } + pool.workers[worker.ID] = worker + + _, err := pool.reserveClaimedWorker(context.Background(), &configstore.WorkerRecord{ + WorkerID: worker.ID, + OwnerCPInstanceID: "cp-2:boot-b", + OwnerEpoch: 7, + State: configstore.WorkerStateReserved, + }, &WorkerAssignment{ + OrgID: "billing", + }) + if err == nil { + t.Fatal("expected transition error") + } + + locked := make(chan struct{}) + go func() { + pool.mu.Lock() + pool.mu.Unlock() + close(locked) + }() + + select { + case <-locked: + case <-time.After(time.Second): + t.Fatal("expected reserveClaimedWorker to unlock pool mutex on error") + } +} + func TestK8sPoolReserveSharedWorkerSkipsUnhealthyIdleWorker(t *testing.T) { pool, _ := newTestK8sPool(t, 2) stale := &ManagedWorker{ID: 1, done: make(chan struct{})} @@ -300,8 +500,7 @@ func TestK8sPoolReserveSharedWorkerSkipsUnhealthyIdleWorker(t *testing.T) { } got, err := pool.ReserveSharedWorker(context.Background(), &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "analytics", }) if err != nil { t.Fatalf("ReserveSharedWorker: %v", err) @@ -421,8 +620,7 @@ func TestK8sPoolSpawnMinWorkersCountsOnlyNeutralIdleWorkersAsWarmCapacity(t *tes if err := worker.SetSharedState(SharedWorkerState{ Lifecycle: WorkerLifecycleReserved, Assignment: &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "analytics", }, }); err != nil { t.Fatalf("SetSharedState(reserved %d): %v", id, err) @@ -455,8 +653,7 @@ func TestK8sPoolFindIdleWorkerSkipsReservedSharedWorker(t *testing.T) { if err := reserved.SetSharedState(SharedWorkerState{ Lifecycle: WorkerLifecycleReserved, Assignment: &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "analytics", }, }); err != nil { t.Fatalf("SetSharedState(reserved): %v", err) @@ -472,9 +669,11 @@ func TestK8sPoolFindIdleWorkerSkipsReservedSharedWorker(t *testing.T) { } } -func TestK8sPoolReserveSharedWorkerReservesIdleWorkerAndReplenishesWarmCapacity(t *testing.T) { +func TestK8sPoolReserveSharedWorkerReservesIdleWorkerWithoutLocalReplenishmentInRuntimeMode(t *testing.T) { pool, _ := newTestK8sPool(t, 5) pool.minWorkers = 1 + store := &captureRuntimeWorkerStore{} + pool.runtimeStore = store idle := &ManagedWorker{ID: 7, done: make(chan struct{})} pool.workers[idle.ID] = idle pool.healthCheckFunc = func(ctx context.Context, worker *ManagedWorker) error { @@ -494,13 +693,11 @@ func TestK8sPoolReserveSharedWorkerReservesIdleWorkerAndReplenishesWarmCapacity( pool.mu.Unlock() } - leaseExpiry := time.Date(2026, time.March, 20, 16, 0, 0, 0, time.UTC) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() worker, err := pool.ReserveSharedWorker(ctx, &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: leaseExpiry, + OrgID: "analytics", }) if err != nil { t.Fatalf("ReserveSharedWorker: %v", err) @@ -516,14 +713,516 @@ func TestK8sPoolReserveSharedWorkerReservesIdleWorkerAndReplenishesWarmCapacity( if state.Assignment == nil || state.Assignment.OrgID != "analytics" { t.Fatalf("expected analytics assignment, got %#v", state.Assignment) } - if !state.Assignment.LeaseExpiresAt.Equal(leaseExpiry) { - t.Fatalf("expected lease expiry %v, got %v", leaseExpiry, state.Assignment.LeaseExpiresAt) + if worker.ownerEpoch != 1 { + t.Fatalf("expected owner epoch 1 after reservation, got %d", worker.ownerEpoch) + } + + records := store.snapshot() + if len(records) == 0 { + t.Fatal("expected reservation to persist a worker record") + } + last := records[len(records)-1] + if last.WorkerID != worker.ID { + t.Fatalf("expected worker_id %d, got %d", worker.ID, last.WorkerID) + } + if last.State != configstore.WorkerStateReserved { + t.Fatalf("expected reserved worker record, got %q", last.State) + } + if last.OwnerCPInstanceID != pool.cpInstanceID { + t.Fatalf("expected owner_cp_instance_id %q, got %q", pool.cpInstanceID, last.OwnerCPInstanceID) + } + if last.OwnerEpoch != 1 { + t.Fatalf("expected owner_epoch 1, got %d", last.OwnerEpoch) + } + if last.OrgID != "analytics" { + t.Fatalf("expected org_id analytics, got %q", last.OrgID) } select { - case <-replacementSpawned: - case <-time.After(time.Second): - t.Fatal("expected reserve to trigger warm-pool replenishment") + case id := <-replacementSpawned: + t.Fatalf("did not expect local warm-pool replenishment in runtime mode, got background spawn %d", id) + default: + } +} + +func TestK8sPoolReserveSharedWorkerClaimsRuntimeWorkerAndAdoptsPod(t *testing.T) { + pool, cs := newTestK8sPool(t, 5) + pool.minWorkers = 0 + store := &captureRuntimeWorkerStore{ + claimed: &configstore.WorkerRecord{ + WorkerID: 21, + PodName: "duckgres-worker-other-cp-21", + State: configstore.WorkerStateReserved, + OrgID: "analytics", + OwnerCPInstanceID: pool.cpInstanceID, + OwnerEpoch: 3, + }, + } + pool.runtimeStore = store + + _, err := cs.CoreV1().Pods("default").Create(context.Background(), &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "duckgres-worker-other-cp-21", + Namespace: "default", + Labels: map[string]string{ + "duckgres/control-plane": "other-cp", + "duckgres/worker-id": "21", + }, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + PodIP: "10.0.0.21", + }, + }, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("create adopted worker pod: %v", err) + } + + var connectedPodName string + var connectedPodIP string + pool.connectWorkerFunc = func(ctx context.Context, podName, podIP, bearerToken string) (*flightsql.Client, error) { + connectedPodName = podName + connectedPodIP = podIP + return nil, nil + } + pool.healthCheckFunc = func(ctx context.Context, worker *ManagedWorker) error { + if worker == nil { + t.Fatal("expected claimed worker for liveness check") + } + if worker.ID != 21 { + t.Fatalf("expected claimed worker id 21, got %d", worker.ID) + } + return nil + } + pool.cachedToken = "shared-token" + + worker, err := pool.ReserveSharedWorker(context.Background(), &WorkerAssignment{ + OrgID: "analytics", + }) + if err != nil { + t.Fatalf("ReserveSharedWorker: %v", err) + } + if worker.ID != 21 { + t.Fatalf("expected claimed worker 21, got %d", worker.ID) + } + if worker.PodName() != "duckgres-worker-other-cp-21" { + t.Fatalf("expected tracked pod name duckgres-worker-other-cp-21, got %q", worker.PodName()) + } + if worker.OwnerEpoch() != 3 { + t.Fatalf("expected claimed owner epoch 3, got %d", worker.OwnerEpoch()) + } + if worker.OwnerCPInstanceID() != pool.cpInstanceID { + t.Fatalf("expected owner cp instance id %q, got %q", pool.cpInstanceID, worker.OwnerCPInstanceID()) + } + if connectedPodName != "duckgres-worker-other-cp-21" || connectedPodIP != "10.0.0.21" { + t.Fatalf("expected connection to claimed pod, got name=%q ip=%q", connectedPodName, connectedPodIP) + } + if store.claimCalls != 1 { + t.Fatalf("expected one claim call, got %d", store.claimCalls) + } + if store.claimOwnerCPID != pool.cpInstanceID { + t.Fatalf("expected claim owner cp instance id %q, got %q", pool.cpInstanceID, store.claimOwnerCPID) + } + if store.claimOrgID != "analytics" { + t.Fatalf("expected claim org analytics, got %q", store.claimOrgID) + } + if store.claimMaxOrgWorkers != 0 { + t.Fatalf("expected default max org workers 0, got %d", store.claimMaxOrgWorkers) + } + + state := worker.SharedState() + if state.Lifecycle != WorkerLifecycleReserved { + t.Fatalf("expected reserved lifecycle, got %q", state.Lifecycle) + } + if state.Assignment == nil || state.Assignment.OrgID != "analytics" { + t.Fatalf("expected analytics assignment, got %#v", state.Assignment) + } +} + +func TestK8sPoolReserveSharedWorkerFallsBackWhenRuntimeClaimReturnsNil(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + store := &captureRuntimeWorkerStore{} + pool.runtimeStore = store + idle := &ManagedWorker{ID: 8, done: make(chan struct{})} + pool.workers[idle.ID] = idle + pool.healthCheckFunc = func(ctx context.Context, worker *ManagedWorker) error { + return nil + } + + worker, err := pool.ReserveSharedWorker(context.Background(), &WorkerAssignment{ + OrgID: "analytics", + }) + if err != nil { + t.Fatalf("ReserveSharedWorker: %v", err) + } + if worker.ID != idle.ID { + t.Fatalf("expected fallback idle worker %d, got %d", idle.ID, worker.ID) + } + if store.claimCalls != 1 { + t.Fatalf("expected one claim attempt before fallback, got %d", store.claimCalls) + } +} + +func TestK8sPoolReserveSharedWorkerPassesOrgCapToRuntimeClaim(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + store := &captureRuntimeWorkerStore{} + pool.runtimeStore = store + idle := &ManagedWorker{ID: 12, done: make(chan struct{})} + pool.workers[idle.ID] = idle + pool.healthCheckFunc = func(ctx context.Context, worker *ManagedWorker) error { return nil } + + _, err := pool.ReserveSharedWorker(context.Background(), &WorkerAssignment{ + OrgID: "analytics", + MaxWorkers: 3, + }) + if err != nil { + t.Fatalf("ReserveSharedWorker: %v", err) + } + if store.claimMaxOrgWorkers != 3 { + t.Fatalf("expected claim max org workers 3, got %d", store.claimMaxOrgWorkers) + } +} + +func TestK8sPoolClaimSpecificWorkerTakesOverRuntimeWorker(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + store := &captureRuntimeWorkerStore{ + takenOver: &configstore.WorkerRecord{ + WorkerID: 44, + PodName: "duckgres-worker-test-cp-44", + State: configstore.WorkerStateReserved, + OrgID: "analytics", + OwnerCPInstanceID: pool.cpInstanceID, + OwnerEpoch: 8, + }, + } + pool.runtimeStore = store + worker := &ManagedWorker{ID: 44, done: make(chan struct{})} + pool.workers[worker.ID] = worker + livenessChecked := false + pool.healthCheckFunc = func(ctx context.Context, worker *ManagedWorker) error { + livenessChecked = true + return nil + } + + claimed, err := pool.claimSpecificWorker(context.Background(), 44, 7, &WorkerAssignment{ + OrgID: "analytics", + MaxWorkers: 3, + }) + if err != nil { + t.Fatalf("claimSpecificWorker: %v", err) + } + if claimed.ID != 44 { + t.Fatalf("expected claimed worker 44, got %d", claimed.ID) + } + if store.takeOverWorkerID != 44 { + t.Fatalf("expected takeover worker id 44, got %d", store.takeOverWorkerID) + } + if store.takeOverOwnerCPID != pool.cpInstanceID { + t.Fatalf("expected takeover owner cp id %q, got %q", pool.cpInstanceID, store.takeOverOwnerCPID) + } + if store.takeOverOrgID != "analytics" { + t.Fatalf("expected takeover org analytics, got %q", store.takeOverOrgID) + } + if store.takeOverExpectedEpoch != 7 { + t.Fatalf("expected takeover expected epoch 7, got %d", store.takeOverExpectedEpoch) + } + if claimed.OwnerEpoch() != 8 { + t.Fatalf("expected owner epoch 8, got %d", claimed.OwnerEpoch()) + } + state := claimed.SharedState() + if state.Lifecycle != WorkerLifecycleReserved { + t.Fatalf("expected reserved lifecycle, got %q", state.Lifecycle) + } + if state.Assignment == nil || state.Assignment.OrgID != "analytics" { + t.Fatalf("expected analytics assignment, got %#v", state.Assignment) + } + if !livenessChecked { + t.Fatal("expected claimSpecificWorker to recheck worker liveness") + } +} + +func TestK8sPoolClaimSpecificWorkerReturnsEpochMismatchError(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + store := &captureRuntimeWorkerStore{ + takeOverErr: configstore.ErrWorkerOwnerEpochMismatch, + } + pool.runtimeStore = store + + claimed, err := pool.claimSpecificWorker(context.Background(), 44, 7, &WorkerAssignment{ + OrgID: "analytics", + MaxWorkers: 3, + }) + if err == nil { + t.Fatal("expected stale takeover to return an error") + } + if !errors.Is(err, configstore.ErrWorkerOwnerEpochMismatch) { + t.Fatalf("expected ErrWorkerOwnerEpochMismatch, got %v", err) + } + if claimed != nil { + t.Fatalf("expected no claimed worker, got %#v", claimed) + } +} + +func TestK8sPoolClaimSpecificWorkerRetiresUnhealthyWorker(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + store := &captureRuntimeWorkerStore{ + takenOver: &configstore.WorkerRecord{ + WorkerID: 44, + PodName: "duckgres-worker-test-cp-44", + State: configstore.WorkerStateReserved, + OrgID: "analytics", + OwnerCPInstanceID: pool.cpInstanceID, + OwnerEpoch: 8, + }, + } + pool.runtimeStore = store + pool.workers[44] = &ManagedWorker{ID: 44, done: make(chan struct{})} + pool.healthCheckFunc = func(ctx context.Context, worker *ManagedWorker) error { + return errors.New("dead worker") + } + + claimed, err := pool.claimSpecificWorker(context.Background(), 44, 7, &WorkerAssignment{ + OrgID: "analytics", + MaxWorkers: 3, + }) + if err == nil { + t.Fatal("expected unhealthy claimed worker to fail liveness recheck") + } + if claimed != nil { + t.Fatalf("expected no claimed worker, got %#v", claimed) + } + if _, ok := pool.Worker(44); ok { + t.Fatal("expected unhealthy worker to be retired from the pool") + } +} + +func TestK8sPoolReserveSharedWorkerCreatesRuntimeSpawningSlotWhenPoolIsCold(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + store := &captureRuntimeWorkerStore{ + spawned: &configstore.WorkerRecord{ + WorkerID: 31, + PodName: "duckgres-worker-test-cp-31", + State: configstore.WorkerStateSpawning, + OrgID: "analytics", + OwnerCPInstanceID: pool.cpInstanceID, + OwnerEpoch: 1, + }, + } + pool.runtimeStore = store + pool.spawnWarmWorkerFunc = func(ctx context.Context, id int) error { + worker := &ManagedWorker{ID: id, podName: "duckgres-worker-test-cp-31", done: make(chan struct{})} + pool.workers[id] = worker + return nil + } + pool.healthCheckFunc = func(ctx context.Context, worker *ManagedWorker) error { + if worker == nil || worker.ID != 31 { + t.Fatalf("expected spawned runtime worker 31, got %#v", worker) + } + return nil + } + + worker, err := pool.ReserveSharedWorker(context.Background(), &WorkerAssignment{ + OrgID: "analytics", + MaxWorkers: 2, + }) + if err != nil { + t.Fatalf("ReserveSharedWorker: %v", err) + } + if worker.ID != 31 { + t.Fatalf("expected spawned worker 31, got %d", worker.ID) + } + if store.spawnCalls != 1 { + t.Fatalf("expected one spawning slot allocation, got %d", store.spawnCalls) + } + if store.spawnOwnerCPID != pool.cpInstanceID { + t.Fatalf("expected spawn owner cp-instance %q, got %q", pool.cpInstanceID, store.spawnOwnerCPID) + } + if store.spawnOrgID != "analytics" { + t.Fatalf("expected spawn org analytics, got %q", store.spawnOrgID) + } + if store.spawnOwnerEpoch != 1 { + t.Fatalf("expected spawn owner epoch 1, got %d", store.spawnOwnerEpoch) + } + if store.spawnPodNamePrefix != "duckgres-worker-test-cp" { + t.Fatalf("expected pod name prefix duckgres-worker-test-cp, got %q", store.spawnPodNamePrefix) + } + if store.spawnMaxOrgWorkers != 2 { + t.Fatalf("expected max org workers 2, got %d", store.spawnMaxOrgWorkers) + } + if store.spawnMaxGlobalWorks != 5 { + t.Fatalf("expected max global workers 5, got %d", store.spawnMaxGlobalWorks) + } + if worker.OwnerEpoch() != 1 { + t.Fatalf("expected owner epoch 1, got %d", worker.OwnerEpoch()) + } + if worker.SharedState().Lifecycle != WorkerLifecycleReserved { + t.Fatalf("expected reserved lifecycle, got %q", worker.SharedState().Lifecycle) + } +} + +func TestK8sPoolSpawnWarmWorkerAllocatesRuntimeSlotWhenIDZero(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + store := &captureRuntimeWorkerStore{ + neutralSpawned: &configstore.WorkerRecord{ + WorkerID: 41, + PodName: "duckgres-worker-test-cp-41", + State: configstore.WorkerStateSpawning, + OwnerCPInstanceID: pool.cpInstanceID, + }, + } + pool.runtimeStore = store + + var spawnedID int + pool.spawnWarmWorkerFunc = func(ctx context.Context, id int) error { + spawnedID = id + return nil + } + + if err := pool.spawnWarmWorker(context.Background(), 0); err != nil { + t.Fatalf("spawnWarmWorker: %v", err) + } + if spawnedID != 41 { + t.Fatalf("expected runtime-allocated worker id 41, got %d", spawnedID) + } + if store.neutralSpawnCalls != 1 { + t.Fatalf("expected one runtime neutral spawn slot allocation, got %d", store.neutralSpawnCalls) + } + if store.neutralSpawnPodPrefix != "duckgres-worker-test-cp" { + t.Fatalf("expected pod name prefix duckgres-worker-test-cp, got %q", store.neutralSpawnPodPrefix) + } +} + +func TestK8sPoolSpawnMinWorkersUsesRuntimeSlots(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + store := &captureRuntimeWorkerStore{} + pool.runtimeStore = store + + var mu sync.Mutex + var slots int + store.neutralSpawned = &configstore.WorkerRecord{ + WorkerID: 51, + PodName: "duckgres-worker-test-cp-51", + State: configstore.WorkerStateSpawning, + OwnerCPInstanceID: pool.cpInstanceID, + } + pool.spawnWarmWorkerFunc = func(ctx context.Context, id int) error { + mu.Lock() + defer mu.Unlock() + slots++ + if slots == 1 { + if id != 51 { + t.Fatalf("expected first runtime worker id 51, got %d", id) + } + store.neutralSpawned = &configstore.WorkerRecord{ + WorkerID: 52, + PodName: "duckgres-worker-test-cp-52", + State: configstore.WorkerStateSpawning, + OwnerCPInstanceID: pool.cpInstanceID, + } + return nil + } + if id != 52 { + t.Fatalf("expected second runtime worker id 52, got %d", id) + } + return nil + } + + if err := pool.SpawnMinWorkers(2); err != nil { + t.Fatalf("SpawnMinWorkers: %v", err) + } + if store.neutralSpawnCalls != 2 { + t.Fatalf("expected two runtime neutral spawn slot allocations, got %d", store.neutralSpawnCalls) + } + if store.neutralSpawnTarget != 2 { + t.Fatalf("expected neutral warm target 2, got %d", store.neutralSpawnTarget) + } +} + +func TestK8sPoolActivateReservedWorkerPersistsActivatingThenHotWorkerRecord(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + store := &captureRuntimeWorkerStore{} + pool.runtimeStore = store + worker := &ManagedWorker{ID: 9, done: make(chan struct{}), ownerEpoch: 4} + worker.SetOwnerCPInstanceID(pool.cpInstanceID) + if err := worker.SetSharedState(SharedWorkerState{ + Lifecycle: WorkerLifecycleReserved, + Assignment: &WorkerAssignment{ + OrgID: "analytics", + }, + }); err != nil { + t.Fatalf("SetSharedState: %v", err) + } + pool.workers[worker.ID] = worker + pool.activateTenantFunc = func(ctx context.Context, got *ManagedWorker, payload TenantActivationPayload) error { + return nil + } + + if err := pool.ActivateReservedWorker(context.Background(), worker, TenantActivationPayload{ + OrgID: "analytics", + }); err != nil { + t.Fatalf("ActivateReservedWorker: %v", err) + } + + records := store.snapshot() + if len(records) != 2 { + t.Fatalf("expected 2 persisted records, got %d", len(records)) + } + if records[0].State != configstore.WorkerStateActivating { + t.Fatalf("expected activating record first, got %q", records[0].State) + } + if records[1].State != configstore.WorkerStateHot { + t.Fatalf("expected hot record second, got %q", records[1].State) + } + for i, record := range records { + if record.OwnerEpoch != 4 { + t.Fatalf("record %d expected owner epoch 4, got %d", i, record.OwnerEpoch) + } + if record.OwnerCPInstanceID != pool.cpInstanceID { + t.Fatalf("record %d expected owner_cp_instance_id %q, got %q", i, pool.cpInstanceID, record.OwnerCPInstanceID) + } + if record.OrgID != "analytics" { + t.Fatalf("record %d expected org_id analytics, got %q", i, record.OrgID) + } + } +} + +func TestK8sPoolRetireWorkerPersistsRetiredWorkerRecord(t *testing.T) { + pool, _ := newTestK8sPool(t, 5) + store := &captureRuntimeWorkerStore{} + pool.runtimeStore = store + worker := &ManagedWorker{ID: 5, done: make(chan struct{}), ownerEpoch: 2} + worker.SetOwnerCPInstanceID(pool.cpInstanceID) + if err := worker.SetSharedState(SharedWorkerState{ + Lifecycle: WorkerLifecycleHot, + Assignment: &WorkerAssignment{ + OrgID: "analytics", + }, + }); err != nil { + t.Fatalf("SetSharedState: %v", err) + } + pool.workers[worker.ID] = worker + + pool.RetireWorker(worker.ID) + + records := store.snapshot() + if len(records) == 0 { + t.Fatal("expected retirement to persist a worker record") + } + last := records[len(records)-1] + if last.State != configstore.WorkerStateRetired { + t.Fatalf("expected retired worker record, got %q", last.State) + } + if last.OwnerEpoch != 2 { + t.Fatalf("expected owner epoch 2, got %d", last.OwnerEpoch) + } + if last.OwnerCPInstanceID != pool.cpInstanceID { + t.Fatalf("expected owner_cp_instance_id %q, got %q", pool.cpInstanceID, last.OwnerCPInstanceID) + } + if last.OrgID != "analytics" { + t.Fatalf("expected org_id analytics, got %q", last.OrgID) + } + if last.RetireReason != RetireReasonNormal { + t.Fatalf("expected retire reason %q, got %q", RetireReasonNormal, last.RetireReason) } } @@ -573,8 +1272,7 @@ func TestK8sPoolReserveSharedWorkerSpawnsWhenPoolIsCold(t *testing.T) { defer cancel() worker, err := pool.ReserveSharedWorker(ctx, &WorkerAssignment{ - OrgID: "billing", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "billing", }) if err != nil { t.Fatalf("ReserveSharedWorker: %v", err) @@ -599,8 +1297,7 @@ func TestK8sPoolIdleReaperSkipsReservedSharedWorker(t *testing.T) { if err := reserved.SetSharedState(SharedWorkerState{ Lifecycle: WorkerLifecycleReserved, Assignment: &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "analytics", }, }); err != nil { t.Fatalf("SetSharedState(reserved): %v", err) @@ -694,15 +1391,18 @@ func assertSpawnedWorkerPod(t *testing.T, pod *corev1.Pod) { if pod.Labels["duckgres/control-plane"] != "test-cp" { t.Fatalf("expected control-plane label test-cp, got %s", pod.Labels["duckgres/control-plane"]) } + if pod.Labels["duckgres/cp-instance-id"] != "cp-uid-123-boot-abc" { + t.Fatalf("expected cp-instance-id label cp-uid-123-boot-abc, got %s", pod.Labels["duckgres/cp-instance-id"]) + } + if pod.Labels["duckgres/owner-epoch"] != "0" { + t.Fatalf("expected owner-epoch label 0, got %s", pod.Labels["duckgres/owner-epoch"]) + } if _, ok := pod.Labels["duckgres/org"]; ok { t.Fatalf("expected shared warm worker startup to stay org-neutral, got labels %#v", pod.Labels) } - if len(pod.OwnerReferences) != 1 { - t.Fatalf("expected 1 owner reference, got %d", len(pod.OwnerReferences)) - } - if pod.OwnerReferences[0].Name != "test-cp" { - t.Fatalf("expected owner ref to test-cp, got %s", pod.OwnerReferences[0].Name) + if len(pod.OwnerReferences) != 0 { + t.Fatalf("expected no owner references, got %d", len(pod.OwnerReferences)) } if pod.Spec.SecurityContext == nil || pod.Spec.SecurityContext.RunAsNonRoot == nil || !*pod.Spec.SecurityContext.RunAsNonRoot { @@ -741,13 +1441,32 @@ func assertSpawnedWorkerPod(t *testing.T, pod *corev1.Pod) { } } +func TestControlPlaneIDLabelValue_StaysKubernetesSafe(t *testing.T) { + t.Parallel() + + label := controlPlaneIDLabelValue("duckgres-control-plane-7fb9dd69c6-dcgzw:14cd8dd9eb353e609c7a4387a594a418") + if len(label) > 63 { + t.Fatalf("expected label length <= 63, got %d (%q)", len(label), label) + } + matched, err := regexp.MatchString(`^(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?$`, label) + if err != nil { + t.Fatalf("failed to compile label regex: %v", err) + } + if !matched { + t.Fatalf("expected Kubernetes-safe label, got %q", label) + } + if label == "duckgres-control-plane-7fb9dd69c6-dcgzw:14cd8dd9eb353e609c7a4387a594a418" { + t.Fatalf("expected sanitized label, got original %q", label) + } +} + func TestK8sPool_ShutdownAll(t *testing.T) { pool, cs := newTestK8sPool(t, 5) // Add some workers for i := 0; i < 3; i++ { done := make(chan struct{}) - pool.workers[i] = &ManagedWorker{ID: i, done: done} + pool.workers[i] = &ManagedWorker{ID: i, podName: "duckgres-worker-test-cp-" + strconv.Itoa(i), done: done} // Create corresponding pods _, _ = cs.CoreV1().Pods("default").Create(context.Background(), &corev1.Pod{ @@ -772,6 +1491,35 @@ func TestK8sPool_ShutdownAll(t *testing.T) { } } +func TestK8sPoolRetireWorkerUsesTrackedPodName(t *testing.T) { + pool, cs := newTestK8sPool(t, 5) + + done := make(chan struct{}) + worker := &ManagedWorker{ + ID: 11, + podName: "duckgres-worker-other-cp-11", + done: done, + } + pool.workers[worker.ID] = worker + + var deletedPodName string + cs.PrependReactor("delete", "pods", func(action k8stesting.Action) (bool, runtime.Object, error) { + deleteAction, ok := action.(k8stesting.DeleteAction) + if !ok { + return false, nil, nil + } + deletedPodName = deleteAction.GetName() + return false, nil, nil + }) + + pool.RetireWorker(worker.ID) + time.Sleep(100 * time.Millisecond) + + if deletedPodName != "duckgres-worker-other-cp-11" { + t.Fatalf("expected retire to delete tracked pod name duckgres-worker-other-cp-11, got %q", deletedPodName) + } +} + func TestK8sPool_OnPodTerminated(t *testing.T) { pool, _ := newTestK8sPool(t, 5) diff --git a/controlplane/leader_loop.go b/controlplane/leader_loop.go new file mode 100644 index 00000000..456ae737 --- /dev/null +++ b/controlplane/leader_loop.go @@ -0,0 +1,46 @@ +package controlplane + +import ( + "context" + "sync" +) + +type leaderOnlyLoop struct { + run func(context.Context) + + mu sync.Mutex + cancel context.CancelFunc +} + +func newLeaderOnlyLoop(run func(context.Context)) *leaderOnlyLoop { + return &leaderOnlyLoop{run: run} +} + +func (l *leaderOnlyLoop) onStartedLeading(ctx context.Context) { + if l == nil || l.run == nil { + return + } + + l.mu.Lock() + if l.cancel != nil { + l.cancel() + } + leaderCtx, cancel := context.WithCancel(ctx) + l.cancel = cancel + l.mu.Unlock() + + go l.run(leaderCtx) +} + +func (l *leaderOnlyLoop) onStoppedLeading() { + if l == nil { + return + } + l.mu.Lock() + cancel := l.cancel + l.cancel = nil + l.mu.Unlock() + if cancel != nil { + cancel() + } +} diff --git a/controlplane/leader_loop_test.go b/controlplane/leader_loop_test.go new file mode 100644 index 00000000..dfe864c6 --- /dev/null +++ b/controlplane/leader_loop_test.go @@ -0,0 +1,34 @@ +package controlplane + +import ( + "context" + "testing" + "time" +) + +func TestLeaderOnlyLoopStartsAndStopsBackgroundRun(t *testing.T) { + started := make(chan struct{}, 1) + stopped := make(chan struct{}, 1) + loop := newLeaderOnlyLoop(func(ctx context.Context) { + started <- struct{}{} + <-ctx.Done() + stopped <- struct{}{} + }) + + rootCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + loop.onStartedLeading(rootCtx) + select { + case <-started: + case <-time.After(time.Second): + t.Fatal("leader loop did not start") + } + + loop.onStoppedLeading() + select { + case <-stopped: + case <-time.After(time.Second): + t.Fatal("leader loop did not stop after leadership loss") + } +} diff --git a/controlplane/migration_lock_test.go b/controlplane/migration_lock_test.go index 3222ecce..9730d5b1 100644 --- a/controlplane/migration_lock_test.go +++ b/controlplane/migration_lock_test.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "testing" - "time" "github.com/posthog/duckgres/controlplane/configstore" corev1 "k8s.io/api/core/v1" @@ -91,7 +90,7 @@ func newTestWorker(t *testing.T, id int, orgID string) *ManagedWorker { worker := &ManagedWorker{ID: id} if err := worker.SetSharedState(SharedWorkerState{ Lifecycle: WorkerLifecycleReserved, - Assignment: &WorkerAssignment{OrgID: orgID, LeaseExpiresAt: time.Now().Add(time.Hour)}, + Assignment: &WorkerAssignment{OrgID: orgID}, }); err != nil { t.Fatalf("SetSharedState: %v", err) } diff --git a/controlplane/multitenant.go b/controlplane/multitenant.go index a82bb139..ec07394a 100644 --- a/controlplane/multitenant.go +++ b/controlplane/multitenant.go @@ -10,6 +10,7 @@ import ( "fmt" "log/slog" "net/http" + "os" "time" "github.com/gin-gonic/gin" @@ -34,6 +35,14 @@ func (a *orgRouterAdapter) StackForUser(username string) (WorkerPool, *SessionMa return stack.Pool, stack.Sessions, stack.Rebalancer, true } +func (a *orgRouterAdapter) StackForOrg(orgID string) (WorkerPool, *SessionManager, *MemoryRebalancer, bool) { + stack, ok := a.router.StackForOrg(orgID) + if !ok { + return nil, nil, nil, false + } + return stack.Pool, stack.Sessions, stack.Rebalancer, true +} + func (a *orgRouterAdapter) IsMigratingForUser(username string) (bool, string) { return a.router.IsMigratingForUser(username) } @@ -120,7 +129,7 @@ func SetupMultiTenant( srv *server.Server, memBudget uint64, maxWorkers int, -) (ConfigStoreInterface, OrgRouterInterface, *http.Server, error) { +) (ConfigStoreInterface, OrgRouterInterface, *http.Server, *ControlPlaneRuntimeTracker, *JanitorLeaderManager, error) { pollInterval := cfg.ConfigPollInterval if pollInterval <= 0 { pollInterval = 30 * time.Second @@ -128,27 +137,51 @@ func SetupMultiTenant( store, err := configstore.NewConfigStore(cfg.ConfigStoreConn, pollInterval) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, nil, err + } + + namespace, err := resolveK8sNamespace(cfg.K8s.WorkerNamespace) + if err != nil { + return nil, nil, nil, nil, nil, err } + cpID := cfg.K8s.ControlPlaneID + if cpID == "" { + cpID = os.Getenv("POD_NAME") + } + if cpID == "" { + cpID, _ = os.Hostname() + } + podUID := os.Getenv("POD_UID") + if podUID == "" { + podUID = cpID + } + bootID := make([]byte, 16) + if _, err := rand.Read(bootID); err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("generate control plane boot id: %w", err) + } + bootIDHex := hex.EncodeToString(bootID) + cpInstanceID := makeControlPlaneInstanceID(podUID, bootIDHex) + baseCfg := K8sWorkerPoolConfig{ - Namespace: cfg.K8s.WorkerNamespace, - CPID: cfg.K8s.ControlPlaneID, - WorkerImage: cfg.K8s.WorkerImage, - WorkerPort: cfg.K8s.WorkerPort, - SecretName: cfg.K8s.WorkerSecret, - ConfigMap: cfg.K8s.WorkerConfigMap, - MaxWorkers: maxWorkers, - IdleTimeout: cfg.WorkerIdleTimeout, - ConfigPath: cfg.ConfigPath, - ImagePullPolicy: cfg.K8s.ImagePullPolicy, - ServiceAccount: cfg.K8s.ServiceAccount, - WorkerCPURequest: cfg.K8s.WorkerCPURequest, - WorkerMemoryRequest: cfg.K8s.WorkerMemoryRequest, - WorkerNodeSelector: parseNodeSelector(cfg.K8s.WorkerNodeSelector), + Namespace: namespace, + CPID: cpID, + CPInstanceID: cpInstanceID, + WorkerImage: cfg.K8s.WorkerImage, + WorkerPort: cfg.K8s.WorkerPort, + SecretName: cfg.K8s.WorkerSecret, + ConfigMap: cfg.K8s.WorkerConfigMap, + MaxWorkers: maxWorkers, + IdleTimeout: cfg.WorkerIdleTimeout, + ConfigPath: cfg.ConfigPath, + ImagePullPolicy: cfg.K8s.ImagePullPolicy, + ServiceAccount: cfg.K8s.ServiceAccount, + WorkerCPURequest: cfg.K8s.WorkerCPURequest, + WorkerMemoryRequest: cfg.K8s.WorkerMemoryRequest, + WorkerNodeSelector: parseNodeSelector(cfg.K8s.WorkerNodeSelector), WorkerTolerationKey: cfg.K8s.WorkerTolerationKey, WorkerTolerationValue: cfg.K8s.WorkerTolerationValue, - WorkerExclusiveNode: cfg.K8s.WorkerExclusiveNode, + WorkerExclusiveNode: cfg.K8s.WorkerExclusiveNode, } // Initialize STS broker for credential brokering (best-effort) @@ -174,10 +207,37 @@ func SetupMultiTenant( router, err := NewOrgRouter(store, baseCfg, cfg, srv, stsBroker, resolveDucklingStatus) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, nil, err } adpt := &orgRouterAdapter{router: router} + runtimeTracker := NewControlPlaneRuntimeTracker( + store, + cpInstanceID, + cpID, + podUID, + bootIDHex, + 5*time.Second, + ) + janitor := NewControlPlaneJanitor(store, 5*time.Second, 20*time.Second) + janitor.maxDrainTimeout = cfg.HandoverDrainTimeout + janitor.retireWorker = func(record configstore.WorkerRecord, reason string) { + router.sharedPool.retireClaimedWorker(&record, reason) + } + janitor.reconcileWarmCapacity = func() { + target := router.sharedPool.WarmCapacityTarget() + if target <= 0 { + return + } + observeOrgWorkerSpawn("shared") + if err := router.sharedPool.SpawnMinWorkers(target); err != nil { + slog.Warn("Janitor failed to reconcile shared warm capacity.", "target", target, "error", err) + } + } + janitorLeader, err := NewJanitorLeaderManager(namespace, cpInstanceID, janitor) + if err != nil { + return nil, nil, nil, nil, nil, err + } // Start provisioning controller (best-effort — K8s API may not be available locally) provCtrl, err := provisioner.NewController(store, 10*time.Second) @@ -198,7 +258,7 @@ func SetupMultiTenant( if internalSecret == "" { tokenBytes := make([]byte, 32) if _, err := rand.Read(tokenBytes); err != nil { - return nil, nil, nil, fmt.Errorf("generate internal secret: %w", err) + return nil, nil, nil, nil, nil, fmt.Errorf("generate internal secret: %w", err) } internalSecret = hex.EncodeToString(tokenBytes) slog.Info("Generated internal secret (pass via --internal-secret or DUCKGRES_INTERNAL_SECRET to set explicitly).", "secret", internalSecret) @@ -211,9 +271,7 @@ func SetupMultiTenant( engine.Use(gin.Recovery()) // Health endpoint (unauthenticated, used by K8s probes) - engine.GET("/health", func(c *gin.Context) { - c.String(http.StatusOK, "ok") - }) + engine.GET("/health", newHealthHandler(runtimeTracker.Draining)) // Authenticated API api := engine.Group("/api/v1", admin.APIAuthMiddleware(internalSecret)) @@ -234,7 +292,28 @@ func SetupMultiTenant( } }() - return store, adpt, apiServer, nil + return store, adpt, apiServer, runtimeTracker, janitorLeader, nil +} + +func resolveK8sNamespace(namespace string) (string, error) { + if namespace != "" { + return namespace, nil + } + ns, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/namespace") + if err != nil { + return "", fmt.Errorf("k8s namespace not set and auto-detection failed: %w", err) + } + return string(ns), nil +} + +func newHealthHandler(isDraining func() bool) gin.HandlerFunc { + return func(c *gin.Context) { + if isDraining != nil && isDraining() { + c.String(http.StatusServiceUnavailable, "draining") + return + } + c.String(http.StatusOK, "ok") + } } // parseNodeSelector parses a JSON string into a map[string]string. @@ -250,3 +329,13 @@ func parseNodeSelector(s string) map[string]string { } return m } + +func makeControlPlaneInstanceID(podUID, bootIDHex string) string { + if podUID == "" { + podUID = "cp" + } + if len(bootIDHex) > 16 { + bootIDHex = bootIDHex[:16] + } + return controlPlaneIDLabelValue(podUID + "-" + bootIDHex) +} diff --git a/controlplane/multitenant_cp_instance_id_test.go b/controlplane/multitenant_cp_instance_id_test.go new file mode 100644 index 00000000..38e55787 --- /dev/null +++ b/controlplane/multitenant_cp_instance_id_test.go @@ -0,0 +1,24 @@ +//go:build kubernetes + +package controlplane + +import "testing" + +func TestMakeControlPlaneInstanceID_StaysKubernetesSafe(t *testing.T) { + t.Parallel() + + id := makeControlPlaneInstanceID( + "duckgres-control-plane-7fb9dd69c6-dcgzw", + "14cd8dd9eb353e609c7a4387a594a418", + ) + + if len(id) > 63 { + t.Fatalf("expected cp_instance_id length <= 63, got %d (%q)", len(id), id) + } + if id == "duckgres-control-plane-7fb9dd69c6-dcgzw:14cd8dd9eb353e609c7a4387a594a418" { + t.Fatalf("expected shortened cp_instance_id, got original %q", id) + } + if id != "duckgres-control-plane-7fb9dd69c6-dcgzw-14cd8dd9eb353e60" { + t.Fatalf("unexpected cp_instance_id %q", id) + } +} diff --git a/controlplane/multitenant_stub.go b/controlplane/multitenant_stub.go index 9efb937f..b8fc5662 100644 --- a/controlplane/multitenant_stub.go +++ b/controlplane/multitenant_stub.go @@ -15,6 +15,6 @@ func SetupMultiTenant( srv *server.Server, memBudget uint64, maxWorkers int, -) (ConfigStoreInterface, OrgRouterInterface, *http.Server, error) { - return nil, nil, nil, fmt.Errorf("multi-tenant mode requires -tags kubernetes build") +) (ConfigStoreInterface, OrgRouterInterface, *http.Server, *ControlPlaneRuntimeTracker, *JanitorLeaderManager, error) { + return nil, nil, nil, nil, nil, fmt.Errorf("multi-tenant mode requires -tags kubernetes build") } diff --git a/controlplane/multitenant_test.go b/controlplane/multitenant_test.go new file mode 100644 index 00000000..ede3deb7 --- /dev/null +++ b/controlplane/multitenant_test.go @@ -0,0 +1,39 @@ +//go:build kubernetes + +package controlplane + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestHealthHandlerReturnsServiceUnavailableWhenDraining(t *testing.T) { + gin.SetMode(gin.ReleaseMode) + engine := gin.New() + engine.GET("/health", newHealthHandler(func() bool { return true })) + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rec := httptest.NewRecorder() + engine.ServeHTTP(rec, req) + + if rec.Code != http.StatusServiceUnavailable { + t.Fatalf("expected 503 while draining, got %d", rec.Code) + } +} + +func TestHealthHandlerReturnsOKWhenNotDraining(t *testing.T) { + gin.SetMode(gin.ReleaseMode) + engine := gin.New() + engine.GET("/health", newHealthHandler(func() bool { return false })) + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rec := httptest.NewRecorder() + engine.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 while healthy, got %d", rec.Code) + } +} diff --git a/controlplane/org_activation_test.go b/controlplane/org_activation_test.go index 4c5dcfe5..ba6f849b 100644 --- a/controlplane/org_activation_test.go +++ b/controlplane/org_activation_test.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "testing" - "time" "github.com/posthog/duckgres/controlplane/configstore" "github.com/posthog/duckgres/controlplane/provisioner" @@ -73,10 +72,7 @@ func TestSharedWorkerActivatorBuildsActivationRequestFromManagedWarehouse(t *tes }, } - assignment := &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: time.Date(2026, time.March, 22, 12, 0, 0, 0, time.UTC), - } + assignment := &WorkerAssignment{OrgID: "analytics"} req, err := activator.BuildActivationRequest(context.Background(), org, assignment) if err != nil { @@ -86,9 +82,6 @@ func TestSharedWorkerActivatorBuildsActivationRequestFromManagedWarehouse(t *tes if req.OrgID != "analytics" { t.Fatalf("expected org analytics, got %q", req.OrgID) } - if !req.LeaseExpiresAt.Equal(assignment.LeaseExpiresAt) { - t.Fatalf("expected lease expiry %v, got %v", assignment.LeaseExpiresAt, req.LeaseExpiresAt) - } if got := req.DuckLake.MetadataStore; got != "postgres:host=metadata.example.internal port=5432 user=ducklake_user password=metadata-password dbname=ducklake_metadata" { t.Fatalf("unexpected metadata store dsn: %q", got) } @@ -110,8 +103,7 @@ func TestSharedWorkerActivatorRequiresManagedWarehouse(t *testing.T) { } _, err := activator.BuildActivationRequest(context.Background(), &configstore.OrgConfig{Name: "analytics"}, &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "analytics", }) if err == nil { t.Fatal("expected missing warehouse to fail") @@ -141,12 +133,10 @@ func TestSharedWorkerActivatorActivateReservedWorkerUsesLatestResolvedOrgConfig( ) worker := &ManagedWorker{ID: 7, done: make(chan struct{})} - leaseExpiry := time.Date(2026, time.March, 22, 12, 0, 0, 0, time.UTC) if err := worker.SetSharedState(SharedWorkerState{ Lifecycle: WorkerLifecycleReserved, Assignment: &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: leaseExpiry, + OrgID: "analytics", }, }); err != nil { t.Fatalf("SetSharedState: %v", err) @@ -218,9 +208,6 @@ func TestSharedWorkerActivatorActivateReservedWorkerUsesLatestResolvedOrgConfig( if captured.OrgID != "analytics" { t.Fatalf("expected org analytics, got %q", captured.OrgID) } - if !captured.LeaseExpiresAt.Equal(leaseExpiry) { - t.Fatalf("expected lease expiry %v, got %v", leaseExpiry, captured.LeaseExpiresAt) - } if len(captured.Usernames) != 1 || captured.Usernames[0] != "bob" { t.Fatalf("expected latest users to be captured, got %#v", captured.Usernames) } @@ -265,8 +252,8 @@ func TestSharedWorkerActivatorDucklingCRFallback(t *testing.T) { } org := &configstore.OrgConfig{ - Name: "test-org", - Users: map[string]string{"testuser": "hash"}, + Name: "test-org", + Users: map[string]string{"testuser": "hash"}, Warehouse: &configstore.ManagedWarehouseConfig{ // SecretRef intentionally empty — simulates Crossplane-provisioned duckling }, diff --git a/controlplane/org_reserved_pool.go b/controlplane/org_reserved_pool.go index 7378feaa..42324cbd 100644 --- a/controlplane/org_reserved_pool.go +++ b/controlplane/org_reserved_pool.go @@ -9,8 +9,6 @@ import ( "time" ) -const defaultSharedWorkerReservationLease = 24 * time.Hour - // OrgReservedPool presents one org's reserved slice of a shared K8s warm pool. // It preserves the existing WorkerPool contract for SessionManager while ensuring // workers are reserved to a single org for their lifetime and retired after use. @@ -18,18 +16,16 @@ type OrgReservedPool struct { shared *K8sWorkerPool orgID string maxWorkers int - leaseDuration time.Duration stsBroker *STSBroker activateReservedWorker func(context.Context, *ManagedWorker) error } func NewOrgReservedPool(shared *K8sWorkerPool, orgID string, maxWorkers int, stsBroker *STSBroker) *OrgReservedPool { pool := &OrgReservedPool{ - shared: shared, - orgID: orgID, - maxWorkers: maxWorkers, - leaseDuration: defaultSharedWorkerReservationLease, - stsBroker: stsBroker, + shared: shared, + orgID: orgID, + maxWorkers: maxWorkers, + stsBroker: stsBroker, } pool.activateReservedWorker = pool.activateReservedWorkerDefault return pool @@ -65,8 +61,7 @@ func (p *OrgReservedPool) AcquireWorker(ctx context.Context) (*ManagedWorker, er p.shared.mu.Unlock() worker, err := p.shared.ReserveSharedWorker(ctx, &WorkerAssignment{ - OrgID: p.orgID, - LeaseExpiresAt: time.Now().Add(p.leaseDuration), + OrgID: p.orgID, }) if err != nil { return nil, err @@ -102,7 +97,7 @@ func (p *OrgReservedPool) AcquireWorker(ctx context.Context) (*ManagedWorker, er } func (p *OrgReservedPool) ReleaseWorker(id int) { - _ = p.RetireWorkerIfNoSessions(id) + p.shared.RetireWorkerIfNoSessions(id) } func (p *OrgReservedPool) RetireWorker(id int) { @@ -263,6 +258,21 @@ func (p *OrgReservedPool) activateWorkerForOrg(ctx context.Context, worker *Mana } } +func (p *OrgReservedPool) ReconnectFlightWorker(ctx context.Context, workerID int, ownerEpoch int64) (*ManagedWorker, error) { + worker, err := p.shared.claimSpecificWorker(ctx, workerID, ownerEpoch, &WorkerAssignment{ + OrgID: p.orgID, + MaxWorkers: p.maxWorkers, + }) + if err != nil { + return nil, err + } + if err := p.activateWorkerForOrg(ctx, worker); err != nil { + p.shared.retireWorkerWithReason(worker.ID, RetireReasonActivationFailure) + return nil, err + } + return worker, nil +} + func (p *OrgReservedPool) activateReservedWorkerDefault(_ context.Context, _ *ManagedWorker) error { return fmt.Errorf("reserved worker activator is not configured for org %s", p.orgID) } diff --git a/controlplane/org_reserved_pool_test.go b/controlplane/org_reserved_pool_test.go index 0ec5f524..4f99c62c 100644 --- a/controlplane/org_reserved_pool_test.go +++ b/controlplane/org_reserved_pool_test.go @@ -53,8 +53,7 @@ func TestOrgReservedPoolAcquireSkipsOtherOrgsWorkers(t *testing.T) { if err := other.SetSharedState(SharedWorkerState{ Lifecycle: WorkerLifecycleReserved, Assignment: &WorkerAssignment{ - OrgID: "billing", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "billing", }, }); err != nil { t.Fatalf("SetSharedState(other): %v", err) @@ -87,14 +86,15 @@ func TestOrgReservedPoolAcquireSkipsOtherOrgsWorkers(t *testing.T) { } } -func TestOrgReservedPoolReleaseWorkerRetiresOnLastSession(t *testing.T) { +func TestOrgReservedPoolReleaseWorkerRetiresWorkerOnLastSession(t *testing.T) { shared, _ := newTestK8sPool(t, 5) worker := &ManagedWorker{ID: 9, activeSessions: 1, done: make(chan struct{})} + worker.SetOwnerCPInstanceID(shared.cpInstanceID) + worker.SetOwnerEpoch(3) if err := worker.SetSharedState(SharedWorkerState{ Lifecycle: WorkerLifecycleHot, Assignment: &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "analytics", }, }); err != nil { t.Fatalf("SetSharedState(worker): %v", err) @@ -104,7 +104,6 @@ func TestOrgReservedPoolReleaseWorkerRetiresOnLastSession(t *testing.T) { pool := NewOrgReservedPool(shared, "analytics", 1, nil) pool.ReleaseWorker(worker.ID) - time.Sleep(100 * time.Millisecond) if _, ok := shared.Worker(worker.ID); ok { t.Fatal("expected worker to be retired after last session release") } @@ -187,8 +186,7 @@ func TestOrgReservedPoolAcquireWaitsWhenSharedWarmWorkerBusyAtCapacity(t *testin if err := worker.SetSharedState(SharedWorkerState{ Lifecycle: WorkerLifecycleHot, Assignment: &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "analytics", }, }); err != nil { t.Fatalf("SetSharedState(worker): %v", err) diff --git a/controlplane/org_router.go b/controlplane/org_router.go index 5a9adf69..2f48ce16 100644 --- a/controlplane/org_router.go +++ b/controlplane/org_router.go @@ -61,6 +61,7 @@ func NewOrgRouter(store *configstore.ConfigStore, baseCfg K8sWorkerPoolConfig, g sharedCfg.WorkerIDGenerator = func() int { return int(tr.nextWorkerID.Add(1)) } + sharedCfg.RuntimeStore = store sharedPoolIface, err := CreateK8sPool(sharedCfg) if err != nil { @@ -184,6 +185,11 @@ func (tr *OrgRouter) StackForUser(username string) (*OrgStack, bool) { return nil, false } + return tr.StackForOrg(orgID) +} + +// StackForOrg resolves an org id to its org stack. +func (tr *OrgRouter) StackForOrg(orgID string) (*OrgStack, bool) { tr.mu.RLock() stack, ok := tr.orgs[orgID] tr.mu.RUnlock() @@ -350,12 +356,6 @@ func (tr *OrgRouter) reconcileWarmCapacity(snap *configstore.Snapshot) { } tr.sharedPool.SetWarmCapacityTarget(target) - if target > 0 { - observeOrgWorkerSpawn("shared") - if err := tr.sharedPool.SpawnMinWorkers(target); err != nil { - slog.Warn("Failed to reconcile shared warm capacity.", "target", target, "error", err) - } - } } func (tr *OrgRouter) onSharedWorkerCrash(workerID int) { diff --git a/controlplane/org_router_test.go b/controlplane/org_router_test.go index 6ad03e2b..c7f6fe37 100644 --- a/controlplane/org_router_test.go +++ b/controlplane/org_router_test.go @@ -195,9 +195,6 @@ func TestOrgRouterCreateOrgStackActivatesUsingLatestSnapshotThroughSharedWorkerA if got := captured.DuckLake.MetadataStore; got != "postgres:host=new-metadata.internal port=5432 user=ducklake_user password=new-password dbname=ducklake_metadata" { t.Fatalf("expected latest warehouse runtime from router snapshot, got %q", got) } - if captured.LeaseExpiresAt.Before(time.Now()) { - t.Fatalf("expected lease expiry to be set, got %v", captured.LeaseExpiresAt) - } } func newStringSecret(namespace, name, key, value string) *corev1.Secret { diff --git a/controlplane/org_router_test_helpers_test.go b/controlplane/org_router_test_helpers_test.go index 99816f13..7cc318b3 100644 --- a/controlplane/org_router_test_helpers_test.go +++ b/controlplane/org_router_test_helpers_test.go @@ -11,5 +11,9 @@ func (m *mockOrgRouter) StackForUser(_ string) (WorkerPool, *SessionManager, *Me return nil, m.sessions, m.rebalancer, m.ok } +func (m *mockOrgRouter) StackForOrg(_ string) (WorkerPool, *SessionManager, *MemoryRebalancer, bool) { + return nil, m.sessions, m.rebalancer, m.ok +} + func (m *mockOrgRouter) IsMigratingForUser(_ string) (bool, string) { return false, "" } func (m *mockOrgRouter) ShutdownAll() {} diff --git a/controlplane/runtime_tracker.go b/controlplane/runtime_tracker.go new file mode 100644 index 00000000..efeb4442 --- /dev/null +++ b/controlplane/runtime_tracker.go @@ -0,0 +1,129 @@ +package controlplane + +import ( + "context" + "log/slog" + "sync" + "sync/atomic" + "time" + + "github.com/posthog/duckgres/controlplane/configstore" +) + +type runtimeInstanceStore interface { + UpsertControlPlaneInstance(instance *configstore.ControlPlaneInstance) error +} + +type ControlPlaneRuntimeTracker struct { + store runtimeInstanceStore + id string + podName string + podUID string + bootID string + heartbeatInterval time.Duration + now func() time.Time + + draining atomic.Bool + + mu sync.Mutex + started bool + startedAt time.Time + drainingAt *time.Time +} + +func NewControlPlaneRuntimeTracker(store runtimeInstanceStore, id, podName, podUID, bootID string, heartbeatInterval time.Duration) *ControlPlaneRuntimeTracker { + if heartbeatInterval <= 0 { + heartbeatInterval = 5 * time.Second + } + return &ControlPlaneRuntimeTracker{ + store: store, + id: id, + podName: podName, + podUID: podUID, + bootID: bootID, + heartbeatInterval: heartbeatInterval, + now: time.Now, + } +} + +func (t *ControlPlaneRuntimeTracker) Start(ctx context.Context) error { + t.mu.Lock() + if t.started { + t.mu.Unlock() + return nil + } + now := t.now() + t.started = true + t.startedAt = now + t.mu.Unlock() + + if err := t.upsertAt(now); err != nil { + return err + } + + go t.heartbeatLoop(ctx) + return nil +} + +func (t *ControlPlaneRuntimeTracker) heartbeatLoop(ctx context.Context) { + ticker := time.NewTicker(t.heartbeatInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := t.upsertAt(t.now()); err != nil { + slog.Warn("Runtime tracker heartbeat failed.", "cp_instance_id", t.id, "error", err) + } + } + } +} + +func (t *ControlPlaneRuntimeTracker) MarkDraining() error { + now := t.now() + t.mu.Lock() + if t.drainingAt == nil { + drainingAt := now + t.drainingAt = &drainingAt + } + if !t.started { + t.started = true + t.startedAt = now + } + t.mu.Unlock() + t.draining.Store(true) + return t.upsertAt(now) +} + +func (t *ControlPlaneRuntimeTracker) Draining() bool { + return t.draining.Load() +} + +func (t *ControlPlaneRuntimeTracker) upsertAt(now time.Time) error { + state := configstore.ControlPlaneInstanceStateActive + + t.mu.Lock() + startedAt := t.startedAt + var drainingAt *time.Time + if t.draining.Load() { + state = configstore.ControlPlaneInstanceStateDraining + if t.drainingAt != nil { + drainingCopy := *t.drainingAt + drainingAt = &drainingCopy + } + } + t.mu.Unlock() + + return t.store.UpsertControlPlaneInstance(&configstore.ControlPlaneInstance{ + ID: t.id, + PodName: t.podName, + PodUID: t.podUID, + BootID: t.bootID, + State: state, + StartedAt: startedAt, + LastHeartbeatAt: now, + DrainingAt: drainingAt, + }) +} diff --git a/controlplane/runtime_tracker_test.go b/controlplane/runtime_tracker_test.go new file mode 100644 index 00000000..0ebb704b --- /dev/null +++ b/controlplane/runtime_tracker_test.go @@ -0,0 +1,100 @@ +package controlplane + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/posthog/duckgres/controlplane/configstore" +) + +type fakeRuntimeInstanceStore struct { + mu sync.Mutex + records []configstore.ControlPlaneInstance + upsertCh chan struct{} +} + +func (f *fakeRuntimeInstanceStore) UpsertControlPlaneInstance(instance *configstore.ControlPlaneInstance) error { + f.mu.Lock() + f.records = append(f.records, *instance) + f.mu.Unlock() + if f.upsertCh != nil { + select { + case f.upsertCh <- struct{}{}: + default: + } + } + return nil +} + +func (f *fakeRuntimeInstanceStore) snapshot() []configstore.ControlPlaneInstance { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]configstore.ControlPlaneInstance, len(f.records)) + copy(out, f.records) + return out +} + +func (f *fakeRuntimeInstanceStore) waitForUpserts(t *testing.T, count int, timeout time.Duration) { + t.Helper() + for i := 0; i < count; i++ { + select { + case <-f.upsertCh: + case <-time.After(timeout): + t.Fatalf("timed out waiting for %d runtime upserts; saw %d", count, len(f.snapshot())) + } + } +} + +func TestControlPlaneRuntimeTrackerStartHeartbeats(t *testing.T) { + store := &fakeRuntimeInstanceStore{upsertCh: make(chan struct{}, 8)} + tracker := NewControlPlaneRuntimeTracker(store, "cp-1:boot-a", "duckgres-0", "pod-uid-1", "boot-a", 10*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := tracker.Start(ctx); err != nil { + t.Fatalf("Start: %v", err) + } + + store.waitForUpserts(t, 2, time.Second) + + records := store.snapshot() + if len(records) < 2 { + t.Fatalf("expected at least 2 upserts, got %d", len(records)) + } + if records[0].State != configstore.ControlPlaneInstanceStateActive { + t.Fatalf("expected initial state active, got %q", records[0].State) + } + if records[0].ID != "cp-1:boot-a" { + t.Fatalf("expected id cp-1:boot-a, got %q", records[0].ID) + } +} + +func TestControlPlaneRuntimeTrackerMarkDraining(t *testing.T) { + store := &fakeRuntimeInstanceStore{} + tracker := NewControlPlaneRuntimeTracker(store, "cp-1:boot-a", "duckgres-0", "pod-uid-1", "boot-a", time.Hour) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := tracker.Start(ctx); err != nil { + t.Fatalf("Start: %v", err) + } + if err := tracker.MarkDraining(); err != nil { + t.Fatalf("MarkDraining: %v", err) + } + if !tracker.Draining() { + t.Fatal("expected tracker to report draining") + } + + records := store.snapshot() + last := records[len(records)-1] + if last.State != configstore.ControlPlaneInstanceStateDraining { + t.Fatalf("expected draining state, got %q", last.State) + } + if last.DrainingAt == nil { + t.Fatal("expected draining_at to be set") + } +} diff --git a/controlplane/session_mgr.go b/controlplane/session_mgr.go index d6a3f8df..86bf53e8 100644 --- a/controlplane/session_mgr.go +++ b/controlplane/session_mgr.go @@ -44,6 +44,10 @@ type SessionManager struct { nextPID atomic.Int32 } +type flightReconnectPool interface { + ReconnectFlightWorker(ctx context.Context, workerID int, ownerEpoch int64) (*ManagedWorker, error) +} + // NewSessionManager creates a new session manager. func NewSessionManager(pool WorkerPool, rebalancer *MemoryRebalancer) *SessionManager { sm := &SessionManager{ @@ -80,17 +84,46 @@ func (sm *SessionManager) CreateSession(ctx context.Context, username string, pi } slog.Debug("Worker acquired.", "pid", pid, "worker", worker.ID, "user", username, "duration", time.Since(acquireStart)) + return sm.createSessionOnWorker(ctx, username, pid, memoryLimit, threads, worker, "postgres", true) +} + +func (sm *SessionManager) resolveSessionLimits(memoryLimit string, threads int) (string, int) { + if sm.rebalancer == nil { + return memoryLimit, threads + } + if memoryLimit == "" { + memoryLimit = sm.rebalancer.MemoryLimit() + } + if threads <= 0 { + threads = sm.rebalancer.PerSessionThreads() + } + return memoryLimit, threads +} + +func (sm *SessionManager) ReconnectFlightSession(ctx context.Context, username string, workerID int, ownerEpoch int64) (int32, *server.FlightExecutor, error) { + reconnector, ok := sm.pool.(flightReconnectPool) + if !ok { + return 0, nil, fmt.Errorf("worker pool does not support flight reconnect") + } + worker, err := reconnector.ReconnectFlightWorker(ctx, workerID, ownerEpoch) + if err != nil { + return 0, nil, fmt.Errorf("reconnect worker %d: %w", workerID, err) + } + return sm.createSessionOnWorker(ctx, username, 0, "", 0, worker, "flight", false) +} + +func (sm *SessionManager) createSessionOnWorker(ctx context.Context, username string, pid int32, memoryLimit string, threads int, worker *ManagedWorker, protocol string, retireOnFailure bool) (int32, *server.FlightExecutor, error) { createStart := time.Now() sessionToken, err := worker.CreateSession(ctx, username, memoryLimit, threads) if err != nil { - // Clean up the worker we just spawned (but not if it was a pre-warmed idle worker - // that has sessions from other concurrent requests). - sm.pool.RetireWorkerIfNoSessions(worker.ID) + if retireOnFailure { + sm.pool.RetireWorkerIfNoSessions(worker.ID) + } return 0, nil, fmt.Errorf("create session on worker %d: %w", worker.ID, err) } - // Create FlightExecutor sharing the worker's existing gRPC connection executor := server.NewFlightExecutorFromClient(worker.client, sessionToken) + executor.SetControlMetadata(worker.ID, worker.OwnerCPInstanceID(), worker.OwnerEpoch()) if pid == 0 { pid = sm.nextPID.Add(1) @@ -99,7 +132,7 @@ func (sm *SessionManager) CreateSession(ctx context.Context, username string, pi session := &ManagedSession{ PID: pid, WorkerID: worker.ID, - Protocol: "postgres", + Protocol: protocol, SessionToken: sessionToken, Executor: executor, } @@ -110,28 +143,12 @@ func (sm *SessionManager) CreateSession(ctx context.Context, username string, pi sm.mu.Unlock() slog.Debug("Session created on worker.", "pid", pid, "worker", worker.ID, "user", username, "create_duration", time.Since(createStart)) - - // Update other sessions if rebalancing is enabled. if sm.rebalancer != nil { sm.rebalancer.RequestRebalance() } - return pid, executor, nil } -func (sm *SessionManager) resolveSessionLimits(memoryLimit string, threads int) (string, int) { - if sm.rebalancer == nil { - return memoryLimit, threads - } - if memoryLimit == "" { - memoryLimit = sm.rebalancer.MemoryLimit() - } - if threads <= 0 { - threads = sm.rebalancer.PerSessionThreads() - } - return memoryLimit, threads -} - // DestroySession destroys a session, retires its dedicated worker, and rebalances // memory/thread limits across remaining sessions. func (sm *SessionManager) DestroySession(pid int32) { diff --git a/controlplane/shared_worker_activator.go b/controlplane/shared_worker_activator.go index 83591a3a..281aba12 100644 --- a/controlplane/shared_worker_activator.go +++ b/controlplane/shared_worker_activator.go @@ -11,7 +11,6 @@ import ( "slices" "strings" "sync" - "time" "github.com/posthog/duckgres/controlplane/configstore" "github.com/posthog/duckgres/controlplane/provisioner" @@ -38,10 +37,9 @@ type SharedWorkerActivator struct { } type TenantActivationPayload struct { - OrgID string `json:"org_id"` - Usernames []string `json:"usernames,omitempty"` - LeaseExpiresAt time.Time `json:"lease_expires_at,omitempty"` - DuckLake server.DuckLakeConfig `json:"ducklake"` + OrgID string `json:"org_id"` + Usernames []string `json:"usernames,omitempty"` + DuckLake server.DuckLakeConfig `json:"ducklake"` } func NewSharedWorkerActivator(shared *K8sWorkerPool, stsBroker *STSBroker, resolveOrgConfig func(string) (*configstore.OrgConfig, error)) *SharedWorkerActivator { @@ -111,9 +109,13 @@ func (a *SharedWorkerActivator) ActivateReservedWorker(ctx context.Context, work return a.activateReservedWorker(ctx, worker, payload) } return worker.ActivateTenant(ctx, server.WorkerActivationPayload{ - OrgID: payload.OrgID, - LeaseExpiresAt: payload.LeaseExpiresAt, - DuckLake: payload.DuckLake, + WorkerControlMetadata: server.WorkerControlMetadata{ + WorkerID: worker.ID, + OwnerEpoch: worker.OwnerEpoch(), + CPInstanceID: worker.OwnerCPInstanceID(), + }, + OrgID: payload.OrgID, + DuckLake: payload.DuckLake, }) } @@ -168,10 +170,9 @@ func (a *SharedWorkerActivator) BuildActivationRequest(ctx context.Context, org slices.Sort(usernames) return TenantActivationPayload{ - OrgID: assignment.OrgID, - Usernames: usernames, - LeaseExpiresAt: assignment.LeaseExpiresAt, - DuckLake: dl, + OrgID: assignment.OrgID, + Usernames: usernames, + DuckLake: dl, }, nil } diff --git a/controlplane/warm_pool_metrics.go b/controlplane/warm_pool_metrics.go index bb439fae..9d35415c 100644 --- a/controlplane/warm_pool_metrics.go +++ b/controlplane/warm_pool_metrics.go @@ -66,7 +66,7 @@ var hotWorkerSessionsHistogram = promauto.NewHistogram(prometheus.HistogramOpts{ const ( RetireReasonNormal = "normal" RetireReasonActivationFailure = "activation_failure" - RetireReasonLeaseExpiry = "lease_expiry" + RetireReasonOrphaned = "orphaned" RetireReasonCrash = "crash" RetireReasonShutdown = "shutdown" RetireReasonIdleTimeout = "idle_timeout" diff --git a/controlplane/warm_pool_metrics_test.go b/controlplane/warm_pool_metrics_test.go index 50b546b6..eb08ba67 100644 --- a/controlplane/warm_pool_metrics_test.go +++ b/controlplane/warm_pool_metrics_test.go @@ -20,11 +20,11 @@ func TestObserveWarmPoolLifecycleGauges(t *testing.T) { workers := map[int]*ManagedWorker{ 1: makeTestWorker(WorkerLifecycleIdle, nil), 2: makeTestWorker(WorkerLifecycleIdle, nil), - 3: makeTestWorker(WorkerLifecycleReserved, &WorkerAssignment{OrgID: "org-1", LeaseExpiresAt: time.Now().Add(time.Hour)}), - 4: makeTestWorker(WorkerLifecycleActivating, &WorkerAssignment{OrgID: "org-1", LeaseExpiresAt: time.Now().Add(time.Hour)}), - 5: makeTestWorker(WorkerLifecycleHot, &WorkerAssignment{OrgID: "org-2", LeaseExpiresAt: time.Now().Add(time.Hour)}), - 6: makeTestWorker(WorkerLifecycleHot, &WorkerAssignment{OrgID: "org-2", LeaseExpiresAt: time.Now().Add(time.Hour)}), - 7: makeTestWorker(WorkerLifecycleDraining, &WorkerAssignment{OrgID: "org-3", LeaseExpiresAt: time.Now().Add(time.Hour)}), + 3: makeTestWorker(WorkerLifecycleReserved, &WorkerAssignment{OrgID: "org-1"}), + 4: makeTestWorker(WorkerLifecycleActivating, &WorkerAssignment{OrgID: "org-1"}), + 5: makeTestWorker(WorkerLifecycleHot, &WorkerAssignment{OrgID: "org-2"}), + 6: makeTestWorker(WorkerLifecycleHot, &WorkerAssignment{OrgID: "org-2"}), + 7: makeTestWorker(WorkerLifecycleDraining, &WorkerAssignment{OrgID: "org-3"}), } observeWarmPoolLifecycleGauges(workers) @@ -73,7 +73,7 @@ func TestMarkWorkerRetiredLocked_RecordsHotWorkerSessions(t *testing.T) { resetMetrics() pool, _ := newTestK8sPool(t, 5) - w := makeTestWorker(WorkerLifecycleHot, &WorkerAssignment{OrgID: "org-1", LeaseExpiresAt: time.Now().Add(time.Hour)}) + w := makeTestWorker(WorkerLifecycleHot, &WorkerAssignment{OrgID: "org-1"}) w.peakSessions = 5 pool.workers[1] = w @@ -98,8 +98,7 @@ func TestReservedAtTracking(t *testing.T) { defer cancel() _, err := pool.ReserveSharedWorker(ctx, &WorkerAssignment{ - OrgID: "org-1", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "org-1", }) if err != nil { t.Fatalf("ReserveSharedWorker failed: %v", err) @@ -156,8 +155,7 @@ func TestActivateWorkerForOrgUpdatesActivatingGauge(t *testing.T) { observeWarmPoolLifecycleGauges(map[int]*ManagedWorker{}) pool, _ := newTestK8sPool(t, 5) worker := makeTestWorker(WorkerLifecycleReserved, &WorkerAssignment{ - OrgID: "org-1", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "org-1", }) worker.reservedAt = time.Now() pool.workers[1] = worker @@ -179,8 +177,7 @@ func TestActivateWorkerForOrgRecordsActivationDurationWhenWorkerAlreadyHot(t *te observeWarmPoolLifecycleGauges(map[int]*ManagedWorker{}) pool, _ := newTestK8sPool(t, 5) worker := makeTestWorker(WorkerLifecycleReserved, &WorkerAssignment{ - OrgID: "org-1", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "org-1", }) worker.reservedAt = time.Now().Add(-2 * time.Second) pool.workers[1] = worker @@ -220,8 +217,7 @@ func TestReapStuckActivatingWorkers(t *testing.T) { pool.workers[1] = idle stuck := makeTestWorker(WorkerLifecycleActivating, &WorkerAssignment{ - OrgID: "org-1", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "org-1", }) stuck.reservedAt = time.Now().Add(-time.Minute) // reserved 1 minute ago pool.workers[2] = stuck @@ -255,8 +251,7 @@ func TestReapStuckActivatingWorkers_RecentlyReservedNotReaped(t *testing.T) { pool.activatingTimeout = 2 * time.Minute w := makeTestWorker(WorkerLifecycleActivating, &WorkerAssignment{ - OrgID: "org-1", - LeaseExpiresAt: time.Now().Add(time.Hour), + OrgID: "org-1", }) w.reservedAt = time.Now() // just reserved pool.workers[1] = w diff --git a/controlplane/worker_mgr.go b/controlplane/worker_mgr.go index 21e01136..14bee01e 100644 --- a/controlplane/worker_mgr.go +++ b/controlplane/worker_mgr.go @@ -26,6 +26,7 @@ import ( // ManagedWorker represents a duckdb-service worker process. type ManagedWorker struct { ID int + podName string cmd *exec.Cmd socketPath string bearerToken string @@ -40,6 +41,8 @@ type ManagedWorker struct { sharedState SharedWorkerState reservedAt time.Time //nolint:unused // only set in kubernetes warm-pool reservation path peakSessions int // High-water mark of concurrent sessions (for retirement metrics) + ownerEpoch int64 + ownerCPInstanceID string } // SharedState returns the additive shared warm-worker lifecycle metadata for @@ -58,6 +61,31 @@ func (w *ManagedWorker) SetSharedState(state SharedWorkerState) error { return nil } +func (w *ManagedWorker) OwnerEpoch() int64 { + return w.ownerEpoch +} + +func (w *ManagedWorker) SetOwnerEpoch(epoch int64) { + w.ownerEpoch = epoch +} + +func (w *ManagedWorker) IncrementOwnerEpoch() int64 { + w.ownerEpoch++ + return w.ownerEpoch +} + +func (w *ManagedWorker) OwnerCPInstanceID() string { + return w.ownerCPInstanceID +} + +func (w *ManagedWorker) SetOwnerCPInstanceID(cpInstanceID string) { + w.ownerCPInstanceID = cpInstanceID +} + +func (w *ManagedWorker) PodName() string { + return w.podName +} + // preboundSocket is a Unix socket pre-bound at startup while the socket // directory is verified writable. This avoids EROFS errors that can occur // later under systemd's ProtectSystem=strict when the RuntimeDirectory @@ -492,10 +520,16 @@ func (r *healthCheckResult) toSessionProgress() map[string]*SessionProgress { // Returns the parsed result (including per-session progress) and an error if // the worker is unreachable. func doHealthCheck(ctx context.Context, client *flightsql.Client) (*healthCheckResult, error) { + return doHealthCheckWithMetadata(ctx, client, server.WorkerHealthCheckPayload{}) +} + +func doHealthCheckWithMetadata(ctx context.Context, client *flightsql.Client, payload server.WorkerHealthCheckPayload) (*healthCheckResult, error) { + body, _ := json.Marshal(payload) + // Use the underlying flight client for custom actions. // flightsql.Client.Client is a flight.Client interface which embeds // FlightServiceClient, giving us access to DoAction directly. - stream, err := client.Client.DoAction(ctx, &flight.Action{Type: "HealthCheck"}) + stream, err := client.Client.DoAction(ctx, &flight.Action{Type: "HealthCheck", Body: body}) if err != nil { return nil, fmt.Errorf("health check action: %w", err) } @@ -1153,10 +1187,15 @@ func recoverWorkerPanic(err *error) { func (w *ManagedWorker) CreateSession(ctx context.Context, username, memoryLimit string, threads int) (token string, err error) { defer recoverWorkerPanic(&err) - body, _ := json.Marshal(map[string]interface{}{ - "username": username, - "memory_limit": memoryLimit, - "threads": threads, + body, _ := json.Marshal(server.WorkerCreateSessionPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + WorkerID: w.ID, + OwnerEpoch: w.OwnerEpoch(), + CPInstanceID: w.OwnerCPInstanceID(), + }, + Username: username, + MemoryLimit: memoryLimit, + Threads: threads, }) stream, err := w.client.Client.DoAction(ctx, &flight.Action{ @@ -1209,7 +1248,14 @@ func (w *ManagedWorker) ActivateTenant(ctx context.Context, payload server.Worke func (w *ManagedWorker) DestroySession(ctx context.Context, sessionToken string) (err error) { defer recoverWorkerPanic(&err) - body, _ := json.Marshal(map[string]string{"session_token": sessionToken}) + body, _ := json.Marshal(server.WorkerDestroySessionPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + WorkerID: w.ID, + OwnerEpoch: w.OwnerEpoch(), + CPInstanceID: w.OwnerCPInstanceID(), + }, + SessionToken: sessionToken, + }) stream, err := w.client.Client.DoAction(ctx, &flight.Action{ Type: "DestroySession", diff --git a/controlplane/worker_mgr_process_test.go b/controlplane/worker_mgr_process_test.go index 09d391db..310a1ba6 100644 --- a/controlplane/worker_mgr_process_test.go +++ b/controlplane/worker_mgr_process_test.go @@ -1145,12 +1145,10 @@ func TestManagedWorkerSharedStateNormalizesZeroValue(t *testing.T) { } func TestManagedWorkerSetSharedStateClonesAssignment(t *testing.T) { - leaseExpiry := time.Date(2026, time.March, 20, 16, 0, 0, 0, time.UTC) input := SharedWorkerState{ Lifecycle: WorkerLifecycleReserved, Assignment: &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: leaseExpiry, + OrgID: "analytics", }, } @@ -1160,7 +1158,6 @@ func TestManagedWorkerSetSharedStateClonesAssignment(t *testing.T) { } input.Assignment.OrgID = "mutated" - input.Assignment.LeaseExpiresAt = leaseExpiry.Add(time.Hour) got := w.SharedState() if got.Assignment == nil { @@ -1172,9 +1169,6 @@ func TestManagedWorkerSetSharedStateClonesAssignment(t *testing.T) { if got.Assignment.OrgID != "analytics" { t.Fatalf("expected stored org ID analytics, got %q", got.Assignment.OrgID) } - if !got.Assignment.LeaseExpiresAt.Equal(leaseExpiry) { - t.Fatalf("expected stored lease expiry %v, got %v", leaseExpiry, got.Assignment.LeaseExpiresAt) - } got.Assignment.OrgID = "leaked" fresh := w.SharedState() diff --git a/controlplane/worker_pool.go b/controlplane/worker_pool.go index cc3a6ebd..05e0cd46 100644 --- a/controlplane/worker_pool.go +++ b/controlplane/worker_pool.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "time" + + "github.com/posthog/duckgres/controlplane/configstore" ) // WorkerPool abstracts the lifecycle and scheduling of Flight SQL workers. @@ -46,25 +48,36 @@ type WorkerPool interface { // K8sWorkerPoolConfig holds the configuration for creating a K8sWorkerPool. type K8sWorkerPoolConfig struct { - Namespace string - CPID string // Control plane pod name, used in labels - WorkerImage string - WorkerPort int - SecretName string // K8s Secret name containing bearer token - ConfigMap string // ConfigMap name for duckgres.yaml - MaxWorkers int - IdleTimeout time.Duration - ConfigPath string // Path inside worker pod where config is mounted - ImagePullPolicy string // Image pull policy for worker pods (e.g., "Never", "IfNotPresent", "Always") - ServiceAccount string // ServiceAccount name for worker pods (default: "default") - WorkerCPURequest string // CPU request for worker pods (e.g., "500m"). Empty = BestEffort. - WorkerMemoryRequest string // Memory request for worker pods (e.g., "1Gi"). Empty = BestEffort. + Namespace string + CPID string // Control plane pod name, used in labels + CPInstanceID string // Durable control-plane instance ID (:) + WorkerImage string + WorkerPort int + SecretName string // K8s Secret name containing bearer token + ConfigMap string // ConfigMap name for duckgres.yaml + MaxWorkers int + IdleTimeout time.Duration + ConfigPath string // Path inside worker pod where config is mounted + ImagePullPolicy string // Image pull policy for worker pods (e.g., "Never", "IfNotPresent", "Always") + ServiceAccount string // ServiceAccount name for worker pods (default: "default") + WorkerCPURequest string // CPU request for worker pods (e.g., "500m"). Empty = BestEffort. + WorkerMemoryRequest string // Memory request for worker pods (e.g., "1Gi"). Empty = BestEffort. WorkerNodeSelector map[string]string // Node selector for worker pods. Nil = no selector. WorkerTolerationKey string // Taint key for worker pod NoSchedule toleration. Empty = no toleration. WorkerTolerationValue string // Taint value for worker pod NoSchedule toleration. - WorkerExclusiveNode bool // One worker per node via pod anti-affinity. - OrgID string // Org ID for pod labels (multi-tenant mode) - WorkerIDGenerator func() int // Shared ID generator across orgs (nil = internal counter) + WorkerExclusiveNode bool // One worker per node via pod anti-affinity. + OrgID string // Org ID for pod labels (multi-tenant mode) + WorkerIDGenerator func() int // Shared ID generator across orgs (nil = internal counter) + RuntimeStore RuntimeWorkerStore +} + +type RuntimeWorkerStore interface { + UpsertWorkerRecord(record *configstore.WorkerRecord) error + ClaimIdleWorker(ownerCPInstanceID, orgID string, maxOrgWorkers int) (*configstore.WorkerRecord, error) + CreateSpawningWorkerSlot(ownerCPInstanceID, orgID string, ownerEpoch int64, podNamePrefix string, maxOrgWorkers, maxGlobalWorkers int) (*configstore.WorkerRecord, error) + CreateNeutralWarmWorkerSlot(ownerCPInstanceID, podNamePrefix string, targetWarmWorkers, maxGlobalWorkers int) (*configstore.WorkerRecord, error) + GetWorkerRecord(workerID int) (*configstore.WorkerRecord, error) + TakeOverWorker(workerID int, ownerCPInstanceID, orgID string, expectedOwnerEpoch int64) (*configstore.WorkerRecord, error) } // K8sPoolFactory creates a K8sWorkerPool. Registered at init time by the diff --git a/controlplane/worker_state.go b/controlplane/worker_state.go index a9c129b1..eb64c5d0 100644 --- a/controlplane/worker_state.go +++ b/controlplane/worker_state.go @@ -1,9 +1,6 @@ package controlplane -import ( - "fmt" - "time" -) +import "fmt" // WorkerLifecycleState models the shared warm-worker lifecycle introduced for // late tenant binding. The current production worker pools do not act on this @@ -22,8 +19,8 @@ const ( // WorkerAssignment carries tenant-specific metadata once a shared worker has // been reserved for an org. type WorkerAssignment struct { - OrgID string - LeaseExpiresAt time.Time + OrgID string + MaxWorkers int } // SharedWorkerState holds the additive lifecycle/assignment model for shared @@ -164,9 +161,6 @@ func validateWorkerAssignment(assignment *WorkerAssignment) error { if assignment.OrgID == "" { return fmt.Errorf("missing org ID") } - if assignment.LeaseExpiresAt.IsZero() { - return fmt.Errorf("missing lease expiry") - } return nil } diff --git a/controlplane/worker_state_test.go b/controlplane/worker_state_test.go index 552998df..d8b00779 100644 --- a/controlplane/worker_state_test.go +++ b/controlplane/worker_state_test.go @@ -4,7 +4,6 @@ package controlplane import ( "testing" - "time" ) func TestSharedWorkerStateZeroValueDefaultsToIdle(t *testing.T) { @@ -22,11 +21,8 @@ func TestSharedWorkerStateZeroValueDefaultsToIdle(t *testing.T) { } func TestSharedWorkerStateTransitionLifecycle(t *testing.T) { - leaseExpiry := time.Date(2026, time.March, 20, 16, 0, 0, 0, time.UTC) - state, err := (SharedWorkerState{}).Transition(WorkerLifecycleReserved, &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: leaseExpiry, + OrgID: "analytics", }) if err != nil { t.Fatalf("reserve worker: %v", err) @@ -37,9 +33,6 @@ func TestSharedWorkerStateTransitionLifecycle(t *testing.T) { if state.Assignment == nil || state.Assignment.OrgID != "analytics" { t.Fatalf("expected analytics assignment, got %#v", state.Assignment) } - if !state.Assignment.LeaseExpiresAt.Equal(leaseExpiry) { - t.Fatalf("expected lease expiry %v, got %v", leaseExpiry, state.Assignment.LeaseExpiresAt) - } for _, next := range []WorkerLifecycleState{ WorkerLifecycleActivating, @@ -66,24 +59,20 @@ func TestSharedWorkerStateTransitionRejectsMissingOrInvalidAssignment(t *testing t.Fatal("expected reserve transition without assignment to fail") } - if _, err := (SharedWorkerState{}).Transition(WorkerLifecycleReserved, &WorkerAssignment{ - LeaseExpiresAt: time.Now().Add(time.Hour), - }); err == nil { + if _, err := (SharedWorkerState{}).Transition(WorkerLifecycleReserved, &WorkerAssignment{}); err == nil { t.Fatal("expected reserve transition without org ID to fail") } if _, err := (SharedWorkerState{}).Transition(WorkerLifecycleReserved, &WorkerAssignment{ OrgID: "analytics", - }); err == nil { - t.Fatal("expected reserve transition without lease expiry to fail") + }); err != nil { + t.Fatalf("expected reserve transition with org ID only to succeed: %v", err) } } func TestSharedWorkerStateTransitionRejectsInvalidLifecycleMoves(t *testing.T) { - leaseExpiry := time.Date(2026, time.March, 20, 16, 0, 0, 0, time.UTC) state, err := (SharedWorkerState{}).Transition(WorkerLifecycleReserved, &WorkerAssignment{ - OrgID: "analytics", - LeaseExpiresAt: leaseExpiry, + OrgID: "analytics", }) if err != nil { t.Fatalf("reserve worker: %v", err) @@ -99,8 +88,7 @@ func TestSharedWorkerStateTransitionRejectsInvalidLifecycleMoves(t *testing.T) { } if _, err := state.Transition(WorkerLifecycleHot, &WorkerAssignment{ - OrgID: "billing", - LeaseExpiresAt: leaseExpiry.Add(time.Hour), + OrgID: "billing", }); err == nil { t.Fatal("expected activating -> hot transition to reject assignment changes") } diff --git a/docs/runbooks/control-plane-rollout.md b/docs/runbooks/control-plane-rollout.md new file mode 100644 index 00000000..cb552887 --- /dev/null +++ b/docs/runbooks/control-plane-rollout.md @@ -0,0 +1,69 @@ +# Runbook: Control-Plane Rolling Rollout + +## Goal + +Replace Duckgres control-plane replicas without breaking most existing sessions during a planned deployment. + +## Requirements + +- The Deployment uses rolling replacement with overlap: + - `maxUnavailable: 0` + - `maxSurge: 1` +- `terminationGracePeriodSeconds` is at least the configured drain timeout +- The control plane sets `--handover-drain-timeout 15m` (or another explicit value appropriate for the cluster) + +## Expected behavior + +1. The old replica receives `SIGTERM`. +2. It marks itself `draining` in runtime state and fails `/health`. +3. New pgwire sessions are rejected on the draining replica. +4. New Flight bootstrap sessions are rejected on the draining replica. +5. Existing pgwire connections and existing Flight sessions continue until they finish or the drain timeout expires. +6. When the timeout expires, the replica force-shuts down remaining sessions and workers. + +Unplanned control-plane failure is different: + +- live pgwire connections are lost immediately +- durable Flight reconnect may recover only when the worker survives and the session token is still valid + +## Rollout procedure + +1. Start the rollout. + ```bash + kubectl -n duckgres rollout restart deploy/duckgres-control-plane + kubectl -n duckgres rollout status deploy/duckgres-control-plane + ``` + +2. Watch old and new pods during overlap. + ```bash + kubectl -n duckgres get pods -l app=duckgres-control-plane -w + ``` + +3. Verify the old pod becomes unready before it exits. + ```bash + kubectl -n duckgres get pods -l app=duckgres-control-plane + kubectl -n duckgres logs + ``` + +4. Confirm the new pod is serving traffic and warm capacity recovers. + - `duckgres_warm_workers` returns to target + - `duckgres_hot_workers` does not drop unexpectedly + - client reconnect errors do not spike + +## If a rollout stalls + +- Check whether the old pod is still draining active sessions: + ```bash + kubectl -n duckgres logs | rg "drain|draining|shutdown" + ``` +- Check whether the pod termination grace period is shorter than the configured drain timeout. +- Check whether long-lived idle clients are holding pgwire connections open. +- Check whether Flight sessions are still active and not timing out. + +## If the timeout is too short + +- Increase both: + - `--handover-drain-timeout` + - `terminationGracePeriodSeconds` + +Keep those values aligned. If the pod is killed before the drain timeout elapses, Kubernetes will cut the drain short. diff --git a/docs/runbooks/replenish-capacity.md b/docs/runbooks/replenish-capacity.md index ebdb7ba1..db69f9a3 100644 --- a/docs/runbooks/replenish-capacity.md +++ b/docs/runbooks/replenish-capacity.md @@ -3,7 +3,7 @@ ## When to use - `duckgres_warm_workers` is 0 and sessions are queuing -- After a mass retirement event (rolling update, crash storm) +- After a mass retirement event (planned control-plane rollout, crash storm) - Scaling up for anticipated traffic ## Background diff --git a/duckdbservice/activation.go b/duckdbservice/activation.go index c08cb9a6..807e6e46 100644 --- a/duckdbservice/activation.go +++ b/duckdbservice/activation.go @@ -6,7 +6,6 @@ import ( "os" "reflect" "strings" - "time" "github.com/posthog/duckgres/server" ) @@ -14,9 +13,9 @@ import ( // ActivationPayload carries the tenant-specific runtime that is delivered to a // neutral shared warm worker over the control-plane RPC channel. type ActivationPayload struct { - OrgID string `json:"org_id"` - LeaseExpiresAt time.Time `json:"lease_expires_at"` - DuckLake server.DuckLakeConfig `json:"ducklake"` + server.WorkerControlMetadata + OrgID string `json:"org_id"` + DuckLake server.DuckLakeConfig `json:"ducklake"` } type activatedTenantRuntime struct { @@ -41,16 +40,41 @@ func (p *SessionPool) activateTenant(payload ActivationPayload) error { if strings.TrimSpace(payload.OrgID) == "" { return fmt.Errorf("org_id is required") } + if payload.OwnerEpoch < 0 { + return fmt.Errorf("owner_epoch must be non-negative") + } p.mu.RLock() current := p.activation + currentOwnerEpoch := p.ownerEpoch + currentOwnerCPInstanceID := p.ownerCPInstanceID + currentWorkerID := p.workerID p.mu.RUnlock() + if currentWorkerID > 0 && payload.WorkerID != currentWorkerID { + return fmt.Errorf("stale worker_id %d (current %d)", payload.WorkerID, currentWorkerID) + } if current != nil { + if !sameTenantActivationRuntime(current.payload, payload) { + return fmt.Errorf("worker already activated for org %q", current.payload.OrgID) + } if reflect.DeepEqual(current.payload, payload) { return nil } + if payload.OwnerEpoch <= currentOwnerEpoch { + return fmt.Errorf("same-tenant takeover requires newer owner epoch %d (current %d)", payload.OwnerEpoch, currentOwnerEpoch) + } + if p.reuseExistingActivation(payload) { + return nil + } return fmt.Errorf("worker already activated for org %q", current.payload.OrgID) } + if currentOwnerCPInstanceID == "" { + if payload.OwnerEpoch <= currentOwnerEpoch { + return fmt.Errorf("stale owner epoch %d (current %d)", payload.OwnerEpoch, currentOwnerEpoch) + } + } else if payload.OwnerEpoch <= currentOwnerEpoch { + return fmt.Errorf("stale owner epoch %d (current %d)", payload.OwnerEpoch, currentOwnerEpoch) + } cfg := p.cfg cfg.DuckLake = payload.DuckLake @@ -79,8 +103,14 @@ func (p *SessionPool) activateTenant(payload ActivationPayload) error { p.mu.Lock() defer p.mu.Unlock() + if p.workerID > 0 && payload.WorkerID != p.workerID { + return fmt.Errorf("stale worker_id %d (current %d)", payload.WorkerID, p.workerID) + } + if payload.OwnerEpoch <= p.ownerEpoch { + return fmt.Errorf("stale owner epoch %d (current %d)", payload.OwnerEpoch, p.ownerEpoch) + } if p.activation != nil { - if reflect.DeepEqual(p.activation.payload, payload) { + if sameTenantActivationRuntime(p.activation.payload, payload) && reflect.DeepEqual(p.activation.payload, payload) { return nil } return fmt.Errorf("worker already activated for org %q", p.activation.payload.OrgID) @@ -90,6 +120,65 @@ func (p *SessionPool) activateTenant(payload ActivationPayload) error { payload: payload, db: db, } + p.ownerEpoch = payload.OwnerEpoch + p.ownerCPInstanceID = payload.CPInstanceID + p.workerID = payload.WorkerID + return nil +} + +func (p *SessionPool) reuseExistingActivation(payload ActivationPayload) bool { + p.mu.Lock() + defer p.mu.Unlock() + return p.reuseExistingActivationLocked(payload) +} + +func (p *SessionPool) reuseExistingActivationLocked(payload ActivationPayload) bool { + if p.activation == nil { + return false + } + current := p.activation.payload + if !sameTenantActivationRuntime(current, payload) { + return false + } + if !reflect.DeepEqual(current, payload) && payload.OwnerEpoch <= current.OwnerEpoch { + return false + } + p.activation.payload = payload + p.ownerEpoch = payload.OwnerEpoch + p.ownerCPInstanceID = payload.CPInstanceID + p.workerID = payload.WorkerID + return true +} + +func sameTenantActivationRuntime(current, next ActivationPayload) bool { + return current.OrgID == next.OrgID && reflect.DeepEqual(current.DuckLake, next.DuckLake) +} + +func (p *SessionPool) validateControlMetadata(meta server.WorkerControlMetadata) error { + if !p.sharedWarmMode { + return nil + } + if meta.OwnerEpoch < 0 { + return fmt.Errorf("owner_epoch must be non-negative") + } + + p.mu.RLock() + defer p.mu.RUnlock() + if p.workerID > 0 && meta.WorkerID != 0 && meta.WorkerID != p.workerID { + return fmt.Errorf("stale worker_id %d (current %d)", meta.WorkerID, p.workerID) + } + if p.activation == nil && p.ownerCPInstanceID == "" { + return nil + } + if meta.OwnerEpoch != p.ownerEpoch { + return fmt.Errorf("stale owner epoch %d (current %d)", meta.OwnerEpoch, p.ownerEpoch) + } + if p.ownerCPInstanceID != "" && meta.CPInstanceID != p.ownerCPInstanceID { + return fmt.Errorf("stale cp_instance_id %q (current %q)", meta.CPInstanceID, p.ownerCPInstanceID) + } + if p.workerID > 0 && meta.WorkerID != p.workerID { + return fmt.Errorf("stale worker_id %d (current %d)", meta.WorkerID, p.workerID) + } return nil } diff --git a/duckdbservice/activation_test.go b/duckdbservice/activation_test.go index b5f58c5e..69fd4852 100644 --- a/duckdbservice/activation_test.go +++ b/duckdbservice/activation_test.go @@ -44,6 +44,11 @@ func TestSessionPoolActivateTenantConfiguresTenantRuntime(t *testing.T) { }() err := pool.activateTenant(ActivationPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + OwnerEpoch: 1, + CPInstanceID: "cp-live:boot-a", + WorkerID: 17, + }, OrgID: "analytics", DuckLake: server.DuckLakeConfig{ MetadataStore: "postgres:host=metadata.internal port=5432 user=ducklake password=secret dbname=ducklake", @@ -86,6 +91,11 @@ func TestSessionPoolActivateTenantRejectsSecondActivation(t *testing.T) { } if err := pool.activateTenant(ActivationPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + OwnerEpoch: 1, + CPInstanceID: "cp-live:boot-a", + WorkerID: 17, + }, OrgID: "analytics", DuckLake: server.DuckLakeConfig{ MetadataStore: "postgres:host=metadata.internal port=5432 user=ducklake password=secret dbname=ducklake", @@ -95,6 +105,11 @@ func TestSessionPoolActivateTenantRejectsSecondActivation(t *testing.T) { } if err := pool.activateTenant(ActivationPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + OwnerEpoch: 2, + CPInstanceID: "cp-live:boot-a", + WorkerID: 17, + }, OrgID: "billing", DuckLake: server.DuckLakeConfig{ MetadataStore: "postgres:host=metadata.internal port=5432 user=ducklake password=secret dbname=ducklake", @@ -103,3 +118,185 @@ func TestSessionPoolActivateTenantRejectsSecondActivation(t *testing.T) { t.Fatal("expected second activation to fail") } } + +func TestSessionPoolActivateTenantRejectsStaleOwnerEpoch(t *testing.T) { + pool := &SessionPool{ + sessions: make(map[string]*Session), + stopRefresh: make(map[string]func()), + duckLakeSem: make(chan struct{}, 1), + cfg: server.Config{Users: map[string]string{"postgres": "postgres"}}, + startTime: time.Now(), + warmupDone: make(chan struct{}), + sharedWarmMode: true, + } + close(pool.warmupDone) + + pool.createDBConnection = func(cfg server.Config, sem chan struct{}, username string, startTime time.Time, version string) (*sql.DB, error) { + return sql.Open("duckdb", "") + } + pool.activateDBConnection = func(db *sql.DB, cfg server.Config, sem chan struct{}, username string) error { + return nil + } + + if err := pool.activateTenant(ActivationPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{OwnerEpoch: 2}, + OrgID: "analytics", + DuckLake: server.DuckLakeConfig{ + MetadataStore: "postgres:host=metadata.internal port=5432 user=ducklake password=secret dbname=ducklake", + }, + }); err != nil { + t.Fatalf("first ActivateTenant: %v", err) + } + + if err := pool.activateTenant(ActivationPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{OwnerEpoch: 1}, + OrgID: "analytics", + DuckLake: server.DuckLakeConfig{ + MetadataStore: "postgres:host=metadata.internal port=5432 user=ducklake password=secret dbname=ducklake", + }, + }); err == nil { + t.Fatal("expected stale owner epoch to be rejected") + } +} + +func TestSessionPoolActivateTenantAllowsSameOrgTakeover(t *testing.T) { + pool := &SessionPool{ + sessions: make(map[string]*Session), + stopRefresh: make(map[string]func()), + duckLakeSem: make(chan struct{}, 1), + cfg: server.Config{Users: map[string]string{"postgres": "postgres"}}, + startTime: time.Now(), + warmupDone: make(chan struct{}), + sharedWarmMode: true, + } + close(pool.warmupDone) + + var activateCalls int + pool.createDBConnection = func(cfg server.Config, sem chan struct{}, username string, startTime time.Time, version string) (*sql.DB, error) { + return sql.Open("duckdb", "") + } + pool.activateDBConnection = func(db *sql.DB, cfg server.Config, sem chan struct{}, username string) error { + activateCalls++ + return nil + } + + first := ActivationPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + OwnerEpoch: 2, + CPInstanceID: "cp-old:boot-a", + WorkerID: 17, + }, + OrgID: "analytics", + DuckLake: server.DuckLakeConfig{ + MetadataStore: "postgres:host=metadata.internal port=5432 user=ducklake password=secret dbname=ducklake", + ObjectStore: "s3://analytics/warehouse/", + }, + } + if err := pool.activateTenant(first); err != nil { + t.Fatalf("first ActivateTenant: %v", err) + } + + second := first + second.OwnerEpoch = 3 + second.CPInstanceID = "cp-new:boot-b" + if err := pool.activateTenant(second); err != nil { + t.Fatalf("takeover ActivateTenant: %v", err) + } + + if activateCalls != 1 { + t.Fatalf("expected same-tenant takeover to reuse existing activation, got %d activation calls", activateCalls) + } + current := pool.currentActivation() + if current == nil { + t.Fatal("expected activation to remain present") + } + if current.payload.OwnerEpoch != 3 { + t.Fatalf("expected owner epoch 3, got %d", current.payload.OwnerEpoch) + } + if current.payload.CPInstanceID != "cp-new:boot-b" { + t.Fatalf("expected cp instance id cp-new:boot-b, got %q", current.payload.CPInstanceID) + } + if pool.ownerEpoch != 3 { + t.Fatalf("expected pool owner epoch 3, got %d", pool.ownerEpoch) + } + if pool.ownerCPInstanceID != "cp-new:boot-b" { + t.Fatalf("expected pool owner cp instance id cp-new:boot-b, got %q", pool.ownerCPInstanceID) + } +} + +func TestSessionPoolActivateTenantRejectsSameEpochOwnerChange(t *testing.T) { + pool := &SessionPool{ + sessions: make(map[string]*Session), + stopRefresh: make(map[string]func()), + duckLakeSem: make(chan struct{}, 1), + cfg: server.Config{Users: map[string]string{"postgres": "postgres"}}, + startTime: time.Now(), + warmupDone: make(chan struct{}), + sharedWarmMode: true, + } + close(pool.warmupDone) + + pool.createDBConnection = func(cfg server.Config, sem chan struct{}, username string, startTime time.Time, version string) (*sql.DB, error) { + return sql.Open("duckdb", "") + } + pool.activateDBConnection = func(db *sql.DB, cfg server.Config, sem chan struct{}, username string) error { + return nil + } + + first := ActivationPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + OwnerEpoch: 2, + CPInstanceID: "cp-old:boot-a", + WorkerID: 17, + }, + OrgID: "analytics", + DuckLake: server.DuckLakeConfig{ + MetadataStore: "postgres:host=metadata.internal port=5432 user=ducklake password=secret dbname=ducklake", + }, + } + if err := pool.activateTenant(first); err != nil { + t.Fatalf("first ActivateTenant: %v", err) + } + + second := first + second.CPInstanceID = "cp-new:boot-b" + if err := pool.activateTenant(second); err == nil { + t.Fatal("expected same-epoch owner change to be rejected") + } +} + +func TestSessionPoolValidateControlMetadataRejectsMismatchedCPInstanceID(t *testing.T) { + pool := &SessionPool{ + sharedWarmMode: true, + ownerEpoch: 4, + ownerCPInstanceID: "cp-live:boot-a", + workerID: 17, + } + + err := pool.validateControlMetadata(server.WorkerControlMetadata{ + WorkerID: 17, + OwnerEpoch: 4, + CPInstanceID: "cp-other:boot-b", + }) + if err == nil { + t.Fatal("expected mismatched cp_instance_id to be rejected") + } +} + +func TestSessionPoolValidateControlMetadataRejectsMismatchedWorkerID(t *testing.T) { + pool := &SessionPool{ + sharedWarmMode: true, + ownerEpoch: 4, + ownerCPInstanceID: "cp-live:boot-a", + workerID: 17, + } + + err := pool.validateControlMetadata(server.WorkerControlMetadata{ + WorkerID: 18, + OwnerEpoch: 4, + CPInstanceID: "cp-live:boot-a", + }) + if err == nil { + t.Fatal("expected mismatched worker_id to be rejected") + } +} diff --git a/duckdbservice/flight_handler.go b/duckdbservice/flight_handler.go index 7c02b969..8e070974 100644 --- a/duckdbservice/flight_handler.go +++ b/duckdbservice/flight_handler.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "log/slog" + "strconv" "strings" "time" @@ -16,6 +17,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/flight/flightsql/schema_ref" "github.com/apache/arrow-go/v18/arrow/memory" bindings "github.com/duckdb/duckdb-go-bindings" + "github.com/posthog/duckgres/server" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -46,6 +48,35 @@ func (h *FlightSQLHandler) sessionFromContext(ctx context.Context) (*Session, er if !ok { return nil, status.Error(codes.Unauthenticated, "session not found") } + if h.pool.sharedWarmMode { + epochs := md.Get("x-duckgres-owner-epoch") + if len(epochs) == 0 { + return nil, status.Error(codes.Unauthenticated, "missing x-duckgres-owner-epoch header") + } + ownerEpoch, err := strconv.ParseInt(epochs[0], 10, 64) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid x-duckgres-owner-epoch header") + } + workerIDs := md.Get("x-duckgres-worker-id") + if len(workerIDs) == 0 { + return nil, status.Error(codes.Unauthenticated, "missing x-duckgres-worker-id header") + } + workerID, err := strconv.Atoi(workerIDs[0]) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid x-duckgres-worker-id header") + } + cpInstanceIDs := md.Get("x-duckgres-cp-instance-id") + if len(cpInstanceIDs) == 0 { + return nil, status.Error(codes.Unauthenticated, "missing x-duckgres-cp-instance-id header") + } + if err := h.pool.validateControlMetadata(server.WorkerControlMetadata{ + WorkerID: workerID, + OwnerEpoch: ownerEpoch, + CPInstanceID: cpInstanceIDs[0], + }); err != nil { + return nil, status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err) + } + } session.lastUsed.Store(time.Now().UnixNano()) return session, nil @@ -54,17 +85,16 @@ func (h *FlightSQLHandler) sessionFromContext(ctx context.Context) (*Session, er // Custom action handlers (called via customActionServer.DoAction) func (h *FlightSQLHandler) doCreateSession(body []byte, stream flight.FlightService_DoActionServer) error { - var req struct { - Username string `json:"username"` - MemoryLimit string `json:"memory_limit"` - Threads int `json:"threads"` - } + var req server.WorkerCreateSessionPayload if err := json.Unmarshal(body, &req); err != nil { return status.Errorf(codes.InvalidArgument, "invalid CreateSession request: %v", err) } if req.Username == "" { return status.Error(codes.InvalidArgument, "username is required") } + if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil { + return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err) + } if h.pool.sharedWarmMode { if _, err := h.pool.currentSessionConfig(); err != nil { @@ -111,15 +141,16 @@ func (h *FlightSQLHandler) doActivateTenant(body []byte, stream flight.FlightSer } func (h *FlightSQLHandler) doDestroySession(body []byte, stream flight.FlightService_DoActionServer) error { - var req struct { - SessionToken string `json:"session_token"` - } + var req server.WorkerDestroySessionPayload if err := json.Unmarshal(body, &req); err != nil { return status.Errorf(codes.InvalidArgument, "invalid DestroySession request: %v", err) } if req.SessionToken == "" { return status.Error(codes.InvalidArgument, "session_token is required") } + if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil { + return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err) + } if err := h.pool.DestroySession(req.SessionToken); err != nil { return status.Errorf(codes.NotFound, "%v", err) @@ -129,7 +160,15 @@ func (h *FlightSQLHandler) doDestroySession(body []byte, stream flight.FlightSer return stream.Send(&flight.Result{Body: resp}) } -func (h *FlightSQLHandler) doHealthCheck(stream flight.FlightService_DoActionServer) error { +func (h *FlightSQLHandler) doHealthCheck(body []byte, stream flight.FlightService_DoActionServer) error { + var req server.WorkerHealthCheckPayload + if err := json.Unmarshal(body, &req); err != nil { + return status.Errorf(codes.InvalidArgument, "invalid HealthCheck request: %v", err) + } + if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil { + return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err) + } + // Block until warmup (extension loading + DuckLake attachment) completes. // Without this, the control plane's waitForWorkerTCP health check passes // as soon as the gRPC server starts, and clients get routed to a worker diff --git a/duckdbservice/flight_handler_test.go b/duckdbservice/flight_handler_test.go index 5bf1bc1e..b3a37a59 100644 --- a/duckdbservice/flight_handler_test.go +++ b/duckdbservice/flight_handler_test.go @@ -1,6 +1,7 @@ package duckdbservice import ( + "context" "database/sql" "encoding/json" "testing" @@ -42,7 +43,7 @@ func TestHealthCheckBlocksUntilWarmup(t *testing.T) { // Health check in a goroutine — should block until warmup completes done := make(chan error, 1) go func() { - done <- handler.doHealthCheck(stream) + done <- handler.doHealthCheck([]byte(`{}`), stream) }() // Verify it hasn't returned after 100ms @@ -94,7 +95,7 @@ func TestHealthCheckReturnsImmediatelyAfterWarmup(t *testing.T) { done := make(chan error, 1) go func() { - done <- handler.doHealthCheck(stream) + done <- handler.doHealthCheck([]byte(`{}`), stream) }() select { @@ -107,6 +108,28 @@ func TestHealthCheckReturnsImmediatelyAfterWarmup(t *testing.T) { } } +func TestHealthCheckRejectsMissingOwnerEpochInSharedWarmMode(t *testing.T) { + pool := &SessionPool{ + sessions: make(map[string]*Session), + stopRefresh: make(map[string]func()), + warmupDone: make(chan struct{}), + startTime: time.Now(), + sharedWarmMode: true, + ownerEpoch: 5, + ownerCPInstanceID: "cp-live:boot-a", + workerID: 17, + } + close(pool.warmupDone) + + handler := &FlightSQLHandler{pool: pool} + stream := &mockDoActionStream{} + + err := handler.doHealthCheck([]byte(`{}`), stream) + if status.Code(err) != codes.FailedPrecondition { + t.Fatalf("expected FailedPrecondition, got %v", err) + } +} + func TestCreateSessionRequiresActivationForSharedWarmMode(t *testing.T) { pool := &SessionPool{ sessions: make(map[string]*Session), @@ -157,6 +180,11 @@ func TestActivateTenantRejectsDifferentTenantAfterActivation(t *testing.T) { stream := &mockDoActionStream{} firstBody, err := json.Marshal(ActivationPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + OwnerEpoch: 1, + CPInstanceID: "cp-live:boot-a", + WorkerID: 17, + }, OrgID: "analytics", }) if err != nil { @@ -167,6 +195,11 @@ func TestActivateTenantRejectsDifferentTenantAfterActivation(t *testing.T) { } secondBody, err := json.Marshal(ActivationPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + OwnerEpoch: 2, + CPInstanceID: "cp-live:boot-a", + WorkerID: 17, + }, OrgID: "billing", }) if err != nil { @@ -178,3 +211,96 @@ func TestActivateTenantRejectsDifferentTenantAfterActivation(t *testing.T) { t.Fatalf("expected FailedPrecondition, got %v", err) } } + +func TestCreateSessionRejectsStaleOwnerEpochInSharedWarmMode(t *testing.T) { + pool := &SessionPool{ + sessions: make(map[string]*Session), + stopRefresh: make(map[string]func()), + warmupDone: make(chan struct{}), + startTime: time.Now(), + cfg: server.Config{}, + sharedWarmMode: true, + ownerEpoch: 4, + activation: &activatedTenantRuntime{ + payload: ActivationPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{OwnerEpoch: 4}, + OrgID: "analytics", + }, + }, + } + close(pool.warmupDone) + + handler := &FlightSQLHandler{pool: pool} + stream := &mockDoActionStream{} + + body, err := json.Marshal(server.WorkerCreateSessionPayload{ + WorkerControlMetadata: server.WorkerControlMetadata{ + OwnerEpoch: 3, + }, + Username: "alice", + }) + if err != nil { + t.Fatalf("marshal request: %v", err) + } + + err = handler.doCreateSession(body, stream) + if status.Code(err) != codes.FailedPrecondition { + t.Fatalf("expected FailedPrecondition, got %v", err) + } +} + +func TestSessionFromContextRejectsStaleOwnerEpoch(t *testing.T) { + pool := &SessionPool{ + sessions: make(map[string]*Session), + stopRefresh: make(map[string]func()), + warmupDone: make(chan struct{}), + startTime: time.Now(), + sharedWarmMode: true, + ownerEpoch: 5, + ownerCPInstanceID: "cp-live:boot-a", + workerID: 17, + } + close(pool.warmupDone) + pool.sessions["session-1"] = &Session{ID: "session-1"} + + handler := &FlightSQLHandler{pool: pool} + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs( + "x-duckgres-session", "session-1", + "x-duckgres-owner-epoch", "4", + "x-duckgres-worker-id", "17", + "x-duckgres-cp-instance-id", "cp-live:boot-a", + )) + + _, err := handler.sessionFromContext(ctx) + if status.Code(err) != codes.FailedPrecondition { + t.Fatalf("expected FailedPrecondition, got %v", err) + } +} + +func TestSessionFromContextRejectsMismatchedControlIdentity(t *testing.T) { + pool := &SessionPool{ + sessions: make(map[string]*Session), + stopRefresh: make(map[string]func()), + warmupDone: make(chan struct{}), + startTime: time.Now(), + sharedWarmMode: true, + ownerEpoch: 5, + ownerCPInstanceID: "cp-live:boot-a", + workerID: 17, + } + close(pool.warmupDone) + pool.sessions["session-1"] = &Session{ID: "session-1"} + + handler := &FlightSQLHandler{pool: pool} + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs( + "x-duckgres-session", "session-1", + "x-duckgres-owner-epoch", "5", + "x-duckgres-worker-id", "18", + "x-duckgres-cp-instance-id", "cp-other:boot-b", + )) + + _, err := handler.sessionFromContext(ctx) + if status.Code(err) != codes.FailedPrecondition { + t.Fatalf("expected FailedPrecondition, got %v", err) + } +} diff --git a/duckdbservice/service.go b/duckdbservice/service.go index 5b58ab96..ffb73df2 100644 --- a/duckdbservice/service.go +++ b/duckdbservice/service.go @@ -49,6 +49,9 @@ type SessionPool struct { sharedWarmMode bool activation *activatedTenantRuntime + ownerEpoch int64 + ownerCPInstanceID string + workerID int activateTenantFunc func(ActivationPayload) error createDBConnection func(server.Config, chan struct{}, string, time.Time, string) (*sql.DB, error) activateDBConnection func(*sql.DB, server.Config, chan struct{}, string) error @@ -636,7 +639,7 @@ func (s *customActionServer) DoAction(cmd *flight.Action, stream flight.FlightSe case "DestroySession": return s.handler.doDestroySession(cmd.Body, stream) case "HealthCheck": - return s.handler.doHealthCheck(stream) + return s.handler.doHealthCheck(cmd.Body, stream) default: // Fall through to standard flightsql action router (BeginTransaction, etc.) return s.FlightServer.DoAction(cmd, stream) diff --git a/justfile b/justfile index 5b0839a8..cc4a7453 100644 --- a/justfile +++ b/justfile @@ -50,7 +50,11 @@ build-k8s-image tag="duckgres:test": [group('dev')] kind-cluster-reset: kind delete cluster --name "${DUCKGRES_KIND_CLUSTER_NAME:-duckgres}" || true - kind create cluster --name "${DUCKGRES_KIND_CLUSTER_NAME:-duckgres}" --wait 120s + if [ -n "${DUCKGRES_KIND_NODE_IMAGE:-}" ]; then \ + kind create cluster --name "${DUCKGRES_KIND_CLUSTER_NAME:-duckgres}" --image "${DUCKGRES_KIND_NODE_IMAGE}" --wait 120s; \ + else \ + kind create cluster --name "${DUCKGRES_KIND_CLUSTER_NAME:-duckgres}" --wait 120s; \ + fi kind export kubeconfig --name "${DUCKGRES_KIND_CLUSTER_NAME:-duckgres}" --kubeconfig "${DUCKGRES_KIND_KUBECONFIG:-/tmp/duckgres-kind-kubeconfig}" # Delete the local kind cluster used by the portable K8s integration flow diff --git a/k8s/README.md b/k8s/README.md index bb91be1f..dee83316 100644 --- a/k8s/README.md +++ b/k8s/README.md @@ -26,15 +26,17 @@ This directory contains **development/reference manifests** for running duckgres │ │ Bearer auth │ │ Bearer auth │ │ Bearer auth │ │ │ └──────────────┘ └──────────────┘ └──────────────┘ │ │ │ +│ Runtime coordination lives in config-store Postgres │ +│ (`cp_instances`, `worker_records`, Flight sessions) │ +│ │ │ Worker pods have: │ -│ - Owner references → CP pod (GC on CP deletion) │ │ - SecurityContext: non-root (UID 1000) │ │ - Bearer token from K8s Secret │ │ - ConfigMap mount for shared config │ └─────────────────────────────────────────────────────┘ ``` -The control plane handles TLS, authentication, PostgreSQL wire protocol, and SQL transpilation. Workers are thin DuckDB execution engines exposed via Arrow Flight SQL. Workers are spawned on demand and reaped when idle. +The control plane handles TLS, authentication, PostgreSQL wire protocol, and SQL transpilation. Workers are thin DuckDB execution engines exposed via Arrow Flight SQL. Workers are spawned on demand and reaped when idle. Planned rolling replacements mark old replicas draining and fail readiness before termination; unplanned control-plane failure still drops existing pgwire connections. ## Manifests @@ -47,7 +49,6 @@ The control plane handles TLS, authentication, PostgreSQL wire protocol, and SQL | `managed-warehouse-secrets.yaml` | Local secret payloads referenced by the seeded managed-warehouse contract | | `worker-identity.yaml` | Local worker ServiceAccount referenced by the seeded managed-warehouse contract | | `networkpolicy.yaml` | Restricts worker ingress to CP pods only | -| `control-plane-deployment.yaml` | Local multitenant CP Deployment + ClusterIP Service | | `control-plane-multitenant-local.yaml` | Optional OrbStack-oriented shared warm-worker control-plane manifest | | `kind/config-store.overlay.yaml` | Compose overlay that attaches local dependency containers to the external Docker `kind` network | | `kind/config-store.seed.sql` | Kind-oriented managed-warehouse seed for the shared warm-worker flow | @@ -62,6 +63,7 @@ Key flags for Kubernetes multitenant mode: |------|---------|-------------| | `--worker-backend remote` | - | Use K8s remote workers in config-store-backed multitenant mode | | `--config-store` | `DUCKGRES_CONFIG_STORE` | PostgreSQL config-store connection string required for remote mode | +| `--handover-drain-timeout` | `DUCKGRES_HANDOVER_DRAIN_TIMEOUT` | Max time to drain planned shutdowns/upgrades before forced exit (`15m` default in remote mode) | | `--k8s-worker-image` | `DUCKGRES_K8S_WORKER_IMAGE` | Docker image for worker pods | | `--k8s-worker-image-pull-policy` | `DUCKGRES_K8S_WORKER_IMAGE_PULL_POLICY` | Image pull policy (`Never`, `IfNotPresent`, `Always`) | | `--k8s-worker-secret` | `DUCKGRES_K8S_WORKER_SECRET` | K8s Secret name for bearer token | @@ -70,6 +72,15 @@ Key flags for Kubernetes multitenant mode: The bearer token secret is used to authenticate gRPC connections between the control plane and workers. If the secret exists but is empty, the CP auto-generates a random token and populates it. +For seamless planned deployments, use a rolling strategy with overlap and enough termination grace period for drain completion. The provided control-plane manifests now set: + +- `strategy.rollingUpdate.maxUnavailable: 0` +- `strategy.rollingUpdate.maxSurge: 1` +- `terminationGracePeriodSeconds: 900` +- `--handover-drain-timeout 15m` + +That gives the old replica time to fail readiness, stop taking new pgwire sessions, keep existing pgwire and Flight sessions alive during the drain window, and then force shutdown at the timeout boundary if sessions remain. + ## Local Development with kind The primary shared warm-worker workflow now uses [`kind`](https://kind.sigs.k8s.io/). Prerequisites: Docker, `kubectl`, `kind`, and `just`. diff --git a/k8s/control-plane-deployment.yaml b/k8s/control-plane-deployment.yaml deleted file mode 100644 index e2ab3c44..00000000 --- a/k8s/control-plane-deployment.yaml +++ /dev/null @@ -1,107 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: duckgres-control-plane - namespace: duckgres - labels: - app: duckgres-control-plane -spec: - replicas: 1 - selector: - matchLabels: - app: duckgres-control-plane - template: - metadata: - labels: - app: duckgres-control-plane - spec: - serviceAccountName: duckgres-control-plane - securityContext: - runAsNonRoot: true - runAsUser: 1000 - containers: - - name: control-plane - # Replace with your duckgres image built with -tags kubernetes - image: duckgres:latest - imagePullPolicy: IfNotPresent - args: - - "--mode" - - "control-plane" - - "--worker-backend" - - "remote" - - "--config-store" - - "postgres://duckgres:duckgres@host.docker.internal:5434/duckgres_config?sslmode=disable" - - "--config-poll-interval" - - "2s" - - "--k8s-worker-image" - - "duckgres:latest" - - "--k8s-worker-image-pull-policy" - - "Never" - - "--k8s-worker-secret" - - "duckgres-worker-token" - - "--k8s-worker-configmap" - - "duckgres-config" - - "--k8s-shared-warm-target" - - "1" - - "--cert" - - "/certs/server.crt" - - "--key" - - "/certs/server.key" - - "--config" - - "/etc/duckgres/duckgres.yaml" - ports: - - name: pg - containerPort: 5432 - protocol: TCP - - name: admin - containerPort: 9090 - protocol: TCP - readinessProbe: - httpGet: - path: /health - port: admin - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 15 - volumeMounts: - - name: config - mountPath: /etc/duckgres - readOnly: true - - name: certs - mountPath: /certs - - name: data - mountPath: /data - securityContext: - allowPrivilegeEscalation: false - resources: - requests: - cpu: "100m" - memory: "128Mi" - limits: - cpu: "500m" - memory: "256Mi" - volumes: - - name: config - configMap: - name: duckgres-config - - name: certs - emptyDir: {} - - name: data - emptyDir: {} ---- -apiVersion: v1 -kind: Service -metadata: - name: duckgres - namespace: duckgres - labels: - app: duckgres-control-plane -spec: - type: ClusterIP - ports: - - name: pg - port: 5432 - targetPort: pg - protocol: TCP - selector: - app: duckgres-control-plane diff --git a/k8s/control-plane-multitenant-local.yaml b/k8s/control-plane-multitenant-local.yaml index 16f4cee0..1954707f 100644 --- a/k8s/control-plane-multitenant-local.yaml +++ b/k8s/control-plane-multitenant-local.yaml @@ -7,6 +7,11 @@ metadata: app: duckgres-control-plane spec: replicas: 1 + strategy: + type: RollingUpdate + rollingUpdate: + maxUnavailable: 0 + maxSurge: 1 selector: matchLabels: app: duckgres-control-plane @@ -15,6 +20,7 @@ spec: labels: app: duckgres-control-plane spec: + terminationGracePeriodSeconds: 900 serviceAccountName: duckgres-control-plane securityContext: runAsNonRoot: true @@ -23,6 +29,15 @@ spec: - name: control-plane image: duckgres:test imagePullPolicy: IfNotPresent + env: + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid args: - "--mode" - "control-plane" @@ -32,6 +47,8 @@ spec: - "postgres://duckgres:duckgres@host.docker.internal:5434/duckgres_config?sslmode=disable" - "--config-poll-interval" - "2s" + - "--handover-drain-timeout" + - "15m" - "--k8s-worker-image" - "duckgres:test" - "--k8s-worker-image-pull-policy" diff --git a/k8s/kind/control-plane.yaml b/k8s/kind/control-plane.yaml index 2a1334c8..2d393eda 100644 --- a/k8s/kind/control-plane.yaml +++ b/k8s/kind/control-plane.yaml @@ -7,6 +7,11 @@ metadata: app: duckgres-control-plane spec: replicas: 1 + strategy: + type: RollingUpdate + rollingUpdate: + maxUnavailable: 0 + maxSurge: 1 selector: matchLabels: app: duckgres-control-plane @@ -15,6 +20,7 @@ spec: labels: app: duckgres-control-plane spec: + terminationGracePeriodSeconds: 900 serviceAccountName: duckgres-control-plane securityContext: runAsNonRoot: true @@ -23,6 +29,15 @@ spec: - name: control-plane image: duckgres:test imagePullPolicy: IfNotPresent + env: + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid args: - "--mode" - "control-plane" @@ -32,6 +47,8 @@ spec: - "postgres://duckgres:duckgres@duckgres-config-store:5432/duckgres_config?sslmode=disable" - "--config-poll-interval" - "2s" + - "--handover-drain-timeout" + - "15m" - "--k8s-worker-image" - "duckgres:test" - "--k8s-worker-image-pull-policy" diff --git a/k8s/rbac.yaml b/k8s/rbac.yaml index 01be55d9..a9485697 100644 --- a/k8s/rbac.yaml +++ b/k8s/rbac.yaml @@ -14,6 +14,10 @@ rules: - apiGroups: [""] resources: ["pods"] verbs: ["create", "delete", "get", "list", "watch"] + # Coordinate leader election for janitor work + - apiGroups: ["coordination.k8s.io"] + resources: ["leases"] + verbs: ["create", "delete", "get", "list", "patch", "update", "watch"] # Read bearer token secret; create if auto-generating; update if key missing - apiGroups: [""] resources: ["secrets"] diff --git a/k8s_manifest_test.go b/k8s_manifest_test.go new file mode 100644 index 00000000..12137257 --- /dev/null +++ b/k8s_manifest_test.go @@ -0,0 +1,60 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + + "gopkg.in/yaml.v3" +) + +type deploymentManifest struct { + Kind string `yaml:"kind"` + Spec struct { + Template struct { + Spec struct { + Containers []struct { + Name string `yaml:"name"` + ReadinessProbe struct { + HTTPGet struct { + Path string `yaml:"path"` + Port any `yaml:"port"` + } `yaml:"httpGet"` + } `yaml:"readinessProbe"` + } `yaml:"containers"` + } `yaml:"spec"` + } `yaml:"template"` + } `yaml:"spec"` +} + +func TestLiveControlPlaneManifestsReadinessProbeTargetsAPIHealthEndpoint(t *testing.T) { + paths := []string{ + filepath.Join("k8s", "control-plane-multitenant-local.yaml"), + filepath.Join("k8s", "kind", "control-plane.yaml"), + } + + for _, path := range paths { + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile(%s): %v", path, err) + } + + var doc deploymentManifest + if err := yaml.Unmarshal(data, &doc); err != nil { + t.Fatalf("Unmarshal(%s): %v", path, err) + } + if doc.Kind != "Deployment" { + t.Fatalf("%s: expected first manifest document to be a Deployment, got %q", path, doc.Kind) + } + if len(doc.Spec.Template.Spec.Containers) == 0 { + t.Fatalf("%s: expected at least one container in deployment manifest", path) + } + probe := doc.Spec.Template.Spec.Containers[0].ReadinessProbe.HTTPGet + if probe.Path != "/health" { + t.Fatalf("%s: expected readiness probe path /health, got %q", path, probe.Path) + } + if port, ok := probe.Port.(string); !ok || port != "api" { + t.Fatalf("%s: expected readiness probe port api, got %#v", path, probe.Port) + } + } +} diff --git a/main.go b/main.go index 2e07d82b..af7c48f5 100644 --- a/main.go +++ b/main.go @@ -223,7 +223,7 @@ func main() { processMaxWorkers := flag.Int("process-max-workers", 0, "Max process workers, 0=auto-derived (control-plane mode) (env: DUCKGRES_PROCESS_MAX_WORKERS)") workerQueueTimeout := flag.String("worker-queue-timeout", "", "How long to wait for an available worker slot (e.g., '5m') (env: DUCKGRES_WORKER_QUEUE_TIMEOUT)") workerIdleTimeout := flag.String("worker-idle-timeout", "", "How long to keep an idle worker alive (e.g., '5m') (env: DUCKGRES_WORKER_IDLE_TIMEOUT)") - handoverDrainTimeout := flag.String("handover-drain-timeout", "", "How long to wait for connections to drain during handover (default: '24h') (env: DUCKGRES_HANDOVER_DRAIN_TIMEOUT)") + handoverDrainTimeout := flag.String("handover-drain-timeout", "", "How long to wait for planned shutdowns/upgrades to drain before forcing exit (default: '24h' in process mode, '15m' in remote mode) (env: DUCKGRES_HANDOVER_DRAIN_TIMEOUT)") socketDir := flag.String("socket-dir", "/var/run/duckgres", "Unix socket directory (control-plane mode)") workerBackend := flag.String("worker-backend", "", "Worker backend: process (default) or remote for config-store-backed K8s multitenant mode (env: DUCKGRES_WORKER_BACKEND)") k8sWorkerImage := flag.String("k8s-worker-image", "", "Container image for K8s worker pods (env: DUCKGRES_K8S_WORKER_IMAGE)") @@ -285,7 +285,7 @@ func main() { fmt.Fprintf(os.Stderr, " DUCKGRES_PROCESS_MIN_WORKERS Pre-warm worker count for process workers (control-plane mode)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_PROCESS_MAX_WORKERS Max process workers (control-plane mode)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_WORKER_QUEUE_TIMEOUT Worker queue timeout (default: 5m)\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_HANDOVER_DRAIN_TIMEOUT Handover drain timeout (default: 24h)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_HANDOVER_DRAIN_TIMEOUT Planned shutdown/upgrade drain timeout (default: 24h in process mode, 15m in remote mode)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_ACME_DOMAIN Domain for ACME/Let's Encrypt certificate\n") fmt.Fprintf(os.Stderr, " DUCKGRES_ACME_EMAIL Contact email for Let's Encrypt notifications\n") fmt.Fprintf(os.Stderr, " DUCKGRES_ACME_CACHE_DIR Directory for ACME certificate cache\n") diff --git a/server/flight_executor.go b/server/flight_executor.go index 5c8bf900..e1a7157f 100644 --- a/server/flight_executor.go +++ b/server/flight_executor.go @@ -11,6 +11,7 @@ import ( "sync" "sync/atomic" "time" + "strconv" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" @@ -44,6 +45,9 @@ type OrderedMapValue struct { type FlightExecutor struct { client *flightsql.Client sessionToken string + workerID int + cpInstanceID string + ownerEpoch int64 alloc memory.Allocator ownsClient bool // if true, Close() closes the client @@ -80,6 +84,7 @@ func NewFlightExecutor(addr, bearerToken, sessionToken string) (*FlightExecutor, return &FlightExecutor{ client: client, sessionToken: sessionToken, + ownerEpoch: 0, alloc: memory.DefaultAllocator, ownsClient: true, ctx: ctx, @@ -95,6 +100,7 @@ func NewFlightExecutorFromClient(client *flightsql.Client, sessionToken string) return &FlightExecutor{ client: client, sessionToken: sessionToken, + ownerEpoch: 0, alloc: memory.DefaultAllocator, ownsClient: false, ctx: ctx, @@ -115,7 +121,23 @@ func (e *FlightExecutor) IsDead() bool { // withSession adds the session token to the gRPC context. func (e *FlightExecutor) withSession(ctx context.Context) context.Context { - return metadata.AppendToOutgoingContext(ctx, "x-duckgres-session", e.sessionToken) + return metadata.AppendToOutgoingContext( + ctx, + "x-duckgres-session", e.sessionToken, + "x-duckgres-worker-id", strconv.Itoa(e.workerID), + "x-duckgres-cp-instance-id", e.cpInstanceID, + "x-duckgres-owner-epoch", strconv.FormatInt(e.ownerEpoch, 10), + ) +} + +func (e *FlightExecutor) SetOwnerEpoch(ownerEpoch int64) { + e.ownerEpoch = ownerEpoch +} + +func (e *FlightExecutor) SetControlMetadata(workerID int, cpInstanceID string, ownerEpoch int64) { + e.workerID = workerID + e.cpInstanceID = cpInstanceID + e.ownerEpoch = ownerEpoch } // recoverClientPanic converts a nil-pointer panic from a closed Flight SQL diff --git a/server/flight_executor_test.go b/server/flight_executor_test.go index 8f3a2339..a60ddeae 100644 --- a/server/flight_executor_test.go +++ b/server/flight_executor_test.go @@ -1,97 +1,32 @@ package server import ( - "encoding/hex" + "context" "testing" - "github.com/apache/arrow-go/v18/arrow" + "google.golang.org/grpc/metadata" ) -func TestEmptyRowSetColumnTypes(t *testing.T) { - // Before fix: emptyRowSet returns nil ColumnTypes - e := &emptyRowSet{} - cols, _ := e.Columns() - colTypes, _ := e.ColumnTypes() +func TestFlightExecutorWithSessionAddsOwnerEpochHeader(t *testing.T) { + exec := NewFlightExecutorFromClient(nil, "session-1") + exec.SetOwnerEpoch(7) + exec.SetControlMetadata(17, "cp-live:boot-a", 7) - if cols != nil { - t.Errorf("emptyRowSet.Columns() = %v, want nil", cols) + ctx := exec.withSession(context.Background()) + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + t.Fatal("expected outgoing metadata") } - if colTypes != nil { - t.Errorf("emptyRowSet.ColumnTypes() = %v, want nil", colTypes) + if got := md.Get("x-duckgres-session"); len(got) != 1 || got[0] != "session-1" { + t.Fatalf("unexpected session metadata: %#v", got) } - - // Simulate BLOB detection with nil colTypes (before fix) - var blobColIndices []int - for i, ct := range colTypes { - if ct.DatabaseTypeName() == "BLOB" { - blobColIndices = append(blobColIndices, i) - } - } - if len(blobColIndices) > 0 { - t.Error("should not detect BLOB with nil colTypes") - } - t.Log("BEFORE fix: emptyRowSet returns nil colTypes → BLOB detection finds nothing → falls through to broken CSV path") -} - -func TestEmptySchemaRowSetColumnTypes(t *testing.T) { - // After fix: emptySchemaRowSet preserves schema - schema := arrow.NewSchema([]arrow.Field{ - {Name: "id", Type: arrow.BinaryTypes.String}, - {Name: "data", Type: arrow.BinaryTypes.Binary}, - {Name: "count", Type: arrow.PrimitiveTypes.Int32}, - }, nil) - - e := &emptySchemaRowSet{schema: schema} - cols, err := e.Columns() - if err != nil { - t.Fatal(err) - } - colTypes, err := e.ColumnTypes() - if err != nil { - t.Fatal(err) - } - - // Verify column names - if len(cols) != 3 || cols[0] != "id" || cols[1] != "data" || cols[2] != "count" { - t.Errorf("Columns() = %v, want [id data count]", cols) + if got := md.Get("x-duckgres-owner-epoch"); len(got) != 1 || got[0] != "7" { + t.Fatalf("unexpected owner epoch metadata: %#v", got) } - - // Verify column types - expectedTypes := []string{"VARCHAR", "BLOB", "INTEGER"} - for i, ct := range colTypes { - if ct.DatabaseTypeName() != expectedTypes[i] { - t.Errorf("colTypes[%d].DatabaseTypeName() = %q, want %q", i, ct.DatabaseTypeName(), expectedTypes[i]) - } - } - - // Simulate BLOB detection with real colTypes (after fix) - var blobColIndices []int - for i, ct := range colTypes { - if ct.DatabaseTypeName() == "BLOB" { - blobColIndices = append(blobColIndices, i) - } + if got := md.Get("x-duckgres-worker-id"); len(got) != 1 || got[0] != "17" { + t.Fatalf("unexpected worker id metadata: %#v", got) } - if len(blobColIndices) != 1 || blobColIndices[0] != 1 { - t.Errorf("BLOB detection = %v, want [1]", blobColIndices) - } - t.Log("AFTER fix: emptySchemaRowSet returns real colTypes → BLOB detected at index 1 → triggers CSV-with-BLOB fallback") -} - -func TestFormatArgValue_Blob(t *testing.T) { - data := []byte{0xDE, 0xAD, 0xBE, 0xEF} - got := formatArgValue(data) - want := `'\x` + hex.EncodeToString(data) + `'::BLOB` - if got != want { - t.Errorf("formatArgValue([]byte) = %q, want %q", got, want) - } -} - -func TestInterpolateArgs_BlobParam(t *testing.T) { - data := []byte{0x01, 0x02, 0x03} - query := "INSERT INTO t (col) VALUES ($1)" - got := interpolateArgs(query, []any{data}) - want := `INSERT INTO t (col) VALUES ('\x010203'::BLOB)` - if got != want { - t.Errorf("interpolateArgs blob = %q, want %q", got, want) + if got := md.Get("x-duckgres-cp-instance-id"); len(got) != 1 || got[0] != "cp-live:boot-a" { + t.Fatalf("unexpected cp instance metadata: %#v", got) } } diff --git a/server/flightsqlingress/ingress.go b/server/flightsqlingress/ingress.go index 40ecdd9e..2c1accff 100644 --- a/server/flightsqlingress/ingress.go +++ b/server/flightsqlingress/ingress.go @@ -7,6 +7,7 @@ import ( "database/sql" "encoding/base64" "encoding/hex" + "errors" "fmt" "log/slog" "net" @@ -39,6 +40,15 @@ const ( defaultFlightSessionHeaderKey = "x-duckgres-session" ) +var ErrDurableReconnectTerminal = errors.New("durable reconnect terminal") + +func MarkDurableReconnectTerminal(err error) error { + if err == nil { + return nil + } + return fmt.Errorf("%w: %w", ErrDurableReconnectTerminal, err) +} + const ( ReapTriggerPeriodic = "periodic" ReapTriggerForced = "forced" @@ -57,6 +67,53 @@ type SessionProvider interface { DestroySession(int32) } +type DurableSessionState string + +const ( + DurableSessionStateActive DurableSessionState = "active" + DurableSessionStateClosed DurableSessionState = "closed" + DurableSessionStateExpired DurableSessionState = "expired" +) + +type DurableSessionMetadata struct { + Username string + OrgID string + WorkerID int + OwnerEpoch int64 + CPInstanceID string +} + +type DurableSessionRecord struct { + SessionToken string + Username string + OrgID string + WorkerID int + OwnerEpoch int64 + CPInstanceID string + State DurableSessionState + ExpiresAt time.Time + LastSeenAt time.Time +} + +type DurableSessionStore interface { + UpsertSession(record DurableSessionRecord) error + GetSession(sessionToken string) (*DurableSessionRecord, error) + TouchSession(sessionToken string, lastSeenAt time.Time) error + CloseSession(sessionToken string, closedAt time.Time) error +} + +type sessionMetadataProvider interface { + DurableSessionMetadata(pid int32, username string) (DurableSessionMetadata, error) +} + +type sessionReconnector interface { + ReconnectSession(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) +} + +type durableSessionStoreProvider interface { + DurableSessionStore() DurableSessionStore +} + // CredentialValidator abstracts username/password authentication. type CredentialValidator interface { ValidateCredentials(username, password string) bool @@ -174,6 +231,20 @@ func (fi *FlightIngress) Start() { }() } +func (fi *FlightIngress) BeginDrain() { + if fi == nil || fi.sessionStore == nil { + return + } + fi.sessionStore.SetDraining(true) +} + +func (fi *FlightIngress) WaitForZeroSessions(ctx context.Context) bool { + if fi == nil || fi.sessionStore == nil { + return true + } + return fi.sessionStore.WaitForZeroSessions(ctx) +} + // Shutdown stops accepting new Flight connections and cleans up sessions. func (fi *FlightIngress) Shutdown() { if fi == nil { @@ -181,6 +252,7 @@ func (fi *FlightIngress) Shutdown() { } fi.shutdownOnce.Do(func() { fi.shutdownState.Store(true) + fi.BeginDrain() if fi.listener != nil { _ = fi.listener.Close() } @@ -242,7 +314,16 @@ func (h *ControlPlaneFlightSQLHandler) sessionFromContextWithTokenMetadata(ctx c } if sessionToken := incomingSessionToken(md); sessionToken != "" { - s, ok := h.sessions.GetByToken(sessionToken) + var authenticatedUsername string + if hasAuthorizationHeader(md) { + username, err := h.authenticateBasicCredentials(md, remoteAddr) + if err != nil { + return nil, err + } + authenticatedUsername = username + } + + s, ok := h.sessions.GetByTokenContext(ctx, sessionToken) if !ok { server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr) observeFlightIngressSessionOutcome("token_invalid") @@ -251,12 +332,8 @@ func (h *ControlPlaneFlightSQLHandler) sessionFromContextWithTokenMetadata(ctx c // When Basic auth is included alongside a bearer session token, enforce // principal consistency. Token-only auth is allowed after bootstrap. - if hasAuthorizationHeader(md) { - username, err := h.authenticateBasicCredentials(md, remoteAddr) - if err != nil { - return nil, err - } - if username != s.username { + if authenticatedUsername != "" { + if authenticatedUsername != s.username { server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr) observeFlightIngressSessionOutcome("auth_failed") return nil, status.Error(codes.PermissionDenied, "session token does not match authenticated user") @@ -1017,6 +1094,7 @@ type flightClientSession struct { lastUsed atomic.Int64 // tokenIssuedAt stores when this token was issued; used for absolute token TTL. tokenIssuedAt atomic.Int64 + expiresAt atomic.Int64 counter atomic.Uint64 streams atomic.Int32 @@ -1172,6 +1250,9 @@ type flightAuthSessionStore struct { createSessionFn func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) destroySessionFn func(int32) + metadataProvider sessionMetadataProvider + reconnector sessionReconnector + durableStore DurableSessionStore mu sync.RWMutex sessions map[string]*flightClientSession // session token -> session @@ -1180,6 +1261,8 @@ type flightAuthSessionStore struct { stopOnce sync.Once stopCh chan struct{} doneCh chan struct{} + + draining atomic.Bool } type lockedRowSet struct { @@ -1203,6 +1286,18 @@ func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, createFn = provider.CreateSession destroyFn = provider.DestroySession } + var metadataProvider sessionMetadataProvider + if p, ok := provider.(sessionMetadataProvider); ok { + metadataProvider = p + } + var reconnector sessionReconnector + if p, ok := provider.(sessionReconnector); ok { + reconnector = p + } + var durableStore DurableSessionStore + if p, ok := provider.(durableSessionStoreProvider); ok { + durableStore = p.DurableSessionStore() + } s := &flightAuthSessionStore{ provider: provider, @@ -1214,6 +1309,9 @@ func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, hooks: opts.Hooks, createSessionFn: createFn, destroySessionFn: destroyFn, + metadataProvider: metadataProvider, + reconnector: reconnector, + durableStore: durableStore, sessions: make(map[string]*flightClientSession), byKey: make(map[string]string), stopCh: make(chan struct{}), @@ -1224,6 +1322,9 @@ func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, } func (s *flightAuthSessionStore) Create(ctx context.Context, username string) (*flightClientSession, error) { + if s.Draining() { + return nil, fmt.Errorf("flight ingress is draining") + } bootstrapNonce, err := generateSessionIdentityToken() if err != nil { return nil, fmt.Errorf("generate bootstrap nonce: %w", err) @@ -1313,16 +1414,25 @@ func (s *flightAuthSessionStore) GetOrCreate(ctx context.Context, key, username created.token = token } s.sessions[created.token] = created - created.tokenIssuedAt.Store(time.Now().UnixNano()) + now := time.Now() + created.tokenIssuedAt.Store(now.UnixNano()) + if s.tokenTTL > 0 { + created.expiresAt.Store(now.Add(s.tokenTTL).UnixNano()) + } s.byKey[key] = created.token sessionCount := len(s.sessions) s.mu.Unlock() s.notifySessionCountChanged(sessionCount) + s.persistSession(created, username) return created, nil } func (s *flightAuthSessionStore) GetByToken(token string) (*flightClientSession, bool) { + return s.GetByTokenContext(context.Background(), token) +} + +func (s *flightAuthSessionStore) GetByTokenContext(ctx context.Context, token string) (*flightClientSession, bool) { token = strings.TrimSpace(token) if token == "" { return nil, false @@ -1341,6 +1451,25 @@ func (s *flightAuthSessionStore) GetByToken(token string) (*flightClientSession, session, ok = s.sessions[token] if !ok { s.mu.Unlock() + return s.reconnectByToken(ctx, token) + } + + expiresAtRaw := session.expiresAt.Load() + if expiresAtRaw > 0 && time.Now().After(time.Unix(0, expiresAtRaw)) { + delete(s.sessions, token) + s.removeByKeyForTokenLocked(token) + expiredSession = session + postExpireCount = len(s.sessions) + destroyFn := s.destroySessionFn + s.mu.Unlock() + + if destroyFn != nil { + destroyFn(expiredSession.pid) + } + if s.durableStore != nil { + _ = s.durableStore.CloseSession(token, time.Now()) + } + s.notifySessionCountChanged(postExpireCount) return nil, false } @@ -1358,11 +1487,17 @@ func (s *flightAuthSessionStore) GetByToken(token string) (*flightClientSession, if destroyFn != nil { destroyFn(expiredSession.pid) } + if s.durableStore != nil { + _ = s.durableStore.CloseSession(token, time.Now()) + } s.notifySessionCountChanged(postExpireCount) return nil, false } } s.mu.Unlock() + if s.durableStore != nil { + _ = s.durableStore.TouchSession(token, time.Now()) + } return session, true } @@ -1395,6 +1530,9 @@ func (s *flightAuthSessionStore) CloseByToken(token string) bool { if destroyFn != nil { destroyFn(session.pid) } + if s.durableStore != nil { + _ = s.durableStore.CloseSession(token, time.Now()) + } s.notifySessionCountChanged(sessionCount) return true } @@ -1462,10 +1600,57 @@ func (s *flightAuthSessionStore) Close() { for _, cs := range sessions { s.destroySessionFn(cs.pid) + if s.durableStore != nil { + _ = s.durableStore.CloseSession(cs.token, time.Now()) + } } }) } +func (s *flightAuthSessionStore) SetDraining(draining bool) { + if s == nil { + return + } + s.draining.Store(draining) +} + +func (s *flightAuthSessionStore) Draining() bool { + if s == nil { + return false + } + return s.draining.Load() +} + +func (s *flightAuthSessionStore) ActiveSessionCount() int { + if s == nil { + return 0 + } + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.sessions) +} + +func (s *flightAuthSessionStore) WaitForZeroSessions(ctx context.Context) bool { + if s == nil { + return true + } + if s.ActiveSessionCount() == 0 { + return true + } + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return s.ActiveSessionCount() == 0 + case <-ticker.C: + if s.ActiveSessionCount() == 0 { + return true + } + } + } +} + func (s *flightAuthSessionStore) ReapIdleNow() int { return s.reapIdle(time.Now(), ReapTriggerForced) } @@ -1501,6 +1686,9 @@ func (s *flightAuthSessionStore) reapIdle(now time.Time, trigger string) int { for _, cs := range stale { s.destroySessionFn(cs.pid) + if s.durableStore != nil { + _ = s.durableStore.CloseSession(cs.token, now) + } } reaped := len(stale) if reaped > 0 { @@ -1510,6 +1698,76 @@ func (s *flightAuthSessionStore) reapIdle(now time.Time, trigger string) int { return reaped } +func (s *flightAuthSessionStore) persistSession(session *flightClientSession, username string) { + if s == nil || s.durableStore == nil || s.metadataProvider == nil || session == nil { + return + } + meta, err := s.metadataProvider.DurableSessionMetadata(session.pid, username) + if err != nil { + slog.Warn("Persisting durable Flight session metadata failed.", "pid", session.pid, "error", err) + return + } + record := DurableSessionRecord{ + SessionToken: session.token, + Username: username, + OrgID: meta.OrgID, + WorkerID: meta.WorkerID, + OwnerEpoch: meta.OwnerEpoch, + CPInstanceID: meta.CPInstanceID, + State: DurableSessionStateActive, + ExpiresAt: time.Unix(0, session.expiresAt.Load()), + LastSeenAt: time.Now(), + } + if err := s.durableStore.UpsertSession(record); err != nil { + slog.Warn("Persisting durable Flight session record failed.", "pid", session.pid, "error", err) + } +} + +func (s *flightAuthSessionStore) reconnectByToken(ctx context.Context, token string) (*flightClientSession, bool) { + if s == nil || s.durableStore == nil || s.reconnector == nil { + return nil, false + } + record, err := s.durableStore.GetSession(token) + if err != nil { + slog.Warn("Loading durable Flight session record failed.", "token", token, "error", err) + return nil, false + } + if record == nil { + return nil, false + } + if record.State != DurableSessionStateActive { + return nil, false + } + if !record.ExpiresAt.IsZero() && time.Now().After(record.ExpiresAt) { + _ = s.durableStore.CloseSession(token, time.Now()) + return nil, false + } + pid, executor, err := s.reconnector.ReconnectSession(ctx, *record) + if err != nil { + slog.Warn("Reconnecting durable Flight session failed.", "token", token, "error", err) + if errors.Is(err, ErrDurableReconnectTerminal) { + _ = s.durableStore.CloseSession(token, time.Now()) + } + return nil, false + } + + session := newFlightClientSession(pid, record.Username, executor) + session.token = token + session.tokenIssuedAt.Store(time.Now().UnixNano()) + if !record.ExpiresAt.IsZero() { + session.expiresAt.Store(record.ExpiresAt.UnixNano()) + } + + s.mu.Lock() + s.ensureMapsLocked() + s.sessions[token] = session + sessionCount := len(s.sessions) + s.mu.Unlock() + s.notifySessionCountChanged(sessionCount) + s.persistSession(session, record.Username) + return session, true +} + func (s *flightAuthSessionStore) removeByKeyForTokenLocked(token string) { for key, mappedToken := range s.byKey { if mappedToken == token { diff --git a/server/flightsqlingress/ingress_test.go b/server/flightsqlingress/ingress_test.go index 0b2a3fd6..23a74eac 100644 --- a/server/flightsqlingress/ingress_test.go +++ b/server/flightsqlingress/ingress_test.go @@ -29,6 +29,94 @@ type testExecResult struct { err error } +type captureDurableSessionStore struct { + mu sync.Mutex + records map[string]DurableSessionRecord + closed []string + touched []string +} + +func (s *captureDurableSessionStore) UpsertSession(record DurableSessionRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.records == nil { + s.records = make(map[string]DurableSessionRecord) + } + s.records[record.SessionToken] = record + return nil +} + +func (s *captureDurableSessionStore) GetSession(sessionToken string) (*DurableSessionRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + record, ok := s.records[sessionToken] + if !ok { + return nil, nil + } + copy := record + return ©, nil +} + +func (s *captureDurableSessionStore) TouchSession(sessionToken string, lastSeenAt time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + record, ok := s.records[sessionToken] + if !ok { + return nil + } + record.LastSeenAt = lastSeenAt + s.records[sessionToken] = record + s.touched = append(s.touched, sessionToken) + return nil +} + +func (s *captureDurableSessionStore) CloseSession(sessionToken string, closedAt time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + record, ok := s.records[sessionToken] + if !ok { + return nil + } + record.State = DurableSessionStateClosed + record.LastSeenAt = closedAt + s.records[sessionToken] = record + s.closed = append(s.closed, sessionToken) + return nil +} + +type testDurableSessionProvider struct { + createSessionFn func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) + destroySessionFn func(int32) + metadataFn func(pid int32, username string) (DurableSessionMetadata, error) + reconnectSessionFn func(context.Context, DurableSessionRecord) (int32, *server.FlightExecutor, error) + durableStore DurableSessionStore +} + +func (p *testDurableSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) { + return p.createSessionFn(ctx, username, pid, memoryLimit, threads) +} + +func (p *testDurableSessionProvider) DestroySession(pid int32) { + if p.destroySessionFn != nil { + p.destroySessionFn(pid) + } +} + +func (p *testDurableSessionProvider) DurableSessionMetadata(pid int32, username string) (DurableSessionMetadata, error) { + if p.metadataFn == nil { + return DurableSessionMetadata{}, fmt.Errorf("durable session metadata is not configured") + } + return p.metadataFn(pid, username) +} + +func (p *testDurableSessionProvider) ReconnectSession(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) { + return p.reconnectSessionFn(ctx, record) +} + +func (p *testDurableSessionProvider) DurableSessionStore() DurableSessionStore { + return p.durableStore +} + type testServerTransportStream struct { header metadata.MD trailer metadata.MD @@ -652,6 +740,359 @@ func TestFlightSessionTokenLifecycleIssueValidateRevokeExpiryMatrix(t *testing.T }) } +func TestFlightAuthSessionStorePersistsDurableSessionRecordOnCreate(t *testing.T) { + durable := &captureDurableSessionStore{} + provider := &testDurableSessionProvider{ + durableStore: durable, + createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + return 4321, nil, nil + }, + metadataFn: func(pid int32, username string) (DurableSessionMetadata, error) { + return DurableSessionMetadata{ + Username: username, + OrgID: "analytics", + WorkerID: 17, + OwnerEpoch: 3, + CPInstanceID: "cp-new:boot-a", + }, nil + }, + } + store := newFlightAuthSessionStore(provider, time.Minute, time.Hour, time.Minute, time.Hour, 0, Options{}) + + session, err := store.Create(context.Background(), "postgres") + if err != nil { + t.Fatalf("Create: %v", err) + } + + record, err := durable.GetSession(session.token) + if err != nil { + t.Fatalf("GetSession: %v", err) + } + if record == nil { + t.Fatal("expected durable session record to be persisted") + } + if record.Username != "postgres" { + t.Fatalf("expected username postgres, got %q", record.Username) + } + if record.OrgID != "analytics" { + t.Fatalf("expected org analytics, got %q", record.OrgID) + } + if record.WorkerID != 17 { + t.Fatalf("expected worker id 17, got %d", record.WorkerID) + } + if record.OwnerEpoch != 3 { + t.Fatalf("expected owner epoch 3, got %d", record.OwnerEpoch) + } + if record.CPInstanceID != "cp-new:boot-a" { + t.Fatalf("expected cp_instance_id cp-new:boot-a, got %q", record.CPInstanceID) + } + if record.State != DurableSessionStateActive { + t.Fatalf("expected active durable session state, got %q", record.State) + } + if record.ExpiresAt.IsZero() { + t.Fatal("expected durable session expiry to be set") + } +} + +func TestFlightAuthSessionStoreReconnectsDurableSessionByToken(t *testing.T) { + durable := &captureDurableSessionStore{ + records: map[string]DurableSessionRecord{ + "durable-token": { + SessionToken: "durable-token", + Username: "postgres", + OrgID: "analytics", + WorkerID: 17, + OwnerEpoch: 4, + CPInstanceID: "cp-old:boot-a", + State: DurableSessionStateActive, + ExpiresAt: time.Now().Add(time.Hour), + LastSeenAt: time.Now().Add(-time.Minute), + }, + }, + } + var reconnected DurableSessionRecord + provider := &testDurableSessionProvider{ + durableStore: durable, + metadataFn: func(pid int32, username string) (DurableSessionMetadata, error) { + return DurableSessionMetadata{ + Username: username, + OrgID: "analytics", + WorkerID: 17, + OwnerEpoch: 4, + CPInstanceID: "cp-old:boot-a", + }, nil + }, + createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + return 0, nil, fmt.Errorf("unexpected create path") + }, + reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) { + reconnected = record + return 9876, nil, nil + }, + } + store := newFlightAuthSessionStore(provider, time.Minute, time.Hour, time.Minute, time.Hour, 0, Options{}) + + session, ok := store.GetByTokenContext(context.Background(), "durable-token") + if !ok { + t.Fatal("expected durable token reconnect to succeed") + } + if session.pid != 9876 { + t.Fatalf("expected reconnected pid 9876, got %d", session.pid) + } + if reconnected.SessionToken != "durable-token" { + t.Fatalf("expected reconnect to receive durable-token, got %q", reconnected.SessionToken) + } + if reconnected.WorkerID != 17 || reconnected.OwnerEpoch != 4 { + t.Fatalf("unexpected reconnect record %+v", reconnected) + } + record, err := durable.GetSession("durable-token") + if err != nil { + t.Fatalf("GetSession: %v", err) + } + if record == nil { + t.Fatal("expected durable session record to remain present") + } + if record.State != DurableSessionStateActive { + t.Fatalf("expected durable session to remain active, got %q", record.State) + } +} + +func TestFlightAuthSessionStoreRejectsClosedDurableSessionToken(t *testing.T) { + durable := &captureDurableSessionStore{ + records: map[string]DurableSessionRecord{ + "closed-token": { + SessionToken: "closed-token", + Username: "postgres", + OrgID: "analytics", + WorkerID: 17, + OwnerEpoch: 4, + CPInstanceID: "cp-old:boot-a", + State: DurableSessionStateClosed, + ExpiresAt: time.Now().Add(time.Hour), + LastSeenAt: time.Now().Add(-time.Minute), + }, + }, + } + reconnectCalls := 0 + provider := &testDurableSessionProvider{ + durableStore: durable, + createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + return 0, nil, fmt.Errorf("unexpected create path") + }, + reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) { + reconnectCalls++ + return 9876, nil, nil + }, + } + store := newFlightAuthSessionStore(provider, time.Minute, time.Hour, time.Minute, time.Hour, 0, Options{}) + + if session, ok := store.GetByTokenContext(context.Background(), "closed-token"); ok || session != nil { + t.Fatal("expected closed durable token reconnect to fail") + } + if reconnectCalls != 0 { + t.Fatalf("expected reconnect path to be skipped, got %d calls", reconnectCalls) + } +} + +func TestFlightAuthSessionStoreReconnectRefreshesDurableSessionMetadata(t *testing.T) { + durable := &captureDurableSessionStore{ + records: map[string]DurableSessionRecord{ + "durable-token": { + SessionToken: "durable-token", + Username: "postgres", + OrgID: "analytics", + WorkerID: 17, + OwnerEpoch: 4, + CPInstanceID: "cp-old:boot-a", + State: DurableSessionStateActive, + ExpiresAt: time.Now().Add(time.Hour), + LastSeenAt: time.Now().Add(-time.Minute), + }, + }, + } + provider := &testDurableSessionProvider{ + durableStore: durable, + metadataFn: func(pid int32, username string) (DurableSessionMetadata, error) { + if pid != 9876 { + return DurableSessionMetadata{}, fmt.Errorf("unexpected pid %d", pid) + } + return DurableSessionMetadata{ + Username: username, + OrgID: "analytics", + WorkerID: 17, + OwnerEpoch: 5, + CPInstanceID: "cp-new:boot-b", + }, nil + }, + createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + return 0, nil, fmt.Errorf("unexpected create path") + }, + reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) { + return 9876, nil, nil + }, + } + store := newFlightAuthSessionStore(provider, time.Minute, time.Hour, time.Minute, time.Hour, 0, Options{}) + + session, ok := store.GetByTokenContext(context.Background(), "durable-token") + if !ok { + t.Fatal("expected durable token reconnect to succeed") + } + if session.pid != 9876 { + t.Fatalf("expected reconnected pid 9876, got %d", session.pid) + } + + record, err := durable.GetSession("durable-token") + if err != nil { + t.Fatalf("GetSession: %v", err) + } + if record == nil { + t.Fatal("expected durable session record to be present") + } + if record.OwnerEpoch != 5 { + t.Fatalf("expected refreshed owner epoch 5, got %d", record.OwnerEpoch) + } + if record.CPInstanceID != "cp-new:boot-b" { + t.Fatalf("expected refreshed cp_instance_id cp-new:boot-b, got %q", record.CPInstanceID) + } +} + +func TestFlightAuthSessionStoreReconnectFailureUpdatesDurableSessionState(t *testing.T) { + tests := []struct { + name string + reconnectErr error + wantState DurableSessionState + wantReconnectCall int + }{ + { + name: "terminal stale ownership closes durable session", + reconnectErr: MarkDurableReconnectTerminal(errors.New("stale owner")), + wantState: DurableSessionStateClosed, + wantReconnectCall: 1, + }, + { + name: "transient reconnect failure leaves durable session active", + reconnectErr: context.DeadlineExceeded, + wantState: DurableSessionStateActive, + wantReconnectCall: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + durable := &captureDurableSessionStore{ + records: map[string]DurableSessionRecord{ + "durable-token": { + SessionToken: "durable-token", + Username: "postgres", + OrgID: "analytics", + WorkerID: 17, + OwnerEpoch: 4, + CPInstanceID: "cp-old:boot-a", + State: DurableSessionStateActive, + ExpiresAt: time.Now().Add(time.Hour), + LastSeenAt: time.Now().Add(-time.Minute), + }, + }, + } + reconnectCalls := 0 + provider := &testDurableSessionProvider{ + durableStore: durable, + reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) { + reconnectCalls++ + return 0, nil, tt.reconnectErr + }, + } + store := newFlightAuthSessionStore(provider, time.Minute, time.Hour, time.Minute, time.Hour, 0, Options{}) + + if session, ok := store.GetByTokenContext(context.Background(), "durable-token"); ok || session != nil { + t.Fatal("expected durable token reconnect to fail") + } + record, err := durable.GetSession("durable-token") + if err != nil { + t.Fatalf("GetSession: %v", err) + } + if record == nil { + t.Fatal("expected durable session record to remain present") + } + if record.State != tt.wantState { + t.Fatalf("expected durable session state %q, got %q", tt.wantState, record.State) + } + + if session, ok := store.GetByTokenContext(context.Background(), "durable-token"); ok || session != nil { + t.Fatal("expected second durable token lookup to fail") + } + if reconnectCalls != tt.wantReconnectCall { + t.Fatalf("expected %d reconnect attempts, got %d", tt.wantReconnectCall, reconnectCalls) + } + }) + } +} + +func TestFlightAuthSessionStoreRejectsNewSessionsWhileDraining(t *testing.T) { + provider := &testDurableSessionProvider{ + createSessionFn: func(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) { + return 321, nil, nil + }, + } + store := newFlightAuthSessionStore(provider, time.Minute, time.Hour, time.Hour, time.Hour, 0, Options{}) + defer store.Close() + + existing, err := store.Create(context.Background(), "postgres") + if err != nil { + t.Fatalf("Create(initial): %v", err) + } + + store.SetDraining(true) + if _, err := store.Create(context.Background(), "postgres"); err == nil { + t.Fatal("expected Create to reject new sessions while draining") + } + + reused, ok := store.GetByToken(existing.token) + if !ok { + t.Fatal("expected existing token to remain usable while draining") + } + if reused.pid != existing.pid { + t.Fatalf("expected reused pid %d, got %d", existing.pid, reused.pid) + } +} + +func TestFlightAuthSessionStoreWaitForZeroSessions(t *testing.T) { + provider := &testDurableSessionProvider{ + createSessionFn: func(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) { + return 654, nil, nil + }, + } + store := newFlightAuthSessionStore(provider, time.Minute, time.Hour, time.Hour, time.Hour, 0, Options{}) + defer store.Close() + + session, err := store.Create(context.Background(), "postgres") + if err != nil { + t.Fatalf("Create: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + done := make(chan bool, 1) + go func() { + done <- store.WaitForZeroSessions(ctx) + }() + + time.Sleep(25 * time.Millisecond) + if closed := store.CloseByToken(session.token); !closed { + t.Fatal("expected CloseByToken to close the created session") + } + + select { + case ok := <-done: + if !ok { + t.Fatal("expected WaitForZeroSessions to report success") + } + case <-time.After(time.Second): + t.Fatal("WaitForZeroSessions did not return") + } +} + func TestCloseSessionRevokesTokenAndDestroysWorker(t *testing.T) { s := newFlightClientSession(1234, "postgres", nil) s.token = "issued-token" diff --git a/server/worker_activation.go b/server/worker_activation.go index fb484332..9b11e5ad 100644 --- a/server/worker_activation.go +++ b/server/worker_activation.go @@ -1,11 +1,9 @@ package server -import "time" - // WorkerActivationPayload is the tenant runtime material delivered to a shared // warm worker over the control-plane RPC path. type WorkerActivationPayload struct { - OrgID string `json:"org_id"` - LeaseExpiresAt time.Time `json:"lease_expires_at"` - DuckLake DuckLakeConfig `json:"ducklake"` + WorkerControlMetadata + OrgID string `json:"org_id"` + DuckLake DuckLakeConfig `json:"ducklake"` } diff --git a/server/worker_control.go b/server/worker_control.go new file mode 100644 index 00000000..ba967bbd --- /dev/null +++ b/server/worker_control.go @@ -0,0 +1,31 @@ +package server + +// WorkerControlMetadata identifies the logical worker owner on a control-plane +// to worker request. +type WorkerControlMetadata struct { + WorkerID int `json:"worker_id"` + OwnerEpoch int64 `json:"owner_epoch"` + CPInstanceID string `json:"cp_instance_id,omitempty"` +} + +// WorkerCreateSessionPayload is the control-plane request body for creating a +// worker-local session. +type WorkerCreateSessionPayload struct { + WorkerControlMetadata + Username string `json:"username"` + MemoryLimit string `json:"memory_limit"` + Threads int `json:"threads"` +} + +// WorkerDestroySessionPayload is the control-plane request body for destroying +// a worker-local session. +type WorkerDestroySessionPayload struct { + WorkerControlMetadata + SessionToken string `json:"session_token"` +} + +// WorkerHealthCheckPayload is the control-plane request body for health-checking +// a worker. +type WorkerHealthCheckPayload struct { + WorkerControlMetadata +} diff --git a/tests/configstore/runtime_store_postgres_test.go b/tests/configstore/runtime_store_postgres_test.go new file mode 100644 index 00000000..3e2067e6 --- /dev/null +++ b/tests/configstore/runtime_store_postgres_test.go @@ -0,0 +1,674 @@ +//go:build linux || darwin + +package configstore_test + +import ( + "errors" + "strconv" + "testing" + "time" + + "github.com/posthog/duckgres/controlplane/configstore" +) + +func TestRuntimeStorePostgres(t *testing.T) { + store := newIsolatedConfigStore(t) + + runtimeSchema := store.RuntimeSchema() + if runtimeSchema == "" { + t.Fatal("expected runtime schema to be configured") + } + + for _, table := range []string{"cp_instances", "worker_records", "flight_session_records"} { + var count int64 + if err := store.DB().Raw( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ?", + runtimeSchema, + table, + ).Scan(&count).Error; err != nil { + t.Fatalf("lookup %s.%s: %v", runtimeSchema, table, err) + } + if count != 1 { + t.Fatalf("expected runtime table %s.%s to exist", runtimeSchema, table) + } + } + + startedAt := time.Date(2026, time.March, 26, 12, 0, 0, 0, time.UTC) + heartbeatAt := startedAt.Add(5 * time.Second) + if err := store.UpsertControlPlaneInstance(&configstore.ControlPlaneInstance{ + ID: "cp-1:boot-a", + PodName: "duckgres-abc", + PodUID: "pod-uid-1", + BootID: "boot-a", + State: configstore.ControlPlaneInstanceStateActive, + StartedAt: startedAt, + LastHeartbeatAt: heartbeatAt, + }); err != nil { + t.Fatalf("UpsertControlPlaneInstance: %v", err) + } + + cp, err := store.GetControlPlaneInstance("cp-1:boot-a") + if err != nil { + t.Fatalf("GetControlPlaneInstance: %v", err) + } + if cp.PodName != "duckgres-abc" { + t.Fatalf("expected pod name duckgres-abc, got %q", cp.PodName) + } + if !cp.LastHeartbeatAt.Equal(heartbeatAt) { + t.Fatalf("expected heartbeat %v, got %v", heartbeatAt, cp.LastHeartbeatAt) + } + + if err := store.UpsertWorkerRecord(&configstore.WorkerRecord{ + WorkerID: 42, + PodName: "duckgres-worker-42", + State: configstore.WorkerStateIdle, + OwnerCPInstanceID: "cp-1:boot-a", + OwnerEpoch: 7, + LastHeartbeatAt: heartbeatAt, + }); err != nil { + t.Fatalf("UpsertWorkerRecord: %v", err) + } + + worker, err := store.GetWorkerRecord(42) + if err != nil { + t.Fatalf("GetWorkerRecord: %v", err) + } + if worker.State != configstore.WorkerStateIdle { + t.Fatalf("expected worker state idle, got %q", worker.State) + } + if worker.OwnerEpoch != 7 { + t.Fatalf("expected owner epoch 7, got %d", worker.OwnerEpoch) + } + + sessionExpiry := startedAt.Add(5 * time.Minute) + if err := store.UpsertFlightSessionRecord(&configstore.FlightSessionRecord{ + SessionToken: "flight-token-1", + Username: "postgres", + OrgID: "analytics", + WorkerID: 42, + OwnerEpoch: 7, + State: configstore.FlightSessionStateActive, + ExpiresAt: sessionExpiry, + LastSeenAt: heartbeatAt, + }); err != nil { + t.Fatalf("UpsertFlightSessionRecord: %v", err) + } + + session, err := store.GetFlightSessionRecord("flight-token-1") + if err != nil { + t.Fatalf("GetFlightSessionRecord: %v", err) + } + if session.WorkerID != 42 { + t.Fatalf("expected worker id 42, got %d", session.WorkerID) + } + if session.Username != "postgres" { + t.Fatalf("expected username postgres, got %q", session.Username) + } + if session.State != configstore.FlightSessionStateActive { + t.Fatalf("expected session state active, got %q", session.State) + } +} + +func TestClaimIdleWorkerPostgres(t *testing.T) { + store := newIsolatedConfigStore(t) + + startedAt := time.Date(2026, time.March, 26, 13, 0, 0, 0, time.UTC) + heartbeatAt := startedAt.Add(5 * time.Second) + if err := store.UpsertWorkerRecord(&configstore.WorkerRecord{ + WorkerID: 7, + PodName: "duckgres-worker-7", + State: configstore.WorkerStateIdle, + OwnerCPInstanceID: "cp-old:boot-a", + OwnerEpoch: 2, + LastHeartbeatAt: heartbeatAt, + }); err != nil { + t.Fatalf("UpsertWorkerRecord: %v", err) + } + + claimed, err := store.ClaimIdleWorker("cp-new:boot-b", "analytics", 0) + if err != nil { + t.Fatalf("ClaimIdleWorker: %v", err) + } + if claimed == nil { + t.Fatal("expected idle worker claim to succeed") + } + if claimed.WorkerID != 7 { + t.Fatalf("expected worker id 7, got %d", claimed.WorkerID) + } + if claimed.State != configstore.WorkerStateReserved { + t.Fatalf("expected reserved state, got %q", claimed.State) + } + if claimed.OwnerCPInstanceID != "cp-new:boot-b" { + t.Fatalf("expected owner cp-instance cp-new:boot-b, got %q", claimed.OwnerCPInstanceID) + } + if claimed.OwnerEpoch != 3 { + t.Fatalf("expected owner epoch 3, got %d", claimed.OwnerEpoch) + } + if claimed.OrgID != "analytics" { + t.Fatalf("expected org analytics, got %q", claimed.OrgID) + } + persisted, err := store.GetWorkerRecord(7) + if err != nil { + t.Fatalf("GetWorkerRecord: %v", err) + } + if persisted.State != configstore.WorkerStateReserved { + t.Fatalf("expected persisted reserved state, got %q", persisted.State) + } + if persisted.OwnerEpoch != 3 { + t.Fatalf("expected persisted owner epoch 3, got %d", persisted.OwnerEpoch) + } +} + +func TestClaimIdleWorkerReturnsNilWhenNoIdleWorkerExists(t *testing.T) { + store := newIsolatedConfigStore(t) + + claimed, err := store.ClaimIdleWorker("cp-new:boot-b", "analytics", 0) + if err != nil { + t.Fatalf("ClaimIdleWorker: %v", err) + } + if claimed != nil { + t.Fatalf("expected no claim, got %#v", claimed) + } +} + +func TestClaimIdleWorkerRespectsOrgCapPostgres(t *testing.T) { + store := newIsolatedConfigStore(t) + + now := time.Date(2026, time.March, 26, 13, 30, 0, 0, time.UTC) + if err := store.UpsertWorkerRecord(&configstore.WorkerRecord{ + WorkerID: 7, + PodName: "duckgres-worker-7", + State: configstore.WorkerStateIdle, + OwnerCPInstanceID: "cp-old:boot-a", + OwnerEpoch: 2, + LastHeartbeatAt: now, + }); err != nil { + t.Fatalf("UpsertWorkerRecord(idle): %v", err) + } + if err := store.UpsertWorkerRecord(&configstore.WorkerRecord{ + WorkerID: 8, + PodName: "duckgres-worker-8", + State: configstore.WorkerStateHot, + OrgID: "analytics", + OwnerCPInstanceID: "cp-old:boot-a", + OwnerEpoch: 4, + LastHeartbeatAt: now, + }); err != nil { + t.Fatalf("UpsertWorkerRecord(hot): %v", err) + } + + claimed, err := store.ClaimIdleWorker("cp-new:boot-b", "analytics", 1) + if err != nil { + t.Fatalf("ClaimIdleWorker: %v", err) + } + if claimed != nil { + t.Fatalf("expected org cap to block claim, got %#v", claimed) + } + + persisted, err := store.GetWorkerRecord(7) + if err != nil { + t.Fatalf("GetWorkerRecord: %v", err) + } + if persisted.State != configstore.WorkerStateIdle { + t.Fatalf("expected worker to remain idle, got %q", persisted.State) + } + if persisted.OrgID != "" { + t.Fatalf("expected idle worker org to remain empty, got %q", persisted.OrgID) + } +} + +func TestExpireControlPlaneInstancesPostgres(t *testing.T) { + store := newIsolatedConfigStore(t) + + startedAt := time.Date(2026, time.March, 26, 14, 0, 0, 0, time.UTC) + staleHeartbeat := startedAt.Add(5 * time.Second) + freshHeartbeat := startedAt.Add(40 * time.Second) + if err := store.UpsertControlPlaneInstance(&configstore.ControlPlaneInstance{ + ID: "cp-stale:boot-a", + PodName: "duckgres-stale", + PodUID: "pod-stale", + BootID: "boot-a", + State: configstore.ControlPlaneInstanceStateActive, + StartedAt: startedAt, + LastHeartbeatAt: staleHeartbeat, + }); err != nil { + t.Fatalf("UpsertControlPlaneInstance(stale): %v", err) + } + if err := store.UpsertControlPlaneInstance(&configstore.ControlPlaneInstance{ + ID: "cp-fresh:boot-b", + PodName: "duckgres-fresh", + PodUID: "pod-fresh", + BootID: "boot-b", + State: configstore.ControlPlaneInstanceStateActive, + StartedAt: startedAt, + LastHeartbeatAt: freshHeartbeat, + }); err != nil { + t.Fatalf("UpsertControlPlaneInstance(fresh): %v", err) + } + + expired, err := store.ExpireControlPlaneInstances(startedAt.Add(20 * time.Second)) + if err != nil { + t.Fatalf("ExpireControlPlaneInstances: %v", err) + } + if expired != 1 { + t.Fatalf("expected 1 expired instance, got %d", expired) + } + + stale, err := store.GetControlPlaneInstance("cp-stale:boot-a") + if err != nil { + t.Fatalf("GetControlPlaneInstance(stale): %v", err) + } + if stale.State != configstore.ControlPlaneInstanceStateExpired { + t.Fatalf("expected stale instance to be expired, got %q", stale.State) + } + if stale.ExpiredAt == nil { + t.Fatal("expected expired_at to be set for stale instance") + } + + fresh, err := store.GetControlPlaneInstance("cp-fresh:boot-b") + if err != nil { + t.Fatalf("GetControlPlaneInstance(fresh): %v", err) + } + if fresh.State != configstore.ControlPlaneInstanceStateActive { + t.Fatalf("expected fresh instance to stay active, got %q", fresh.State) + } + if fresh.ExpiredAt != nil { + t.Fatalf("expected fresh instance expired_at to stay nil, got %v", fresh.ExpiredAt) + } +} + +func TestExpireDrainingControlPlaneInstancesPostgres(t *testing.T) { + store := newIsolatedConfigStore(t) + + startedAt := time.Date(2026, time.March, 26, 14, 0, 0, 0, time.UTC) + oldDrain := startedAt.Add(5 * time.Minute) + recentDrain := startedAt.Add(20 * time.Minute) + if err := store.UpsertControlPlaneInstance(&configstore.ControlPlaneInstance{ + ID: "cp-draining-old:boot-a", + PodName: "duckgres-old", + PodUID: "pod-old", + BootID: "boot-a", + State: configstore.ControlPlaneInstanceStateDraining, + StartedAt: startedAt, + LastHeartbeatAt: recentDrain, + DrainingAt: &oldDrain, + }); err != nil { + t.Fatalf("UpsertControlPlaneInstance(old draining): %v", err) + } + if err := store.UpsertControlPlaneInstance(&configstore.ControlPlaneInstance{ + ID: "cp-draining-recent:boot-b", + PodName: "duckgres-recent", + PodUID: "pod-recent", + BootID: "boot-b", + State: configstore.ControlPlaneInstanceStateDraining, + StartedAt: startedAt, + LastHeartbeatAt: recentDrain, + DrainingAt: &recentDrain, + }); err != nil { + t.Fatalf("UpsertControlPlaneInstance(recent draining): %v", err) + } + + expired, err := store.ExpireDrainingControlPlaneInstances(startedAt.Add(15 * time.Minute)) + if err != nil { + t.Fatalf("ExpireDrainingControlPlaneInstances: %v", err) + } + if expired != 1 { + t.Fatalf("expected 1 overdue draining instance, got %d", expired) + } + + old, err := store.GetControlPlaneInstance("cp-draining-old:boot-a") + if err != nil { + t.Fatalf("GetControlPlaneInstance(old draining): %v", err) + } + if old.State != configstore.ControlPlaneInstanceStateExpired { + t.Fatalf("expected old draining instance to be expired, got %q", old.State) + } + if old.ExpiredAt == nil { + t.Fatal("expected expired_at to be set for old draining instance") + } + + recent, err := store.GetControlPlaneInstance("cp-draining-recent:boot-b") + if err != nil { + t.Fatalf("GetControlPlaneInstance(recent draining): %v", err) + } + if recent.State != configstore.ControlPlaneInstanceStateDraining { + t.Fatalf("expected recent draining instance to remain draining, got %q", recent.State) + } + if recent.ExpiredAt != nil { + t.Fatalf("expected recent draining instance expired_at to stay nil, got %v", recent.ExpiredAt) + } +} + +func TestCreateSpawningWorkerSlotPostgres(t *testing.T) { + store := newIsolatedConfigStore(t) + + slot, err := store.CreateSpawningWorkerSlot("cp-new:boot-b", "analytics", 1, "duckgres-worker-test-cp", 3, 5) + if err != nil { + t.Fatalf("CreateSpawningWorkerSlot: %v", err) + } + if slot == nil { + t.Fatal("expected spawning worker slot to be created") + } + if slot.WorkerID <= 0 { + t.Fatalf("expected positive worker id, got %d", slot.WorkerID) + } + if slot.State != configstore.WorkerStateSpawning { + t.Fatalf("expected spawning state, got %q", slot.State) + } + if slot.PodName != "duckgres-worker-test-cp-"+strconv.Itoa(slot.WorkerID) { + t.Fatalf("unexpected pod name %q for worker id %d", slot.PodName, slot.WorkerID) + } + if slot.OwnerCPInstanceID != "cp-new:boot-b" { + t.Fatalf("expected owner cp-instance cp-new:boot-b, got %q", slot.OwnerCPInstanceID) + } + if slot.OwnerEpoch != 1 { + t.Fatalf("expected owner epoch 1, got %d", slot.OwnerEpoch) + } + if slot.OrgID != "analytics" { + t.Fatalf("expected org analytics, got %q", slot.OrgID) + } + + persisted, err := store.GetWorkerRecord(slot.WorkerID) + if err != nil { + t.Fatalf("GetWorkerRecord: %v", err) + } + if persisted.State != configstore.WorkerStateSpawning { + t.Fatalf("expected persisted spawning state, got %q", persisted.State) + } +} + +func TestCreateSpawningWorkerSlotRespectsOrgAndGlobalCaps(t *testing.T) { + store := newIsolatedConfigStore(t) + now := time.Date(2026, time.March, 27, 13, 0, 0, 0, time.UTC) + + if err := store.UpsertWorkerRecord(&configstore.WorkerRecord{ + WorkerID: 9, + PodName: "duckgres-worker-existing-9", + State: configstore.WorkerStateHot, + OrgID: "analytics", + OwnerCPInstanceID: "cp-old:boot-a", + OwnerEpoch: 4, + LastHeartbeatAt: now, + }); err != nil { + t.Fatalf("UpsertWorkerRecord(existing): %v", err) + } + + orgLimited, err := store.CreateSpawningWorkerSlot("cp-new:boot-b", "analytics", 1, "duckgres-worker-test-cp", 1, 5) + if err != nil { + t.Fatalf("CreateSpawningWorkerSlot(org cap): %v", err) + } + if orgLimited != nil { + t.Fatalf("expected org cap to block spawning, got %#v", orgLimited) + } + + globalLimited, err := store.CreateSpawningWorkerSlot("cp-new:boot-b", "sales", 1, "duckgres-worker-test-cp", 2, 1) + if err != nil { + t.Fatalf("CreateSpawningWorkerSlot(global cap): %v", err) + } + if globalLimited != nil { + t.Fatalf("expected global cap to block spawning, got %#v", globalLimited) + } +} + +func TestCreateNeutralWarmWorkerSlotRespectsSharedWarmTarget(t *testing.T) { + store := newIsolatedConfigStore(t) + now := time.Date(2026, time.March, 27, 13, 30, 0, 0, time.UTC) + + if err := store.UpsertWorkerRecord(&configstore.WorkerRecord{ + WorkerID: 10, + PodName: "duckgres-worker-existing-10", + State: configstore.WorkerStateIdle, + OrgID: "", + OwnerCPInstanceID: "cp-old:boot-a", + OwnerEpoch: 0, + LastHeartbeatAt: now, + }); err != nil { + t.Fatalf("UpsertWorkerRecord(existing neutral): %v", err) + } + + blocked, err := store.CreateNeutralWarmWorkerSlot("cp-new:boot-b", "duckgres-worker-test-cp", 1, 5) + if err != nil { + t.Fatalf("CreateNeutralWarmWorkerSlot(shared target): %v", err) + } + if blocked != nil { + t.Fatalf("expected shared warm target to block spawning, got %#v", blocked) + } + + slot, err := store.CreateNeutralWarmWorkerSlot("cp-new:boot-b", "duckgres-worker-test-cp", 2, 5) + if err != nil { + t.Fatalf("CreateNeutralWarmWorkerSlot(expand target): %v", err) + } + if slot == nil { + t.Fatal("expected neutral warm slot to be created") + } + if slot.OrgID != "" { + t.Fatalf("expected neutral warm slot org to be empty, got %q", slot.OrgID) + } + if slot.State != configstore.WorkerStateSpawning { + t.Fatalf("expected spawning state, got %q", slot.State) + } +} + +func TestListOrphanedAndStuckWorkersPostgres(t *testing.T) { + store := newIsolatedConfigStore(t) + now := time.Date(2026, time.March, 27, 14, 0, 0, 0, time.UTC) + + if err := store.UpsertControlPlaneInstance(&configstore.ControlPlaneInstance{ + ID: "cp-expired:boot-a", + PodName: "duckgres-old", + PodUID: "pod-old", + BootID: "boot-a", + State: configstore.ControlPlaneInstanceStateExpired, + StartedAt: now.Add(-time.Hour), + LastHeartbeatAt: now.Add(-time.Minute), + ExpiredAt: ptrTime(now.Add(-time.Minute)), + }); err != nil { + t.Fatalf("UpsertControlPlaneInstance(expired): %v", err) + } + if err := store.UpsertWorkerRecord(&configstore.WorkerRecord{ + WorkerID: 61, + PodName: "duckgres-worker-61", + State: configstore.WorkerStateReserved, + OrgID: "analytics", + OwnerCPInstanceID: "cp-expired:boot-a", + OwnerEpoch: 2, + LastHeartbeatAt: now.Add(-time.Minute), + }); err != nil { + t.Fatalf("UpsertWorkerRecord(orphaned): %v", err) + } + if err := store.UpsertWorkerRecord(&configstore.WorkerRecord{ + WorkerID: 63, + PodName: "duckgres-worker-63", + State: configstore.WorkerStateRetired, + OrgID: "analytics", + OwnerCPInstanceID: "cp-expired:boot-a", + OwnerEpoch: 3, + LastHeartbeatAt: now.Add(-time.Minute), + RetireReason: "normal", + }); err != nil { + t.Fatalf("UpsertWorkerRecord(retired orphan): %v", err) + } + if err := store.UpsertWorkerRecord(&configstore.WorkerRecord{ + WorkerID: 62, + PodName: "duckgres-worker-62", + State: configstore.WorkerStateActivating, + OrgID: "analytics", + OwnerCPInstanceID: "cp-live:boot-b", + OwnerEpoch: 1, + LastHeartbeatAt: now.Add(-time.Minute), + }); err != nil { + t.Fatalf("UpsertWorkerRecord(stuck): %v", err) + } + if err := store.DB().Table(store.RuntimeSchema()+".worker_records"). + Where("worker_id = ?", 62). + Update("updated_at", now.Add(-3*time.Minute)).Error; err != nil { + t.Fatalf("age stuck worker: %v", err) + } + + orphaned, err := store.ListOrphanedWorkers(now.Add(-30 * time.Second)) + if err != nil { + t.Fatalf("ListOrphanedWorkers: %v", err) + } + if len(orphaned) != 2 { + t.Fatalf("expected orphaned worker 61, got %#v", orphaned) + } + if orphaned[0].WorkerID != 61 || orphaned[1].WorkerID != 63 { + t.Fatalf("expected orphaned workers 61 and 63, got %#v", orphaned) + } + + stuck, err := store.ListStuckWorkers(now.Add(-2*time.Minute), now.Add(-2*time.Minute)) + if err != nil { + t.Fatalf("ListStuckWorkers: %v", err) + } + if len(stuck) != 1 || stuck[0].WorkerID != 62 { + t.Fatalf("expected stuck worker 62, got %#v", stuck) + } +} + +func TestExpireFlightSessionRecordsPostgres(t *testing.T) { + store := newIsolatedConfigStore(t) + now := time.Date(2026, time.March, 27, 15, 0, 0, 0, time.UTC) + + if err := store.UpsertFlightSessionRecord(&configstore.FlightSessionRecord{ + SessionToken: "flight-expire-me", + Username: "postgres", + OrgID: "analytics", + WorkerID: 42, + OwnerEpoch: 7, + State: configstore.FlightSessionStateActive, + ExpiresAt: now.Add(-time.Minute), + LastSeenAt: now.Add(-2 * time.Minute), + }); err != nil { + t.Fatalf("UpsertFlightSessionRecord: %v", err) + } + + expired, err := store.ExpireFlightSessionRecords(now) + if err != nil { + t.Fatalf("ExpireFlightSessionRecords: %v", err) + } + if expired != 1 { + t.Fatalf("expected one expired session record, got %d", expired) + } + + record, err := store.GetFlightSessionRecord("flight-expire-me") + if err != nil { + t.Fatalf("GetFlightSessionRecord: %v", err) + } + if record.State != configstore.FlightSessionStateExpired { + t.Fatalf("expected expired flight session state, got %q", record.State) + } +} + +func TestGetTouchAndCloseFlightSessionRecordPostgres(t *testing.T) { + store := newIsolatedConfigStore(t) + now := time.Date(2026, time.March, 27, 16, 0, 0, 0, time.UTC) + + if err := store.UpsertFlightSessionRecord(&configstore.FlightSessionRecord{ + SessionToken: "flight-touch-close", + Username: "postgres", + OrgID: "analytics", + WorkerID: 42, + OwnerEpoch: 8, + State: configstore.FlightSessionStateActive, + ExpiresAt: now.Add(time.Hour), + LastSeenAt: now.Add(-time.Minute), + }); err != nil { + t.Fatalf("UpsertFlightSessionRecord: %v", err) + } + + record, err := store.GetFlightSessionRecord("flight-touch-close") + if err != nil { + t.Fatalf("GetFlightSessionRecord: %v", err) + } + if record == nil || record.Username != "postgres" { + t.Fatalf("expected durable record with username postgres, got %#v", record) + } + + touchedAt := now.Add(2 * time.Minute) + if err := store.TouchFlightSessionRecord("flight-touch-close", touchedAt); err != nil { + t.Fatalf("TouchFlightSessionRecord: %v", err) + } + record, err = store.GetFlightSessionRecord("flight-touch-close") + if err != nil { + t.Fatalf("GetFlightSessionRecord: %v", err) + } + if !record.LastSeenAt.Equal(touchedAt) { + t.Fatalf("expected last_seen_at %v, got %v", touchedAt, record.LastSeenAt) + } + + closedAt := now.Add(3 * time.Minute) + if err := store.CloseFlightSessionRecord("flight-touch-close", closedAt); err != nil { + t.Fatalf("CloseFlightSessionRecord: %v", err) + } + record, err = store.GetFlightSessionRecord("flight-touch-close") + if err != nil { + t.Fatalf("GetFlightSessionRecord: %v", err) + } + if record.State != configstore.FlightSessionStateClosed { + t.Fatalf("expected closed state, got %q", record.State) + } + if !record.LastSeenAt.Equal(closedAt) { + t.Fatalf("expected close timestamp %v, got %v", closedAt, record.LastSeenAt) + } +} + +func TestGetFlightSessionRecordReturnsNilWhenMissing(t *testing.T) { + store := newIsolatedConfigStore(t) + + record, err := store.GetFlightSessionRecord("missing-flight-session") + if err != nil { + t.Fatalf("GetFlightSessionRecord: %v", err) + } + if record != nil { + t.Fatalf("expected nil record for missing session, got %#v", record) + } +} + +func TestTakeOverWorkerPostgres(t *testing.T) { + store := newIsolatedConfigStore(t) + now := time.Date(2026, time.March, 27, 17, 0, 0, 0, time.UTC) + + if err := store.UpsertWorkerRecord(&configstore.WorkerRecord{ + WorkerID: 71, + PodName: "duckgres-worker-71", + State: configstore.WorkerStateHot, + OrgID: "analytics", + OwnerCPInstanceID: "cp-old:boot-a", + OwnerEpoch: 5, + LastHeartbeatAt: now, + }); err != nil { + t.Fatalf("UpsertWorkerRecord: %v", err) + } + + claimed, err := store.TakeOverWorker(71, "cp-new:boot-b", "analytics", 5) + if err != nil { + t.Fatalf("TakeOverWorker: %v", err) + } + if claimed == nil { + t.Fatal("expected takeover to succeed") + } + if claimed.OwnerCPInstanceID != "cp-new:boot-b" { + t.Fatalf("expected owner cp-instance cp-new:boot-b, got %q", claimed.OwnerCPInstanceID) + } + if claimed.OwnerEpoch != 6 { + t.Fatalf("expected owner epoch 6, got %d", claimed.OwnerEpoch) + } + if claimed.State != configstore.WorkerStateReserved { + t.Fatalf("expected reserved state, got %q", claimed.State) + } + + missed, err := store.TakeOverWorker(71, "cp-third:boot-c", "analytics", 5) + if err == nil { + t.Fatal("expected stale takeover attempt to return an epoch mismatch error") + } + if !errors.Is(err, configstore.ErrWorkerOwnerEpochMismatch) { + t.Fatalf("expected ErrWorkerOwnerEpochMismatch, got %v", err) + } + if missed != nil { + t.Fatalf("expected stale takeover attempt to fail, got %#v", missed) + } +} + +func ptrTime(t time.Time) *time.Time { + return &t +} diff --git a/tests/k8s/k8s_test.go b/tests/k8s/k8s_test.go index ae9757f1..bb1ce7f6 100644 --- a/tests/k8s/k8s_test.go +++ b/tests/k8s/k8s_test.go @@ -130,10 +130,8 @@ func TestK8sWorkerPodCreation(t *testing.T) { if cpLabel == "" { t.Errorf("worker pod %s missing duckgres/control-plane label", pod.Name) } - - // Verify owner references - if len(pod.OwnerReferences) == 0 { - t.Errorf("worker pod %s has no owner references", pod.Name) + if pod.Labels["duckgres/worker-id"] == "" { + t.Errorf("worker pod %s missing duckgres/worker-id label", pod.Name) } } } @@ -276,40 +274,51 @@ func TestK8sCPDeletionGarbageCollects(t *testing.T) { t.Skip("no worker pods found — cannot test GC") } - // Verify worker pods have owner references pointing to the CP pod - for _, wp := range workerPods.Items { - hasOwner := false - for _, ref := range wp.OwnerReferences { - if ref.Kind == "Pod" { - hasOwner = true - t.Logf("Worker %s owned by %s (UID %s)", wp.Name, ref.Name, ref.UID) - } - } - if !hasOwner { - t.Errorf("worker pod %s has no Pod owner reference — GC will not work", wp.Name) + ownedWorkers := workerPodsByControlPlaneLabel(workerPods.Items) + if len(ownedWorkers) == 0 { + t.Skip("no worker pods with duckgres/control-plane label found") + } + + // Delete a CP pod that currently owns at least one worker. + var cpName string + var workerNames []string + for ownerName, owned := range ownedWorkers { + if len(owned) > 0 { + cpName = ownerName + workerNames = append([]string(nil), owned...) + break } } + if cpName == "" { + t.Skip("no control-plane-owned worker pods found") + } - // Delete the CP pod (the deployment will recreate it) cpPods, err := clientset.CoreV1().Pods(namespace).List(context.Background(), metav1.ListOptions{ LabelSelector: "app=duckgres-control-plane", }) if err != nil || len(cpPods.Items) == 0 { t.Fatalf("failed to find CP pod: %v", err) } - cpName := cpPods.Items[0].Name - t.Logf("Deleting CP pod %s to test garbage collection", cpName) - err = clientset.CoreV1().Pods(namespace).Delete(context.Background(), cpName, metav1.DeleteOptions{}) + foundCP := false + for _, pod := range cpPods.Items { + if pod.Name == cpName { + foundCP = true + break + } + } + if !foundCP { + t.Skipf("control-plane pod %s no longer exists", cpName) + } + gracePeriodSeconds := int64(0) + t.Logf("Force deleting CP pod %s to test crash-style garbage collection", cpName) + err = clientset.CoreV1().Pods(namespace).Delete(context.Background(), cpName, metav1.DeleteOptions{ + GracePeriodSeconds: &gracePeriodSeconds, + }) if err != nil { t.Fatalf("failed to delete CP pod: %v", err) } - // Wait for old worker pods to be garbage collected - workerNames := make([]string, len(workerPods.Items)) - for i, wp := range workerPods.Items { - workerNames[i] = wp.Name - } - + // Wait for the deleted control plane's worker pods to be retired. allGone := false deadline := time.Now().Add(90 * time.Second) for time.Now().Before(deadline) { diff --git a/tests/k8s/setup_config_test.go b/tests/k8s/setup_config_test.go index 975ec497..4153eaf0 100644 --- a/tests/k8s/setup_config_test.go +++ b/tests/k8s/setup_config_test.go @@ -163,6 +163,27 @@ func TestLocalDependencyPortsStayFixedAndPreflighted(t *testing.T) { } } +func TestControlPlaneRBACIncludesLeaseAccess(t *testing.T) { + root := findProjectRootForUnitTest(t) + + rbacPath := filepath.Join(root, "k8s", "rbac.yaml") + manifest, err := os.ReadFile(rbacPath) + if err != nil { + t.Fatalf("read rbac manifest: %v", err) + } + + content := string(manifest) + for _, want := range []string{ + `apiGroups: ["coordination.k8s.io"]`, + `resources: ["leases"]`, + `verbs: ["create", "delete", "get", "list", "patch", "update", "watch"]`, + } { + if !strings.Contains(content, want) { + t.Fatalf("expected %q in %s", want, rbacPath) + } + } +} + func findProjectRootForUnitTest(t *testing.T) string { t.Helper() diff --git a/tests/k8s/worker_owner_helper_test.go b/tests/k8s/worker_owner_helper_test.go new file mode 100644 index 00000000..d94c319a --- /dev/null +++ b/tests/k8s/worker_owner_helper_test.go @@ -0,0 +1,15 @@ +package k8s_test + +import corev1 "k8s.io/api/core/v1" + +func workerPodsByControlPlaneLabel(pods []corev1.Pod) map[string][]string { + owned := make(map[string][]string) + for _, pod := range pods { + cpName := pod.Labels["duckgres/control-plane"] + if cpName == "" { + continue + } + owned[cpName] = append(owned[cpName], pod.Name) + } + return owned +} diff --git a/tests/k8s/worker_owner_helper_unit_test.go b/tests/k8s/worker_owner_helper_unit_test.go new file mode 100644 index 00000000..8a2b8e2c --- /dev/null +++ b/tests/k8s/worker_owner_helper_unit_test.go @@ -0,0 +1,47 @@ +package k8s_test + +import ( + "reflect" + "testing" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestWorkerPodsByControlPlaneLabel(t *testing.T) { + pods := []corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "worker-a", + Labels: map[string]string{"duckgres/control-plane": "cp-a"}, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "worker-b", + Labels: map[string]string{"duckgres/control-plane": "cp-b"}, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "worker-c", + Labels: map[string]string{"duckgres/control-plane": "cp-a"}, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "worker-unowned", + Labels: map[string]string{}, + }, + }, + } + + got := workerPodsByControlPlaneLabel(pods) + want := map[string][]string{ + "cp-a": {"worker-a", "worker-c"}, + "cp-b": {"worker-b"}, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("workerPodsByControlPlaneLabel() = %#v, want %#v", got, want) + } +}