diff --git a/packages/api/internal/handlers/admin_kill_team_sandboxes.go b/packages/api/internal/handlers/admin_kill_team_sandboxes.go index cfe8fd2d1c..5f3d589254 100644 --- a/packages/api/internal/handlers/admin_kill_team_sandboxes.go +++ b/packages/api/internal/handlers/admin_kill_team_sandboxes.go @@ -43,7 +43,7 @@ func (a *APIStore) PostAdminTeamsTeamIDSandboxesKill(c *gin.Context, teamID uuid // Kill each sandbox for _, sbx := range sandboxes { wg.Go(func() error { - err := a.orchestrator.RemoveSandbox(ctx, sbx, sandbox.StateActionKill) + err := a.orchestrator.RemoveSandbox(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) if err != nil { logger.L().Error(ctx, "Failed to kill sandbox", logger.WithSandboxID(sbx.SandboxID), diff --git a/packages/api/internal/handlers/sandbox_kill.go b/packages/api/internal/handlers/sandbox_kill.go index 1ebfe654d9..08de5d09a0 100644 --- a/packages/api/internal/handlers/sandbox_kill.go +++ b/packages/api/internal/handlers/sandbox_kill.go @@ -66,33 +66,21 @@ func (a *APIStore) DeleteSandboxesSandboxID( killedOrRemoved := false - sbx, err := a.orchestrator.GetSandbox(ctx, teamID, sandboxID) - if err == nil { - if sbx.TeamID != teamID { - logger.L().Debug(ctx, "Sandbox team mismatch on kill", logger.WithSandboxID(sandboxID), logger.WithTeamID(teamID.String())) - a.sendAPIStoreError(c, http.StatusNotFound, sandboxNotFoundMsg(sandboxID)) - - return - } - - err = a.orchestrator.RemoveSandbox(ctx, sbx, sandbox.StateActionKill) - switch { - case err == nil: - killedOrRemoved = true - case errors.Is(err, orchestrator.ErrSandboxNotFound): - logger.L().Debug(ctx, "Sandbox not found", logger.WithSandboxID(sandboxID)) - case errors.Is(err, orchestrator.ErrSandboxOperationFailed): - a.sendAPIStoreError(c, http.StatusInternalServerError, fmt.Sprintf("Error killing sandbox: %s", err)) - - return - default: - telemetry.ReportError(ctx, "error killing sandbox", err) - a.sendAPIStoreError(c, http.StatusInternalServerError, fmt.Sprintf("Error killing sandbox: %s", err)) - - return - } - } else { - logger.L().Debug(ctx, "Sandbox not found", logger.WithSandboxID(sandboxID)) + err = a.orchestrator.RemoveSandbox(ctx, teamID, sandboxID, sandbox.StateActionKill) + switch { + case err == nil: + killedOrRemoved = true + case errors.Is(err, orchestrator.ErrSandboxNotFound): + logger.L().Debug(ctx, "Running sandbox not found", logger.WithSandboxID(sandboxID)) + case errors.Is(err, orchestrator.ErrSandboxOperationFailed): + a.sendAPIStoreError(c, http.StatusInternalServerError, fmt.Sprintf("Error killing sandbox: %s", err)) + + return + default: + telemetry.ReportError(ctx, "error killing sandbox", err) + a.sendAPIStoreError(c, http.StatusInternalServerError, fmt.Sprintf("Error killing sandbox: %s", err)) + + return } // remove any snapshots when the sandbox is not running diff --git a/packages/api/internal/handlers/sandbox_pause.go b/packages/api/internal/handlers/sandbox_pause.go index d8fb35fdb9..54210231b2 100644 --- a/packages/api/internal/handlers/sandbox_pause.go +++ b/packages/api/internal/handlers/sandbox_pause.go @@ -40,23 +40,7 @@ func (a *APIStore) PostSandboxesSandboxIDPause(c *gin.Context, sandboxID api.San traceID := span.SpanContext().TraceID().String() c.Set("traceID", traceID) - sbx, err := a.orchestrator.GetSandbox(ctx, teamID, sandboxID) - if err != nil { - apiErr := pauseHandleNotRunningSandbox(ctx, a.sqlcDB, sandboxID, teamID) - a.sendAPIStoreError(c, apiErr.Code, apiErr.ClientMsg) - - return - } - - if sbx.TeamID != teamID { - logger.L().Debug(ctx, "Sandbox team mismatch on pause", logger.WithSandboxID(sandboxID), logger.WithTeamID(teamID.String())) - a.sendAPIStoreError(c, http.StatusNotFound, sandboxNotFoundMsg(sandboxID)) - - return - } - - err = a.orchestrator.RemoveSandbox(ctx, sbx, sandbox.StateActionPause) - + err = a.orchestrator.RemoveSandbox(ctx, teamID, sandboxID, sandbox.StateActionPause) var transErr *sandbox.InvalidStateTransitionError switch { diff --git a/packages/api/internal/handlers/snapshot_template_create.go b/packages/api/internal/handlers/snapshot_template_create.go index 2d5b146fc5..0b1d832cdb 100644 --- a/packages/api/internal/handlers/snapshot_template_create.go +++ b/packages/api/internal/handlers/snapshot_template_create.go @@ -95,7 +95,9 @@ func (a *APIStore) PostSandboxesSandboxIDSnapshots(c *gin.Context, sandboxID api opts.Namespace = &teamInfo.Slug } - sbx, err := a.orchestrator.GetSandbox(ctx, teamID, sandboxID) + telemetry.ReportEvent(ctx, "Creating snapshot template") + + result, err := a.orchestrator.CreateSnapshotTemplate(ctx, teamID, sandboxID, opts) if err != nil { var notFoundErr *sandbox.NotFoundError if errors.As(err, ¬FoundErr) { @@ -104,22 +106,7 @@ func (a *APIStore) PostSandboxesSandboxIDSnapshots(c *gin.Context, sandboxID api return } - a.sendAPIStoreError(c, http.StatusInternalServerError, "Error getting sandbox") - - return - } - - if sbx.TeamID != teamID { - logger.L().Debug(ctx, "Sandbox team mismatch on snapshot", logger.WithSandboxID(sandboxID), logger.WithTeamID(teamID.String())) - a.sendAPIStoreError(c, http.StatusNotFound, sandboxNotFoundMsg(sandboxID)) - return - } - - telemetry.ReportEvent(ctx, "Creating snapshot template") - - result, err := a.orchestrator.CreateSnapshotTemplate(ctx, teamID, sandboxID, opts) - if err != nil { var transErr *sandbox.InvalidStateTransitionError if errors.As(err, &transErr) { a.sendAPIStoreError(c, http.StatusConflict, fmt.Sprintf("Sandbox '%s' cannot be snapshotted while in '%s' state", sandboxID, transErr.CurrentState)) diff --git a/packages/api/internal/orchestrator/analytics.go b/packages/api/internal/orchestrator/analytics.go index cabaf5a1f2..f3661bbbbf 100644 --- a/packages/api/internal/orchestrator/analytics.go +++ b/packages/api/internal/orchestrator/analytics.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/google/uuid" "github.com/posthog/posthog-go" "go.opentelemetry.io/otel/metric" "go.uber.org/zap" @@ -89,6 +90,6 @@ func (o *Orchestrator) sandboxCounterInsert(ctx context.Context, sandbox sandbox o.sandboxCounter.Add(ctx, 1, metric.WithAttributes(telemetry.WithTeamID(sandbox.TeamID.String()))) } -func (o *Orchestrator) countersRemove(ctx context.Context, sandbox sandbox.Sandbox, _ sandbox.StateAction) { - o.sandboxCounter.Add(ctx, -1, metric.WithAttributes(telemetry.WithTeamID(sandbox.TeamID.String()))) +func (o *Orchestrator) countersRemove(ctx context.Context, teamID uuid.UUID, _ sandbox.StateAction) { + o.sandboxCounter.Add(ctx, -1, metric.WithAttributes(telemetry.WithTeamID(teamID.String()))) } diff --git a/packages/api/internal/orchestrator/delete_instance.go b/packages/api/internal/orchestrator/delete_instance.go index ade69deaaf..c2ec8e3c65 100644 --- a/packages/api/internal/orchestrator/delete_instance.go +++ b/packages/api/internal/orchestrator/delete_instance.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" + "github.com/google/uuid" "go.uber.org/zap" "github.com/e2b-dev/infra/packages/api/internal/sandbox" @@ -14,15 +15,21 @@ import ( sbxlogger "github.com/e2b-dev/infra/packages/shared/pkg/logger/sandbox" ) -func (o *Orchestrator) RemoveSandbox(ctx context.Context, sbx sandbox.Sandbox, stateAction sandbox.StateAction) error { +func (o *Orchestrator) RemoveSandbox(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) error { ctx, span := tracer.Start(ctx, "remove-sandbox") defer span.End() - sandboxID := sbx.SandboxID - alreadyDone, finish, err := o.sandboxStore.StartRemoving(ctx, sbx.TeamID, sandboxID, stateAction) + sbx, alreadyDone, finish, err := o.sandboxStore.StartRemoving(ctx, teamID, sandboxID, stateAction) if err != nil { switch stateAction { case sandbox.StateActionKill: + var notFoundErr *sandbox.NotFoundError + if errors.As(err, ¬FoundErr) { + logger.L().Info(ctx, "Sandbox not found, already removed", logger.WithSandboxID(sandboxID)) + + return ErrSandboxNotFound + } + switch sbx.State { case sandbox.StateKilling: logger.L().Info(ctx, "Sandbox is already killed", logger.WithSandboxID(sandboxID)) @@ -34,6 +41,13 @@ func (o *Orchestrator) RemoveSandbox(ctx context.Context, sbx sandbox.Sandbox, s return ErrSandboxOperationFailed } case sandbox.StateActionPause: + var notFoundErrPause *sandbox.NotFoundError + if errors.As(err, ¬FoundErrPause) { + logger.L().Info(ctx, "Sandbox not found for pause", logger.WithSandboxID(sandboxID)) + + return ErrSandboxNotFound + } + var transErr *sandbox.InvalidStateTransitionError if errors.As(err, &transErr) { if transErr.CurrentState == sandbox.StateKilling { @@ -64,9 +78,9 @@ func (o *Orchestrator) RemoveSandbox(ctx context.Context, sbx sandbox.Sandbox, s return nil } - defer func() { go o.countersRemove(context.WithoutCancel(ctx), sbx, stateAction) }() + defer func() { go o.countersRemove(context.WithoutCancel(ctx), teamID, stateAction) }() defer func() { go o.analyticsRemove(context.WithoutCancel(ctx), sbx, stateAction) }() - defer o.sandboxStore.Remove(ctx, sbx.TeamID, sbx.SandboxID) + defer o.sandboxStore.Remove(ctx, teamID, sandboxID) err = o.removeSandboxFromNode(ctx, sbx, stateAction) if err != nil { logger.L().Error(ctx, "Error pausing sandbox", zap.Error(err), logger.WithSandboxID(sbx.SandboxID)) diff --git a/packages/api/internal/orchestrator/evictor/evict.go b/packages/api/internal/orchestrator/evictor/evict.go index ff9a116615..8a1622cb43 100644 --- a/packages/api/internal/orchestrator/evictor/evict.go +++ b/packages/api/internal/orchestrator/evictor/evict.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/google/uuid" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -17,12 +18,12 @@ const ( type Evictor struct { store *sandbox.Store - removeSandbox func(ctx context.Context, sandbox sandbox.Sandbox, stateAction sandbox.StateAction) error + removeSandbox func(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) error } func New( store *sandbox.Store, - removeSandbox func(ctx context.Context, sandbox sandbox.Sandbox, stateAction sandbox.StateAction) error, + removeSandbox func(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) error, ) *Evictor { return &Evictor{ store: store, @@ -58,7 +59,7 @@ func (e *Evictor) Start(ctx context.Context) { } logger.L().Debug(ctx, "Evicting sandbox", logger.WithSandboxID(item.SandboxID), zap.String("state_action", stateAction.Name)) - if err := e.removeSandbox(context.WithoutCancel(ctx), item, stateAction); err != nil { + if err := e.removeSandbox(context.WithoutCancel(ctx), item.TeamID, item.SandboxID, stateAction); err != nil { logger.L().Debug(ctx, "Evicting sandbox failed", zap.Error(err), logger.WithSandboxID(item.SandboxID)) } diff --git a/packages/api/internal/orchestrator/snapshot_template.go b/packages/api/internal/orchestrator/snapshot_template.go index b76788c90b..af70dfb185 100644 --- a/packages/api/internal/orchestrator/snapshot_template.go +++ b/packages/api/internal/orchestrator/snapshot_template.go @@ -39,12 +39,7 @@ func (o *Orchestrator) CreateSnapshotTemplate(ctx context.Context, teamID uuid.U ctx, span := tracer.Start(ctx, "create-snapshot-template") defer span.End() - sbx, err := o.sandboxStore.Get(ctx, teamID, sandboxID) - if err != nil { - return SnapshotTemplateResult{}, fmt.Errorf("failed to get sandbox: %w", err) - } - - alreadyDone, finishSnapshotting, err := o.sandboxStore.StartRemoving(ctx, teamID, sandboxID, sandbox.StateActionSnapshot) + sbx, alreadyDone, finishSnapshotting, err := o.sandboxStore.StartRemoving(ctx, teamID, sandboxID, sandbox.StateActionSnapshot) if err != nil { return SnapshotTemplateResult{}, fmt.Errorf("failed to start snapshotting: %w", err) } @@ -102,7 +97,7 @@ func (o *Orchestrator) CreateSnapshotTemplate(ctx context.Context, teamID uuid.U // so RemoveSandbox can proceed without deadlock. finish(err) - if killErr := o.RemoveSandbox(ctx, sbx, sandbox.StateActionKill); killErr != nil { + if killErr := o.RemoveSandbox(ctx, teamID, sandboxID, sandbox.StateActionKill); killErr != nil { telemetry.ReportError(ctx, "error killing sandbox after failed checkpoint", killErr) } diff --git a/packages/api/internal/sandbox/storage/memory/operations.go b/packages/api/internal/sandbox/storage/memory/operations.go index c1420719c7..6a3d9bf138 100644 --- a/packages/api/internal/sandbox/storage/memory/operations.go +++ b/packages/api/internal/sandbox/storage/memory/operations.go @@ -117,13 +117,20 @@ func (s *Storage) Update(_ context.Context, _ uuid.UUID, sandboxID string, updat return sbx, nil } -func (s *Storage) StartRemoving(ctx context.Context, _ uuid.UUID, sandboxID string, stateAction sandbox.StateAction) (alreadyDone bool, callback func(context.Context, error), err error) { +func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) (sandbox.Sandbox, bool, func(context.Context, error), error) { sbx, err := s.get(sandboxID) if err != nil { - return false, nil, err + return sandbox.Sandbox{}, false, nil, &sandbox.NotFoundError{SandboxID: sandboxID} } - return startRemoving(ctx, sbx, stateAction) + data := sbx.Data() + if data.TeamID != teamID { + return sandbox.Sandbox{}, false, nil, &sandbox.NotFoundError{SandboxID: sandboxID} + } + + alreadyDone, callback, err := startRemoving(ctx, sbx, stateAction) + + return data, alreadyDone, callback, err } func startRemoving(ctx context.Context, sbx *memorySandbox, stateAction sandbox.StateAction) (alreadyDone bool, callback func(ctx context.Context, err error), err error) { diff --git a/packages/api/internal/sandbox/storage/memory/operations_test.go b/packages/api/internal/sandbox/storage/memory/operations_test.go index 8373c3d5c9..b030de1c3b 100644 --- a/packages/api/internal/sandbox/storage/memory/operations_test.go +++ b/packages/api/internal/sandbox/storage/memory/operations_test.go @@ -449,7 +449,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { err := storage.Add(ctx, sbx) require.NoError(t, err) - snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) + _, snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) require.NoError(t, err) assert.False(t, snapAlreadyDone) require.NotNil(t, finishSnap) @@ -461,7 +461,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { go func() { defer close(pauseDone) - pauseAlreadyDone, pauseFinish, pauseErr = storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, pauseAlreadyDone, pauseFinish, pauseErr = storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) }() time.Sleep(50 * time.Millisecond) @@ -505,7 +505,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { err := storage.Add(ctx, sbx) require.NoError(t, err) - snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) + _, snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) require.NoError(t, err) assert.False(t, snapAlreadyDone) @@ -517,7 +517,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { go func() { defer close(killDone) - killAlreadyDone, killFinish, killErr = storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, killAlreadyDone, killFinish, killErr = storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) }() // Give the kill goroutine time to start waiting @@ -564,7 +564,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { err := storage.Add(ctx, sbx) require.NoError(t, err) - _, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) + _, _, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) require.NoError(t, err) // Finish with error — state stays Snapshotting, transition cleared @@ -575,7 +575,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { assert.Equal(t, sandbox.StateSnapshotting, got.State) // Kill proceeds immediately — no active transition, Snapshotting→Killing is allowed - killAlreadyDone, killFinish, killErr := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, killAlreadyDone, killFinish, killErr := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) require.NoError(t, killErr) assert.False(t, killAlreadyDone) require.NotNil(t, killFinish) @@ -607,7 +607,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { err := storage.Add(ctx, sbx) require.NoError(t, err) - snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) + _, snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) require.NoError(t, err) assert.False(t, snapAlreadyDone) diff --git a/packages/api/internal/sandbox/storage/populate_redis/main.go b/packages/api/internal/sandbox/storage/populate_redis/main.go index 5c76b2ffde..6640cf2e54 100644 --- a/packages/api/internal/sandbox/storage/populate_redis/main.go +++ b/packages/api/internal/sandbox/storage/populate_redis/main.go @@ -82,7 +82,7 @@ func (m *PopulateRedisStorage) Update(ctx context.Context, teamID uuid.UUID, san return sbx, nil } -func (m *PopulateRedisStorage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) (alreadyDone bool, callback func(context.Context, error), err error) { +func (m *PopulateRedisStorage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) (sandbox.Sandbox, bool, func(context.Context, error), error) { return m.memoryBackend.StartRemoving(ctx, teamID, sandboxID, stateAction) } diff --git a/packages/api/internal/sandbox/storage/redis/state_change.go b/packages/api/internal/sandbox/storage/redis/state_change.go index 761b469ee7..b399ccf794 100644 --- a/packages/api/internal/sandbox/storage/redis/state_change.go +++ b/packages/api/internal/sandbox/storage/redis/state_change.go @@ -31,7 +31,7 @@ import ( // // The callback is critical: it deletes the transition key // and sets the result value with short TTL to notify waiters of the outcome. -func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) (alreadyDone bool, callback func(context.Context, error), err error) { +func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) (sandbox.Sandbox, bool, func(context.Context, error), error) { newState := stateAction.TargetState key := getSandboxKey(teamID.String(), sandboxID) @@ -40,7 +40,7 @@ func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID // Acquire distributed lock lock, err := s.lockService.Obtain(ctx, redis_utils.GetLockKey(key), lockTimeout, s.lockOption) if err != nil { - return false, nil, fmt.Errorf("failed to obtain lock: %w", err) + return sandbox.Sandbox{}, false, nil, fmt.Errorf("failed to obtain lock: %w", err) } // Ensure lock is released once @@ -58,21 +58,21 @@ func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID // Get current sandbox state first data, err := s.redisClient.Get(ctx, key).Bytes() if errors.Is(err, redis.Nil) { - return false, nil, &sandbox.NotFoundError{SandboxID: sandboxID} + return sandbox.Sandbox{}, false, nil, &sandbox.NotFoundError{SandboxID: sandboxID} } if err != nil { - return false, nil, fmt.Errorf("failed to get sandbox from Redis: %w", err) + return sandbox.Sandbox{}, false, nil, fmt.Errorf("failed to get sandbox from Redis: %w", err) } var sbx sandbox.Sandbox if err = json.Unmarshal(data, &sbx); err != nil { - return false, nil, fmt.Errorf("failed to unmarshal sandbox: %w", err) + return sandbox.Sandbox{}, false, nil, fmt.Errorf("failed to unmarshal sandbox: %w", err) } // Check if there's an existing transition transactionID, err := s.redisClient.Get(ctx, transitionKey).Result() if err != nil && !errors.Is(err, redis.Nil) { - return false, nil, fmt.Errorf("failed to check transition key: %w", err) + return sbx, false, nil, fmt.Errorf("failed to check transition key: %w", err) } if transactionID != "" { @@ -88,25 +88,27 @@ func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID if sbx.State == newState { logger.L().Debug(ctx, "Already in the same state", logger.WithSandboxID(sandboxID), zap.String("state", string(newState))) - return true, func(context.Context, error) {}, nil + return sbx, true, func(context.Context, error) {}, nil } // Validate state transition is allowed if !sandbox.AllowedTransitions[sbx.State][newState] { - return false, nil, &sandbox.InvalidStateTransitionError{CurrentState: sbx.State, TargetState: newState} + return sbx, false, nil, &sandbox.InvalidStateTransitionError{CurrentState: sbx.State, TargetState: newState} } - // Update sandbox state - sbx.State = newState + // Build the updated sandbox for Redis without mutating the original. + // This ensures that on failure the caller sees the pre-mutation state, + updated := sbx + updated.State = newState if stateAction.Effect == sandbox.TransitionExpires { - if !sbx.IsExpired() { - sbx.EndTime = time.Now() + if !updated.IsExpired() { + updated.EndTime = time.Now() } } - newData, err := json.Marshal(sbx) + newData, err := json.Marshal(updated) if err != nil { - return false, nil, fmt.Errorf("failed to marshal sandbox: %w", err) + return sbx, false, nil, fmt.Errorf("failed to marshal sandbox: %w", err) } // Generate transition ID @@ -119,12 +121,12 @@ func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID err = startTransitionScript.Run(ctx, s.redisClient, []string{key, transitionKey, resultKey}, newData, transitionID, ttlSeconds, resultTtlSeconds).Err() if err != nil { - return false, nil, fmt.Errorf("failed to update sandbox state: %w", err) + return sbx, false, nil, fmt.Errorf("failed to update sandbox state: %w", err) } logger.L().Debug(ctx, "Started state transition", logger.WithSandboxID(sandboxID), zap.String("state", string(newState)), zap.String("transitionID", transitionID)) - return false, s.createCallback(teamID, sandboxID, transitionKey, resultKey, transitionID, stateAction), nil + return updated, false, s.createCallback(teamID, sandboxID, transitionKey, resultKey, transitionID, stateAction), nil } // createCallback returns a callback function for completing a transition. @@ -263,7 +265,7 @@ func (s *Storage) handleExistingTransition( stateAction sandbox.StateAction, newState sandbox.State, transactionID string, -) (bool, func(context.Context, error), error) { +) (sandbox.Sandbox, bool, func(context.Context, error), error) { if sbx.State == newState { // Same target state - wait for completion and return alreadyDone=true logger.L().Debug(ctx, "State transition already in progress to the same state, waiting", @@ -272,20 +274,20 @@ func (s *Storage) handleExistingTransition( err := s.waitForTransition(ctx, teamID, sbx.SandboxID, transactionID) if err != nil { - return false, nil, fmt.Errorf("failed waiting for transition: %w", err) + return sbx, false, nil, fmt.Errorf("failed waiting for transition: %w", err) } - return true, func(context.Context, error) {}, nil + return sbx, true, func(context.Context, error) {}, nil } // Different state - validate transition and wait if !sandbox.AllowedTransitions[sbx.State][newState] { - return false, nil, &sandbox.InvalidStateTransitionError{CurrentState: sbx.State, TargetState: newState} + return sbx, false, nil, &sandbox.InvalidStateTransitionError{CurrentState: sbx.State, TargetState: newState} } err := s.waitForTransition(ctx, teamID, sbx.SandboxID, transactionID) if err != nil { - return false, nil, fmt.Errorf("failed waiting for transition: %w", err) + return sbx, false, nil, fmt.Errorf("failed waiting for transition: %w", err) } // Retry with new state after transition completes diff --git a/packages/api/internal/sandbox/storage/redis/state_change_test.go b/packages/api/internal/sandbox/storage/redis/state_change_test.go index 1e0843942b..3469e8b4d5 100644 --- a/packages/api/internal/sandbox/storage/redis/state_change_test.go +++ b/packages/api/internal/sandbox/storage/redis/state_change_test.go @@ -74,7 +74,7 @@ func TestStartRemoving_BasicTransitions(t *testing.T) { err := storage.Add(ctx, sbx) require.NoError(t, err) - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, tt.stateAction) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, tt.stateAction) switch { case tt.shouldError: @@ -119,7 +119,7 @@ func TestStartRemoving_PauseThenKill(t *testing.T) { require.NoError(t, err) // Start pause operation - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -145,7 +145,7 @@ func TestStartRemoving_PauseThenKill(t *testing.T) { // Meanwhile, another request tries to kill the sandbox start := time.Now() - alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) elapsed := time.Since(start) // Should have waited for the pause to complete @@ -186,7 +186,7 @@ func TestStartRemoving_ConcurrentSameState(t *testing.T) { // Three concurrent requests to pause the sandbox for range 3 { go func() { - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) if err != nil { results <- struct { alreadyDone bool @@ -251,7 +251,7 @@ func TestStartRemoving_NotFound(t *testing.T) { ctx := context.Background() teamID := uuid.New() - alreadyDone, callback, err := storage.StartRemoving(ctx, teamID, "non-existent", sandbox.StateActionKill) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, teamID, "non-existent", sandbox.StateActionKill) require.Error(t, err) assert.False(t, alreadyDone) assert.Nil(t, callback) @@ -270,7 +270,7 @@ func TestStartRemoving_ContextCancellation(t *testing.T) { require.NoError(t, err) // Start a transition - alreadyDone1, callback1, err := storage.StartRemoving(context.Background(), sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone1, callback1, err := storage.StartRemoving(context.Background(), sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone1) require.NotNil(t, callback1) @@ -280,12 +280,13 @@ func TestStartRemoving_ContextCancellation(t *testing.T) { defer cancel() start := time.Now() - _, _, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, alreadyDone2, _, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) elapsed := time.Since(start) // Should timeout require.Error(t, err2) require.ErrorIs(t, err2, context.DeadlineExceeded) + assert.False(t, alreadyDone2) assert.Greater(t, elapsed, 20*time.Millisecond) assert.Less(t, elapsed, 200*time.Millisecond) @@ -319,7 +320,7 @@ func TestWaitForStateChange_WaitForCompletion(t *testing.T) { require.NoError(t, err) // Start a transition - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -358,7 +359,7 @@ func TestWaitForStateChange_ContextCancellation(t *testing.T) { require.NoError(t, err) // Start a transition - alreadyDone, callback, err := storage.StartRemoving(context.Background(), sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(context.Background(), sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -399,7 +400,7 @@ func TestWaitForStateChange_MultipleWaiters(t *testing.T) { require.NoError(t, err) // Start a transition - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -443,7 +444,7 @@ func TestStartRemoving_TransitionKeyTTL(t *testing.T) { require.NoError(t, err) // Start a transition but don't complete it - alreadyDone, _, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, _, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) @@ -471,7 +472,7 @@ func TestStartRemoving_CallbackMarksTransitionCompleted(t *testing.T) { require.NoError(t, err) // Start a transition - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -513,7 +514,7 @@ func TestStartRemoving_CallbackSetsErrorOnFailure(t *testing.T) { require.NoError(t, err) // Start a transition - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -558,7 +559,7 @@ func TestStartRemoving_SetsEndTimeWhenNotExpired(t *testing.T) { beforeTransition := time.Now() // Start a transition - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -585,7 +586,7 @@ func TestStartRemoving_WaiterCompletesOnCallbackSuccess(t *testing.T) { require.NoError(t, err) // Start a transition - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -613,7 +614,7 @@ func TestStartRemoving_WaiterCompletesOnCallbackSuccess(t *testing.T) { } // Retry should work now - sandbox is already in pausing state - alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err2) // Already in pausing state from first transition assert.True(t, alreadyDone2) @@ -631,7 +632,7 @@ func TestStartRemoving_WaiterReceivesErrorOnCallbackFailure(t *testing.T) { require.NoError(t, err) // Start a transition - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -674,7 +675,7 @@ func TestStartRemoving_DifferentExecutionID(t *testing.T) { require.NoError(t, err) // Start a transition - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -697,7 +698,7 @@ func TestStartRemoving_DifferentExecutionID(t *testing.T) { require.NoError(t, err) // Now start a new pause transition - should work since previous transition completed - alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err2) assert.False(t, alreadyDone2, "Should not be alreadyDone since we have a new execution") require.NotNil(t, callback2) @@ -727,7 +728,7 @@ func TestStartRemoving_TransientTransition(t *testing.T) { sbx := createTestSandbox("transient-restore") require.NoError(t, storage.Add(ctx, sbx)) - _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, transientAction) + _, _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, transientAction) require.NoError(t, err) finish(ctx, nil) @@ -746,7 +747,7 @@ func TestStartRemoving_TransientTransition(t *testing.T) { sbx := createTestSandbox("transient-fail-result") require.NoError(t, storage.Add(ctx, sbx)) - _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, transientAction) + _, _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, transientAction) require.NoError(t, err) transitionKey := getTransitionKey(sbx.TeamID.String(), sbx.SandboxID) @@ -772,7 +773,7 @@ func TestStartRemoving_TransientTransition(t *testing.T) { sbx := createTestSandbox("transient-restore-fail") require.NoError(t, storage.Add(ctx, sbx)) - _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, transientAction) + _, _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, transientAction) require.NoError(t, err) // Remove the sandbox key to force restoreToRunning to fail @@ -805,13 +806,13 @@ func TestStartRemoving_CompletedTransitionAllowsNewTransition(t *testing.T) { require.NoError(t, err) // Start and complete a pause transition - alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) require.NoError(t, err) assert.False(t, alreadyDone) callback(ctx, nil) // Immediately try to kill - should work since pause is completed - alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) require.NoError(t, err2) assert.False(t, alreadyDone2) require.NotNil(t, callback2) diff --git a/packages/api/internal/sandbox/store.go b/packages/api/internal/sandbox/store.go index 43f2fac067..52bf71e1de 100644 --- a/packages/api/internal/sandbox/store.go +++ b/packages/api/internal/sandbox/store.go @@ -36,7 +36,7 @@ type Storage interface { //nolint: interfacebloat TeamsWithSandboxCount(ctx context.Context) (map[uuid.UUID]int64, error) Update(ctx context.Context, teamID uuid.UUID, sandboxID string, updateFunc func(sandbox Sandbox) (Sandbox, error)) (Sandbox, error) - StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction StateAction) (alreadyDone bool, callback func(context.Context, error), err error) + StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction StateAction) (Sandbox, bool, func(context.Context, error), error) WaitForStateChange(ctx context.Context, teamID uuid.UUID, sandboxID string) error Sync(sandboxes []Sandbox, nodeID string) []Sandbox } @@ -150,7 +150,7 @@ func (s *Store) Update(ctx context.Context, teamID uuid.UUID, sandboxID string, return s.storage.Update(ctx, teamID, sandboxID, updateFunc) } -func (s *Store) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction StateAction) (alreadyDone bool, callback func(context.Context, error), err error) { +func (s *Store) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction StateAction) (Sandbox, bool, func(context.Context, error), error) { return s.storage.StartRemoving(ctx, teamID, sandboxID, stateAction) }