diff --git a/actor/transform.go b/actor/transform.go new file mode 100644 index 00000000000..055e0f3382a --- /dev/null +++ b/actor/transform.go @@ -0,0 +1,58 @@ +package actor + +import "context" + +// MapInputRef wraps a TellOnlyRef and transforms incoming messages before +// forwarding them to the target ref. This allows adapting a ref that expects +// message type Out to accept message type In, eliminating the need for +// intermediate adapter actors. +// +// This is particularly useful for notification patterns where a source actor +// sends events of a specific type, but consumers want to receive events in +// their own domain-specific type. +// +// Example usage: +// +// // roundActorRef accepts round.ConfirmationEvent +// // chainsource sends chainsource.ConfirmationEvent +// adaptedRef := actor.NewMapInputRef( +// roundActorRef, +// func(cs chainsource.ConfirmationEvent) round.ConfirmationEvent { +// return round.ConfirmationEvent{ +// TxID: cs.Txid, +// BlockHeight: cs.BlockHeight, +// // ... transform fields +// } +// }, +// ) +// // Now adaptedRef can be used as TellOnlyRef[chainsource.ConfirmationEvent] +type MapInputRef[In Message, Out Message] struct { + targetRef TellOnlyRef[Out] + mapFn func(In) Out +} + +// NewMapInputRef creates a new message-transforming wrapper around a +// TellOnlyRef. The mapFn function is called for each message to transform it +// from type In to type Out before forwarding to targetRef. +func NewMapInputRef[In Message, Out Message]( + targetRef TellOnlyRef[Out], mapFn func(In) Out) *MapInputRef[In, Out] { + + return &MapInputRef[In, Out]{ + targetRef: targetRef, + mapFn: mapFn, + } +} + +// Tell transforms the incoming message using the map function and forwards it +// to the target ref. If the context is cancelled before the message can be sent +// to the target actor's mailbox, the message may be dropped. +func (m *MapInputRef[In, Out]) Tell(ctx context.Context, msg In) { + transformed := m.mapFn(msg) + m.targetRef.Tell(ctx, transformed) +} + +// ID returns a unique identifier for this actor. The ID includes the +// "map-input-" prefix to indicate this is a transformation wrapper. +func (m *MapInputRef[In, Out]) ID() string { + return "map-input-" + m.targetRef.ID() +} diff --git a/actor/transform_test.go b/actor/transform_test.go new file mode 100644 index 00000000000..93f3915c482 --- /dev/null +++ b/actor/transform_test.go @@ -0,0 +1,201 @@ +package actor + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// testMessageA is a test message type for transformation testing. +type testMessageA struct { + BaseMessage + Value int + Text string +} + +// MessageType returns the message type identifier. +func (m testMessageA) MessageType() string { + return "testMessageA" +} + +// testMessageB is another test message type for transformation testing. +type testMessageB struct { + BaseMessage + DoubledValue int + UpperText string +} + +// MessageType returns the message type identifier. +func (m testMessageB) MessageType() string { + return "testMessageB" +} + +// mockTellOnlyRef is a mock implementation of TellOnlyRef for testing. +type mockTellOnlyRef[M Message] struct { + id string + received []M +} + +func (m *mockTellOnlyRef[M]) Tell(ctx context.Context, msg M) { + m.received = append(m.received, msg) +} + +func (m *mockTellOnlyRef[M]) ID() string { + return m.id +} + +// TestMapInputRefBasicTransformation tests that messages are correctly +// transformed when sent through a MapInputRef. +func TestMapInputRefBasicTransformation(t *testing.T) { + t.Parallel() + + // Create a mock target ref that expects testMessageB. + targetRef := &mockTellOnlyRef[testMessageB]{ + id: "test-target", + } + + // Create a transformation function from A to B. + transformFn := func(a testMessageA) testMessageB { + return testMessageB{ + DoubledValue: a.Value * 2, + UpperText: a.Text + "-TRANSFORMED", + } + } + + // Create the MapInputRef that accepts testMessageA. + adaptedRef := NewMapInputRef(targetRef, transformFn) + + // Send a message of type A. + ctx := context.Background() + inputMsg := testMessageA{ + Value: 42, + Text: "hello", + } + adaptedRef.Tell(ctx, inputMsg) + + // Verify the target received the transformed message. + require.Len(t, targetRef.received, 1) + received := targetRef.received[0] + require.Equal(t, 84, received.DoubledValue) + require.Equal(t, "hello-TRANSFORMED", received.UpperText) +} + +// TestMapInputRefMultipleMessages tests that multiple messages are all +// correctly transformed. +func TestMapInputRefMultipleMessages(t *testing.T) { + t.Parallel() + + targetRef := &mockTellOnlyRef[testMessageB]{ + id: "test-target", + } + + transformFn := func(a testMessageA) testMessageB { + return testMessageB{ + DoubledValue: a.Value * 2, + UpperText: a.Text, + } + } + + adaptedRef := NewMapInputRef(targetRef, transformFn) + ctx := context.Background() + + // Send multiple messages. + messages := []testMessageA{ + {Value: 1, Text: "one"}, + {Value: 2, Text: "two"}, + {Value: 3, Text: "three"}, + } + + for _, msg := range messages { + adaptedRef.Tell(ctx, msg) + } + + // Verify all messages were transformed and received. + require.Len(t, targetRef.received, 3) + require.Equal(t, 2, targetRef.received[0].DoubledValue) + require.Equal(t, "one", targetRef.received[0].UpperText) + require.Equal(t, 4, targetRef.received[1].DoubledValue) + require.Equal(t, "two", targetRef.received[1].UpperText) + require.Equal(t, 6, targetRef.received[2].DoubledValue) + require.Equal(t, "three", targetRef.received[2].UpperText) +} + +// TestMapInputRefID tests that the ID method returns a prefixed version of +// the target ref's ID. +func TestMapInputRefID(t *testing.T) { + t.Parallel() + + targetRef := &mockTellOnlyRef[testMessageB]{ + id: "my-target-actor", + } + + transformFn := func(a testMessageA) testMessageB { + return testMessageB{} + } + + adaptedRef := NewMapInputRef(targetRef, transformFn) + + // Verify the ID includes the target ID with a prefix. + expectedID := "map-input-my-target-actor" + require.Equal(t, expectedID, adaptedRef.ID()) +} + +// TestMapInputRefTypeSafety tests that the generic type constraints ensure +// compile-time type safety. +func TestMapInputRefTypeSafety(t *testing.T) { + t.Parallel() + + // This test verifies that the type system works correctly. If this + // compiles, it proves type safety is maintained. + targetRef := &mockTellOnlyRef[testMessageB]{ + id: "test-target", + } + + // Create a MapInputRef[A, B]. + var adaptedRef TellOnlyRef[testMessageA] = NewMapInputRef( + targetRef, + func(a testMessageA) testMessageB { + return testMessageB{ + DoubledValue: a.Value, + } + }, + ) + + // Verify we can use it as a TellOnlyRef[testMessageA]. + ctx := context.Background() + adaptedRef.Tell(ctx, testMessageA{Value: 10}) + + // The fact that this compiles and runs proves type safety. + require.Len(t, targetRef.received, 1) +} + +// TestMapInputRefIdentityTransform tests that MapInputRef works when the +// input and output types are the same (identity transformation). +func TestMapInputRefIdentityTransform(t *testing.T) { + t.Parallel() + + targetRef := &mockTellOnlyRef[testMessageA]{ + id: "test-target", + } + + // Identity transformation: A -> A with modified value. + transformFn := func(a testMessageA) testMessageA { + a.Value = a.Value + 100 + return a + } + + adaptedRef := NewMapInputRef(targetRef, transformFn) + + ctx := context.Background() + inputMsg := testMessageA{ + Value: 5, + Text: "test", + } + adaptedRef.Tell(ctx, inputMsg) + + // Verify the transformation was applied. + require.Len(t, targetRef.received, 1) + require.Equal(t, 105, targetRef.received[0].Value) + require.Equal(t, "test", targetRef.received[0].Text) +} diff --git a/go.mod b/go.mod index 7c501f88170..50763762ac0 100644 --- a/go.mod +++ b/go.mod @@ -64,6 +64,8 @@ require ( pgregory.net/rapid v1.2.0 ) +require github.com/lightningnetwork/lnd/actor v0.0.1-alpha + require ( dario.cat/mergo v1.0.1 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect @@ -213,6 +215,9 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2 // allows us to specify that as an option. replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.33.0-hex-display +// Use the local actor module for development. +replace github.com/lightningnetwork/lnd/actor => ./actor + // If you change this please also update docs/INSTALL.md and GO_VERSION in // Makefile (then run `make lint` to see where else it needs to be updated as // well). diff --git a/protofsm/actor_wrapper.go b/protofsm/actor_wrapper.go new file mode 100644 index 00000000000..b0be9a07eb4 --- /dev/null +++ b/protofsm/actor_wrapper.go @@ -0,0 +1,23 @@ +package protofsm + +import ( + "fmt" + + "github.com/lightningnetwork/lnd/actor" +) + +// ActorMessage wraps an Event, in order to create a new message that can be +// used with the actor package. +type ActorMessage[Event any] struct { + actor.BaseMessage + + // Event is the event that is being sent to the actor. + Event Event +} + +// MessageType returns the type of the message. +// +// NOTE: This implements the actor.Message interface. +func (a ActorMessage[Event]) MessageType() string { + return fmt.Sprintf("ActorMessage(%T)", a.Event) +} diff --git a/protofsm/state_machine.go b/protofsm/state_machine.go index b3e16f5fd35..984abe7ee76 100644 --- a/protofsm/state_machine.go +++ b/protofsm/state_machine.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/actor" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnutils" @@ -42,6 +43,12 @@ type EmittedEvent[Event any] struct { // ExternalEvent is an optional external event that is to be sent to // the daemon for dispatch. Usually, this is some form of I/O. ExternalEvents DaemonEventSet + + // Outbox is an optional set of events that are accumulated during event + // processing and returned to the caller for processing into the main + // state machine. This enables nested state machines to emit events that + // bubble up to their parent. + Outbox []Event } // StateTransition is a state transition type. It denotes the next state to go @@ -124,6 +131,18 @@ type stateQuery[Event any, Env Environment] struct { CurrentState chan State[Event, Env] } +// syncEventRequest is used to send an event to the state machine synchronously, +// waiting for the event processing to complete and returning the accumulated +// outbox events. +type syncEventRequest[Event any] struct { + // event is the event to process. + event Event + + // promise is used to signal completion and return the accumulated + // outbox events or an error. + promise actor.Promise[[]Event] +} + // StateMachine represents an abstract FSM that is able to process new incoming // events and drive a state machine to termination. This implementation uses // type params to abstract over the types of events and environment. Events @@ -140,6 +159,10 @@ type StateMachine[Event any, Env Environment] struct { // FSM. events chan Event + // syncEvents is the channel that will be used to send synchronous event + // requests to the FSM, returning the accumulated outbox events. + syncEvents chan syncEventRequest[Event] + // newStateEvents is an EventDistributor that will be used to notify // any relevant callers of new state transitions that occur. newStateEvents *fn.EventDistributor[State[Event, Env]] @@ -214,6 +237,7 @@ func NewStateMachine[Event any, Env Environment]( fmt.Sprintf("FSM(%v):", cfg.Env.Name()), ), events: make(chan Event, 1), + syncEvents: make(chan syncEventRequest[Event], 1), stateQuery: make(chan stateQuery[Event, Env]), gm: *fn.NewGoroutineManager(), newStateEvents: fn.NewEventDistributor[State[Event, Env]](), @@ -259,6 +283,84 @@ func (s *StateMachine[Event, Env]) SendEvent(ctx context.Context, event Event) { } } +// AskEvent sends a new event to the state machine using the Ask pattern +// (request-response), waiting for the event to be fully processed. It +// returns a Future that will be resolved with the accumulated outbox events +// from all state transitions triggered by this event, including nested +// internal events. The Future's Await method will return fn.Result[[]Event] +// containing either the accumulated outbox events or an error if processing +// failed. +func (s *StateMachine[Event, Env]) AskEvent(ctx context.Context, + event Event) actor.Future[[]Event] { + + s.log.Debugf("Asking event %T", event) + + // Create a promise to signal completion and return results. + promise := actor.NewPromise[[]Event]() + + req := syncEventRequest[Event]{ + event: event, + promise: promise, + } + + // Check for context cancellation or shutdown first to avoid races. + select { + case <-ctx.Done(): + promise.Complete( + fn.Errf[[]Event]("context cancelled: %w", + ctx.Err()), + ) + + return promise.Future() + + case <-s.quit: + promise.Complete(fn.Err[[]Event](ErrStateMachineShutdown)) + + return promise.Future() + + default: + } + + // Send the request to the state machine. If we can't send it due to + // context cancellation or shutdown, complete the promise with an error. + select { + // Successfully sent, the promise will be completed by driveMachine. + case s.syncEvents <- req: + + case <-ctx.Done(): + promise.Complete( + fn.Errf[[]Event]("context cancelled: %w", + ctx.Err()), + ) + + case <-s.quit: + promise.Complete(fn.Err[[]Event](ErrStateMachineShutdown)) + } + + return promise.Future() +} + +// Receive processes a message and returns a Result containing the accumulated +// outbox events from the state machine. The provided context is the actor's +// internal context, which can be used to detect actor shutdown requests. +// +// This method uses the AskEvent pattern to wait for the event to be fully +// processed and collect any outbox events emitted during state transitions. +// This enables the actor system to propagate events from nested state machines +// up through the actor hierarchy. +// +// NOTE: This implements the actor.ActorBehavior interface. +func (s *StateMachine[Event, Env]) Receive(ctx context.Context, + e ActorMessage[Event]) fn.Result[[]Event] { + + // Use AskEvent to process the event and get the outbox events back. + future := s.AskEvent(ctx, e.Event) + + // Await the result which will contain the accumulated outbox events + // from all state transitions triggered by this event. + return future.Await(ctx) +} + // CanHandle returns true if the target message can be routed to the state // machine. func (s *StateMachine[Event, Env]) CanHandle(msg msgmux.PeerMsg) bool { @@ -563,13 +665,19 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(ctx context.Context, // applyEvents applies a new event to the state machine. This will continue // until no further events are emitted by the state machine. Along the way, -// we'll also ensure to execute any daemon events that are emitted. +// we'll also ensure to execute any daemon events that are emitted. The +// function returns the final state, any accumulated outbox events, and an +// error if one occurred. func (s *StateMachine[Event, Env]) applyEvents(ctx context.Context, currentState State[Event, Env], newEvent Event) (State[Event, Env], - error) { + []Event, error) { eventQueue := fn.NewQueue(newEvent) + // outbox accumulates all outbox events from state transitions during + // the entire event processing chain. + var outbox []Event + // Given the next event to handle, we'll process the event, then add // any new emitted internal events to our event queue. This continues // until we reach a terminal state, or we run out of internal events to @@ -613,6 +721,10 @@ func (s *StateMachine[Event, Env]) applyEvents(ctx context.Context, eventQueue.Enqueue(inEvent) } + // Accumulate any outbox events from this state + // transition. + outbox = append(outbox, events.Outbox...) + return nil }) if err != nil { @@ -636,11 +748,11 @@ func (s *StateMachine[Event, Env]) applyEvents(ctx context.Context, return nil }) if err != nil { - return currentState, err + return currentState, nil, err } } - return currentState, nil + return currentState, outbox, nil } // driveMachine is the main event loop of the state machine. It accepts any new @@ -671,7 +783,7 @@ func (s *StateMachine[Event, Env]) driveMachine(ctx context.Context) { // machine forward until we either run out of internal events, // or we reach a terminal state. case newEvent := <-s.events: - newState, err := s.applyEvents( + newState, _, err := s.applyEvents( ctx, currentState, newEvent, ) if err != nil { @@ -688,6 +800,37 @@ func (s *StateMachine[Event, Env]) driveMachine(ctx context.Context) { currentState = newState + // We have a synchronous event request that expects the + // accumulated outbox events to be returned via the promise. + case syncReq := <-s.syncEvents: + newState, outbox, err := s.applyEvents( + ctx, currentState, syncReq.event, + ) + if err != nil { + s.cfg.ErrorReporter.ReportError(err) + + s.log.ErrorS(ctx, "Unable to apply sync event", + err) + + // Complete the promise with the error. + // + // TODO(roasbeef): distinguish between error + // types? state vs processing + syncReq.promise.Complete(fn.Err[[]Event](err)) + + // An error occurred, so we'll tear down the + // entire state machine as we can't proceed. + go s.Stop() + + return + } + + currentState = newState + + // Complete the promise with the accumulated outbox + // events. + syncReq.promise.Complete(fn.Ok(outbox)) + // An outside caller is querying our state, so we'll return the // latest state. case stateQuery := <-s.stateQuery: diff --git a/protofsm/state_machine_test.go b/protofsm/state_machine_test.go index ca060614f3b..469b76a7756 100644 --- a/protofsm/state_machine_test.go +++ b/protofsm/state_machine_test.go @@ -1,6 +1,7 @@ package protofsm import ( + "context" "encoding/hex" "fmt" "sync/atomic" @@ -868,3 +869,347 @@ func TestStateMachineMsgMapper(t *testing.T) { adapters.AssertExpectations(t) env.AssertExpectations(t) } + +// outboxEvent is a test event type that gets added to the outbox. +type outboxEvent struct { + id int +} + +func (o *outboxEvent) dummy() { +} + +// emitOutbox is a test event that triggers a state to emit outbox events. +type emitOutbox struct { + numOutbox int + numInternal int + shouldGoToFin bool +} + +func (e *emitOutbox) dummy() { +} + +// dummyStateOutbox is a test state that emits outbox events during +// transitions. +type dummyStateOutbox struct { + counter int +} + +func (d *dummyStateOutbox) String() string { + return fmt.Sprintf("dummyStateOutbox(%d)", d.counter) +} + +func (d *dummyStateOutbox) ProcessEvent(event dummyEvents, env *dummyEnv, +) (*StateTransition[dummyEvents, *dummyEnv], error) { + + switch newEvent := event.(type) { + case *emitOutbox: + // Create outbox events based on the request. + outbox := make([]dummyEvents, newEvent.numOutbox) + for i := 0; i < newEvent.numOutbox; i++ { + outbox[i] = &outboxEvent{ + id: d.counter*100 + i, + } + } + + // Create internal events that will also emit outbox events. + internalEvents := make([]dummyEvents, newEvent.numInternal) + for i := 0; i < newEvent.numInternal; i++ { + internalEvents[i] = &emitOutbox{ + numOutbox: 1, + numInternal: 0, + shouldGoToFin: false, + } + } + + var nextState State[dummyEvents, *dummyEnv] + if newEvent.shouldGoToFin { + nextState = &dummyStateFin{} + } else { + nextState = &dummyStateOutbox{counter: d.counter + 1} + } + + return &StateTransition[dummyEvents, *dummyEnv]{ + NextState: nextState, + NewEvents: fn.Some(EmittedEvent[dummyEvents]{ + InternalEvent: internalEvents, + Outbox: outbox, + }), + }, nil + + case *goToFin: + return &StateTransition[dummyEvents, *dummyEnv]{ + NextState: &dummyStateFin{}, + }, nil + + case *outboxEvent: + // When processing an outbox event (shouldn't happen in normal + // flow), just stay in current state. + return &StateTransition[dummyEvents, *dummyEnv]{ + NextState: d, + }, nil + } + + return nil, fmt.Errorf("unknown event: %T", event) +} + +func (d *dummyStateOutbox) IsTerminal() bool { + return false +} + +// TestStateMachineAskEvent tests the AskEvent method and outbox event +// accumulation functionality. +func TestStateMachineAskEvent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + event dummyEvents + expectedOutboxCount int + expectError bool + }{ + { + name: "basic outbox accumulation", + event: &emitOutbox{ + numOutbox: 3, + numInternal: 0, + shouldGoToFin: false, + }, + expectedOutboxCount: 3, + expectError: false, + }, + + // 2 from top-level + 3 from internal events (1 each). + { + name: "nested internal events with outbox", + event: &emitOutbox{ + numOutbox: 2, + numInternal: 3, + shouldGoToFin: false, + }, + expectedOutboxCount: 5, + expectError: false, + }, + + { + name: "empty outbox", + event: &emitOutbox{ + numOutbox: 0, + numInternal: 0, + shouldGoToFin: false, + }, + expectedOutboxCount: 0, + expectError: false, + }, + + // 1 from top-level + 5 from internal events. + { + name: "deeply nested outbox", + event: &emitOutbox{ + numOutbox: 1, + numInternal: 5, + shouldGoToFin: false, + }, + expectedOutboxCount: 6, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := t.Context() + + // Create our state machine with the outbox test state. + env := &dummyEnv{} + startingState := &dummyStateOutbox{counter: 0} + adapters := newDaemonAdapters() + + cfg := StateMachineCfg[dummyEvents, *dummyEnv]{ + Daemon: adapters, + InitialState: startingState, + Env: env, + } + stateMachine := NewStateMachine(cfg) + + stateSub := stateMachine.RegisterStateEvents() + defer stateMachine.RemoveStateSub(stateSub) + + stateMachine.Start(ctx) + defer stateMachine.Stop() + + // Wait for initial state. + expectedStates := []State[dummyEvents, *dummyEnv]{ + &dummyStateOutbox{}, + } + assertStateTransitions(t, stateSub, expectedStates) + + // Send the event using Ask pattern. + future := stateMachine.AskEvent(ctx, tc.event) + require.NotNil(t, future) + + result := future.Await(ctx) + + if tc.expectError { + require.True(t, result.IsErr()) + } else { + require.True(t, result.IsOk()) + + // Extract the outbox events. + outbox := result.UnwrapOr(nil) + require.Len(t, outbox, tc.expectedOutboxCount) + + // Verify outbox events are of the correct type. + for _, event := range outbox { + _, ok := event.(*outboxEvent) + require.True(t, ok, + "expected outboxEvent, got %T", + event) + } + } + + adapters.AssertExpectations(t) + env.AssertExpectations(t) + }) + } +} + +// TestStateMachineOutboxWithMixedEvents tests that outbox accumulation works +// correctly when mixed with regular SendEvent calls. +func TestStateMachineOutboxWithMixedEvents(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + // Create our state machine with the outbox test state. + env := &dummyEnv{} + startingState := &dummyStateOutbox{counter: 0} + adapters := newDaemonAdapters() + + cfg := StateMachineCfg[dummyEvents, *dummyEnv]{ + Daemon: adapters, + InitialState: startingState, + Env: env, + } + stateMachine := NewStateMachine(cfg) + + // Subscribe to state transitions, then start the main state machine. + stateSub := stateMachine.RegisterStateEvents() + defer stateMachine.RemoveStateSub(stateSub) + + stateMachine.Start(ctx) + defer stateMachine.Stop() + + expectedStates := []State[dummyEvents, *dummyEnv]{ + &dummyStateOutbox{}, + } + assertStateTransitions(t, stateSub, expectedStates) + + // Send a regular async event first. + stateMachine.SendEvent(ctx, &emitOutbox{ + numOutbox: 1, + numInternal: 0, + shouldGoToFin: false, + }) + + // Wait for state transition from async event. + expectedStates = []State[dummyEvents, *dummyEnv]{ + &dummyStateOutbox{counter: 1}, + } + assertStateTransitions(t, stateSub, expectedStates) + + // Now send an event using Ask pattern. + future := stateMachine.AskEvent(ctx, &emitOutbox{ + numOutbox: 2, + numInternal: 1, + shouldGoToFin: false, + }) + + result := future.Await(ctx) + require.True(t, result.IsOk()) + + // We should have 3 outbox events (2 from top-level + 1 from internal). + outbox := result.UnwrapOr(nil) + require.Len(t, outbox, 3) + + adapters.AssertExpectations(t) + env.AssertExpectations(t) +} + +// TestStateMachineAskEventContextCancellation tests that context cancellation +// is properly handled in AskEvent. +func TestStateMachineAskEventContextCancellation(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + env := &dummyEnv{} + startingState := &dummyStateOutbox{counter: 0} + adapters := newDaemonAdapters() + + cfg := StateMachineCfg[dummyEvents, *dummyEnv]{ + Daemon: adapters, + InitialState: startingState, + Env: env, + } + stateMachine := NewStateMachine(cfg) + + stateMachine.Start(ctx) + defer stateMachine.Stop() + + // Create a context that's already cancelled. + cancelledCtx, cancel := context.WithCancel(t.Context()) + cancel() + + // Try to send an event with a cancelled context. + future := stateMachine.AskEvent(cancelledCtx, &emitOutbox{ + numOutbox: 1, + numInternal: 0, + shouldGoToFin: false, + }) + + // The future should be completed with an error. + result := future.Await(ctx) + require.True(t, result.IsErr()) + + adapters.AssertExpectations(t) + env.AssertExpectations(t) +} + +// TestStateMachineAskEventAfterShutdown tests that AskEvent properly handles +// the case where the state machine has been shut down. +func TestStateMachineAskEventAfterShutdown(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + // Create our state machine. + env := &dummyEnv{} + startingState := &dummyStateOutbox{counter: 0} + adapters := newDaemonAdapters() + + cfg := StateMachineCfg[dummyEvents, *dummyEnv]{ + Daemon: adapters, + InitialState: startingState, + Env: env, + } + stateMachine := NewStateMachine(cfg) + + stateMachine.Start(ctx) + + // Stop the state machine. + stateMachine.Stop() + + // Try to send an event after shutdown. + future := stateMachine.AskEvent(ctx, &emitOutbox{ + numOutbox: 1, + numInternal: 0, + shouldGoToFin: false, + }) + + // The future should be completed with a shutdown error. + result := future.Await(ctx) + require.True(t, result.IsErr()) + require.ErrorIs(t, result.Err(), ErrStateMachineShutdown) + + adapters.AssertExpectations(t) + env.AssertExpectations(t) +}