diff --git a/multinode/README.md b/multinode/README.md index 12aef2d..306b499 100644 --- a/multinode/README.md +++ b/multinode/README.md @@ -19,4 +19,32 @@ Manages all nodes performing node selection and load balancing, health checks an Used to poll for new heads and finalized heads within subscriptions. ### Transaction Sender -Used to send transactions to all healthy RPCs and aggregate the results. \ No newline at end of file +Used to send transactions to all healthy RPCs and aggregate the results. + +## States diagram + +```mermaid +graph TD + Undialed --> Dialed + Undialed --> Unreachable + Dialed --> Alive + Dialed --> InvalidChainID + Dialed --> Syncing + Dialed --> Unreachable + Alive --> OutOfSync + Alive --> Unreachable + OutOfSync --> Alive + OutOfSync --> InvalidChainID + OutOfSync --> Syncing + OutOfSync --> Unreachable + InvalidChainID --> Alive + InvalidChainID --> Syncing + InvalidChainID --> Unreachable + Syncing --> Alive + Syncing --> OutOfSync + Syncing --> InvalidChainID + Syncing --> Unreachable + Unreachable --> Dialed + Unusable:::terminal + Closed:::terminal +``` diff --git a/multinode/config/config.go b/multinode/config/config.go index 4d7d8ca..be7540e 100644 --- a/multinode/config/config.go +++ b/multinode/config/config.go @@ -44,6 +44,10 @@ func (c *MultiNodeConfig) PollFailureThreshold() uint32 { return *c.MultiNode.PollFailureThreshold } +func (c *MultiNodeConfig) PollSuccessThreshold() uint32 { + return 0 // retaining source compat for -solana; -evm sets via NodePoolConfig +} + func (c *MultiNodeConfig) PollInterval() time.Duration { return c.MultiNode.PollInterval.Duration() } diff --git a/multinode/ctx_test.go b/multinode/ctx_test.go index 8466325..35ca4d5 100644 --- a/multinode/ctx_test.go +++ b/multinode/ctx_test.go @@ -4,12 +4,10 @@ import ( "testing" "github.com/stretchr/testify/assert" - - "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" ) func TestContext(t *testing.T) { - ctx := tests.Context(t) + ctx := t.Context() assert.False(t, CtxIsHealthCheckRequest(ctx), "expected false for test context") ctx = CtxAddHealthCheckFlag(ctx) assert.True(t, CtxIsHealthCheckRequest(ctx), "expected context to contain the healthcheck flag") diff --git a/multinode/multi_node_test.go b/multinode/multi_node_test.go index fea7fe6..7993079 100644 --- a/multinode/multi_node_test.go +++ b/multinode/multi_node_test.go @@ -81,7 +81,7 @@ func TestMultiNode_Dial(t *testing.T) { selectionMode: NodeSelectionModeRoundRobin, chainID: RandomID(), }) - err := mn.Start(tests.Context(t)) + err := mn.Start(t.Context()) assert.ErrorContains(t, err, fmt.Sprintf("no available nodes for chain %s", mn.chainID)) }) t.Run("Fails with wrong node's chainID", func(t *testing.T) { @@ -97,7 +97,7 @@ func TestMultiNode_Dial(t *testing.T) { chainID: multiNodeChainID, nodes: []Node[ID, multiNodeRPCClient]{node}, }) - err := mn.Start(tests.Context(t)) + err := mn.Start(t.Context()) assert.ErrorContains(t, err, fmt.Sprintf("node %s has configured chain ID %s which does not match multinode configured chain ID of %s", nodeName, nodeChainID, mn.chainID)) }) t.Run("Fails if node fails", func(t *testing.T) { @@ -113,7 +113,7 @@ func TestMultiNode_Dial(t *testing.T) { chainID: chainID, nodes: []Node[ID, multiNodeRPCClient]{node}, }) - err := mn.Start(tests.Context(t)) + err := mn.Start(t.Context()) assert.ErrorIs(t, err, expectedError) }) @@ -132,7 +132,7 @@ func TestMultiNode_Dial(t *testing.T) { chainID: chainID, nodes: []Node[ID, multiNodeRPCClient]{node1, node2}, }) - err := mn.Start(tests.Context(t)) + err := mn.Start(t.Context()) assert.ErrorIs(t, err, expectedError) }) t.Run("Fails with wrong send only node's chainID", func(t *testing.T) { @@ -151,7 +151,7 @@ func TestMultiNode_Dial(t *testing.T) { nodes: []Node[ID, multiNodeRPCClient]{node}, sendonlys: []SendOnlyNode[ID, multiNodeRPCClient]{sendOnly}, }) - err := mn.Start(tests.Context(t)) + err := mn.Start(t.Context()) assert.ErrorContains(t, err, fmt.Sprintf("sendonly node %s has configured chain ID %s which does not match multinode configured chain ID of %s", sendOnlyName, sendOnlyChainID, mn.chainID)) }) @@ -178,7 +178,7 @@ func TestMultiNode_Dial(t *testing.T) { nodes: []Node[ID, multiNodeRPCClient]{node}, sendonlys: []SendOnlyNode[ID, multiNodeRPCClient]{sendOnly1, sendOnly2}, }) - err := mn.Start(tests.Context(t)) + err := mn.Start(t.Context()) assert.ErrorIs(t, err, expectedError) }) t.Run("Starts successfully with healthy nodes", func(t *testing.T) { @@ -192,7 +192,7 @@ func TestMultiNode_Dial(t *testing.T) { sendonlys: []SendOnlyNode[ID, multiNodeRPCClient]{newHealthySendOnly(t, chainID)}, }) servicetest.Run(t, mn) - selectedNode, err := mn.selectNode(tests.Context(t)) + selectedNode, err := mn.selectNode(t.Context()) require.NoError(t, err) assert.Equal(t, node, selectedNode) }) @@ -336,7 +336,7 @@ func TestMultiNode_selectNode(t *testing.T) { t.Parallel() t.Run("Returns same node, if it's still healthy", func(t *testing.T) { t.Parallel() - ctx := tests.Context(t) + ctx := t.Context() chainID := RandomID() node1 := newMockNode[ID, multiNodeRPCClient](t) node1.On("State").Return(nodeStateAlive).Once() @@ -360,7 +360,7 @@ func TestMultiNode_selectNode(t *testing.T) { }) t.Run("Updates node if active is not healthy", func(t *testing.T) { t.Parallel() - ctx := tests.Context(t) + ctx := t.Context() chainID := RandomID() oldBest := newMockNode[ID, multiNodeRPCClient](t) oldBest.On("String").Return("oldBest").Maybe() @@ -387,7 +387,7 @@ func TestMultiNode_selectNode(t *testing.T) { }) t.Run("No active nodes - reports critical error", func(t *testing.T) { t.Parallel() - ctx := tests.Context(t) + ctx := t.Context() chainID := RandomID() lggr, observedLogs := logger.TestObserved(t, zap.InfoLevel) mn := newTestMultiNode(t, multiNodeOpts{ diff --git a/multinode/node.go b/multinode/node.go index 6729459..5c83545 100644 --- a/multinode/node.go +++ b/multinode/node.go @@ -18,6 +18,7 @@ var errInvalidChainID = errors.New("invalid chain id") type NodeConfig interface { PollFailureThreshold() uint32 + PollSuccessThreshold() uint32 PollInterval() time.Duration SelectionMode() string SyncThreshold() uint32 diff --git a/multinode/node_lifecycle.go b/multinode/node_lifecycle.go index 8fd4e9d..9941d01 100644 --- a/multinode/node_lifecycle.go +++ b/multinode/node_lifecycle.go @@ -123,7 +123,10 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { lggr.Debugw("Ping successful", "nodeState", n.State()) n.metrics.RecordNodeClientVersion(ctx, n.name, version) n.metrics.IncrementPollsSuccess(ctx, n.name) - pollFailures = 0 + // Decay rather than reset; detects sustained failure rates above 1:1 + if pollFailures > 0 { + pollFailures-- + } } if pollFailureThreshold > 0 && pollFailures >= pollFailureThreshold { lggr.Errorw(fmt.Sprintf("RPC endpoint failed to respond to %d consecutive polls", pollFailures), "pollFailures", pollFailures, "nodeState", n.getCachedState()) @@ -356,7 +359,13 @@ func (n *node[CHAIN_ID, HEAD, RPC]) isOutOfSyncWithPool() (outOfSync bool, liveN } if outOfSync && n.getCachedState() == nodeStateAlive { - n.lfcLog.Errorw("RPC endpoint has fallen behind", "blockNumber", localChainInfo.BlockNumber, "bestLatestBlockNumber", ci.BlockNumber, "totalDifficulty", localChainInfo.TotalDifficulty) + n.lfcLog.Errorw( + "RPC endpoint has fallen behind", + "blockNumber", localChainInfo.BlockNumber, + "bestLatestBlockNumber", ci.BlockNumber, + "totalDifficulty", localChainInfo.TotalDifficulty, + "blockDifference", localChainInfo.BlockNumber-ci.BlockNumber, + ) } return outOfSync, ln } @@ -518,6 +527,39 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(syncIssues syncStatus) { } } +// probeUntilStable polls the node PollSuccessThreshold consecutive times before allowing it back into +// the alive pool. Returns true if all probes pass, false if any probe fails or ctx is cancelled. +// When threshold is 0 the probe is disabled and the function returns true immediately. +func (n *node[CHAIN_ID, HEAD, RPC]) probeUntilStable(ctx context.Context, lggr logger.Logger) bool { + threshold := n.nodePoolCfg.PollSuccessThreshold() + pollInterval := n.nodePoolCfg.PollInterval() + if threshold == 0 || pollInterval <= 0 { + return true + } + var successes uint32 + for successes < threshold { + select { + case <-ctx.Done(): + return false + case <-time.After(pollInterval): + } + n.metrics.IncrementPolls(ctx, n.name) + pollCtx, cancel := context.WithTimeout(ctx, pollInterval) + version, err := n.RPC().ClientVersion(pollCtx) + cancel() + if err != nil { + n.metrics.IncrementPollsFailed(ctx, n.name) + lggr.Warnw("Recovery probe poll failed; restarting redial", "err", err, "successesSoFar", successes, "threshold", threshold) + return false + } + n.metrics.IncrementPollsSuccess(ctx, n.name) + n.metrics.RecordNodeClientVersion(ctx, n.name, version) + successes++ + lggr.Debugw("Recovery probe poll succeeded", "successes", successes, "threshold", threshold) + } + return true +} + func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { defer n.wg.Done() ctx, cancel := n.newCtx() @@ -563,6 +605,11 @@ func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { n.setState(nodeStateUnreachable) continue case nodeStateAlive: + if !n.probeUntilStable(ctx, lggr) { + n.rpc.Close() + n.setState(nodeStateUnreachable) + continue + } lggr.Infow(fmt.Sprintf("Successfully redialled and verified RPC node %s. Node was offline for %s", n.String(), time.Since(unreachableAt)), "nodeState", n.getCachedState()) fallthrough default: diff --git a/multinode/node_lifecycle_test.go b/multinode/node_lifecycle_test.go index f1eb250..3ecdfcb 100644 --- a/multinode/node_lifecycle_test.go +++ b/multinode/node_lifecycle_test.go @@ -1,6 +1,7 @@ package multinode import ( + "context" "errors" "fmt" "math/big" @@ -110,7 +111,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { tests.AssertLogEventually(t, observedLogs, "Polling disabled") assert.Equal(t, nodeStateAlive, node.State()) }) - t.Run("stays alive while below pollFailureThreshold and resets counter on success", func(t *testing.T) { + t.Run("stays alive while below pollFailureThreshold, success decrements failure count", func(t *testing.T) { t.Parallel() rpc := newMockRPCClient[ID, Head](t) rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}) @@ -132,9 +133,9 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { // stays healthy while below threshold assert.Equal(t, nodeStateAlive, node.State()) }).Times(pollFailureThreshold - 1) - // 2. Successful call that is expected to reset counter + // 2. Successful call that is expected to decrement the counter (counter: 2 → 1) rpc.On("ClientVersion", mock.Anything).Return("", nil).Once() - // 3. Return error. If we have not reset the timer, we'll transition to nonAliveState + // 3. Return error. Counter was decremented (not reset), so it reaches 2 — still below threshold. rpc.On("ClientVersion", mock.Anything).Return("", pollError).Once() // 4. Once during the call, check if node is alive var ensuredAlive atomic.Bool @@ -176,6 +177,37 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { return nodeStateUnreachable == node.State() }) }) + t.Run("transitions to unreachable when net poll failures accumulate despite intermittent successes", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}) + const pollFailureThreshold = 3 + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{ + pollFailureThreshold: pollFailureThreshold, + pollInterval: tests.TestInterval, + }, + rpc: rpc, + }) + defer func() { assert.NoError(t, node.close()) }() + + pollError := errors.New("failed to get ClientVersion") + // Pattern F·F·S·F·F: with the decay counter the net failure debt reaches + // threshold=3 at the 5th poll (counter: 1→2→1→2→3). With the old + // reset-on-success behaviour the counter resets to 0 at S and peaks at only + // 2 before the next success, never tripping. + rpc.On("ClientVersion", mock.Anything).Return("", pollError).Times(2) + rpc.On("ClientVersion", mock.Anything).Return("", nil).Once() + rpc.On("ClientVersion", mock.Anything).Return("", pollError).Times(2) + // Unlimited successes after: ensures old code stays alive indefinitely so + // the test correctly fails (times out) when run against the old behaviour. + rpc.On("ClientVersion", mock.Anything).Return("", nil) + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + node.declareAlive() + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateUnreachable + }) + }) t.Run("with threshold poll failures, but we are the last node alive, forcibly keeps it alive", func(t *testing.T) { t.Parallel() rpc := newMockRPCClient[ID, Head](t) @@ -729,7 +761,7 @@ func writeHeads(t *testing.T, ch chan<- Head, heads ...head) { h := head.ToMockHead(t) select { case ch <- h: - case <-tests.Context(t).Done(): + case <-t.Context().Done(): return } } @@ -1464,6 +1496,208 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { return node.State() == nodeStateAlive }) }) + t.Run("with PollSuccessThreshold set, without isSyncing, node becomes alive once all probe polls succeed", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + nodeChainID := RandomID() + const pollSuccessThreshold = 2 + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + config: testNodeConfig{ + pollSuccessThreshold: pollSuccessThreshold, + pollInterval: tests.TestInterval, + }, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil).Once() + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + rpc.On("ClientVersion", mock.Anything).Return("", nil).Twice() + setupRPCForAliveLoop(t, rpc) + + node.declareUnreachable() + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateAlive + }) + }) + t.Run("with PollSuccessThreshold set, node becomes alive once all probe polls succeed", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + nodeChainID := RandomID() + const pollSuccessThreshold = 2 + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + config: testNodeConfig{ + nodeIsSyncingEnabled: true, + pollSuccessThreshold: pollSuccessThreshold, + pollInterval: tests.TestInterval, + }, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil).Once() + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + rpc.On("IsSyncing", mock.Anything).Return(false, nil) + rpc.On("ClientVersion", mock.Anything).Return("", nil).Twice() + setupRPCForAliveLoop(t, rpc) + + node.declareUnreachable() + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateAlive + }) + }) + t.Run("with PollSuccessThreshold set, probe poll failure keeps node unreachable and restarts redial", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + nodeChainID := RandomID() + lggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) + const pollSuccessThreshold = 2 + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + config: testNodeConfig{ + pollSuccessThreshold: pollSuccessThreshold, + pollInterval: tests.TestInterval, + }, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil).Once() + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + rpc.On("ClientVersion", mock.Anything).Return("", nil).Once() + rpc.On("ClientVersion", mock.Anything).Return("", errors.New("probe poll failed")).Once() + // after the probe aborts, rpc.Close() is called and the redial backoff fires again; keep failing + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")) + // guard: if current code (no probe) enters aliveLoop, fail the subscribe so the node returns to unreachable + rpc.On("SubscribeToHeads", mock.Anything).Return(nil, nil, errors.New("unexpected")).Maybe() + + node.declareUnreachable() + tests.AssertLogEventually(t, observedLogs, "Recovery probe poll failed; restarting redial") + assert.Equal(t, nodeStateUnreachable, node.State()) + }) +} + +func TestUnit_NodeLifecycle_probeUntilStable(t *testing.T) { + t.Parallel() + + t.Run("returns true immediately when pollInterval is zero, skipping probe", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + // ClientVersion is intentionally NOT mocked: if the guard is missing the loop fires + // immediately (time.After(0)) and calls ClientVersion, which makes the test fail. + node := newTestNode(t, testNodeOpts{ + rpc: rpc, + config: testNodeConfig{ + pollSuccessThreshold: 2, + pollInterval: 0, + }, + }) + result := node.probeUntilStable(t.Context(), logger.Test(t)) + assert.True(t, result) + }) + t.Run("returns true immediately when pollInterval is negative, skipping probe", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + // ClientVersion is intentionally NOT mocked: same reasoning as above. + node := newTestNode(t, testNodeOpts{ + rpc: rpc, + config: testNodeConfig{ + pollSuccessThreshold: 2, + pollInterval: -1, + }, + }) + result := node.probeUntilStable(t.Context(), logger.Test(t)) + assert.True(t, result) + }) + t.Run("returns true immediately when threshold is zero, skipping probe", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + // ClientVersion is intentionally NOT mocked: probing must be entirely skipped. + node := newTestNode(t, testNodeOpts{ + rpc: rpc, + config: testNodeConfig{ + pollSuccessThreshold: 0, + pollInterval: tests.TestInterval, + }, + }) + result := node.probeUntilStable(t.Context(), logger.Test(t)) + assert.True(t, result) + }) + t.Run("returns false when context is already cancelled", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + // ClientVersion must never be called: ctx is done before the first timer fires. + node := newTestNode(t, testNodeOpts{ + rpc: rpc, + config: testNodeConfig{ + pollSuccessThreshold: 2, + pollInterval: tests.TestInterval, + }, + }) + ctx, cancel := context.WithCancel(t.Context()) + cancel() + result := node.probeUntilStable(ctx, logger.Test(t)) + assert.False(t, result) + }) + t.Run("returns false when first poll fails", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + lggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) + node := newTestNode(t, testNodeOpts{ + rpc: rpc, + lggr: lggr, + config: testNodeConfig{ + pollSuccessThreshold: 2, + pollInterval: tests.TestInterval, + }, + }) + rpc.On("ClientVersion", mock.Anything).Return("", errors.New("rpc unavailable")).Once() + result := node.probeUntilStable(t.Context(), lggr) + assert.False(t, result) + tests.AssertLogEventually(t, observedLogs, "Recovery probe poll failed; restarting redial") + }) + t.Run("returns true when all threshold polls succeed", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) + const threshold = 3 + node := newTestNode(t, testNodeOpts{ + rpc: rpc, + lggr: lggr, + config: testNodeConfig{ + pollSuccessThreshold: threshold, + pollInterval: tests.TestInterval, + }, + }) + rpc.On("ClientVersion", mock.Anything).Return("v1.0.0", nil).Times(threshold) + result := node.probeUntilStable(t.Context(), lggr) + assert.True(t, result) + tests.AssertLogCountEventually(t, observedLogs, "Recovery probe poll succeeded", threshold) + }) + t.Run("returns false when a later probe poll fails, logging correct successesSoFar", func(t *testing.T) { + t.Parallel() + rpc := newMockRPCClient[ID, Head](t) + lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) + const threshold = 3 + node := newTestNode(t, testNodeOpts{ + rpc: rpc, + lggr: lggr, + config: testNodeConfig{ + pollSuccessThreshold: threshold, + pollInterval: tests.TestInterval, + }, + }) + rpc.On("ClientVersion", mock.Anything).Return("v1.0.0", nil).Times(threshold - 1) + rpc.On("ClientVersion", mock.Anything).Return("", errors.New("rpc unavailable")).Once() + result := node.probeUntilStable(t.Context(), lggr) + assert.False(t, result) + // threshold-1 successes logged before the failure + tests.AssertLogCountEventually(t, observedLogs, "Recovery probe poll succeeded", threshold-1) + tests.AssertLogEventually(t, observedLogs, "Recovery probe poll failed; restarting redial") + }) } func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { @@ -1612,7 +1846,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")) - err := node.Start(tests.Context(t)) + err := node.Start(t.Context()) require.NoError(t, err) tests.AssertLogEventually(t, observedLogs, "Dial failed: Node is unreachable") tests.AssertEventually(t, func() bool { @@ -1635,7 +1869,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { rpc.On("ChainID", mock.Anything).Run(func(_ mock.Arguments) { assert.Equal(t, nodeStateDialed, node.State()) }).Return(nodeChainID, errors.New("failed to get chain id")) - err := node.Start(tests.Context(t)) + err := node.Start(t.Context()) require.NoError(t, err) tests.AssertLogEventually(t, observedLogs, "Failed to verify chain ID for node") tests.AssertEventually(t, func() bool { @@ -1656,7 +1890,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil) rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) - err := node.Start(tests.Context(t)) + err := node.Start(t.Context()) require.NoError(t, err) tests.AssertEventually(t, func() bool { return node.State() == nodeStateInvalidChainID @@ -1682,7 +1916,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { }).Return(nodeChainID, nil).Once() rpc.On("IsSyncing", mock.Anything).Return(false, errors.New("failed to check syncing status")) rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")) - err := node.Start(tests.Context(t)) + err := node.Start(t.Context()) require.NoError(t, err) tests.AssertLogEventually(t, observedLogs, "Unexpected error while verifying RPC node synchronization status") tests.AssertEventually(t, func() bool { @@ -1704,7 +1938,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) rpc.On("IsSyncing", mock.Anything).Return(true, nil) - err := node.Start(tests.Context(t)) + err := node.Start(t.Context()) require.NoError(t, err) tests.AssertEventually(t, func() bool { return node.State() == nodeStateSyncing @@ -1725,7 +1959,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { rpc.On("IsSyncing", mock.Anything).Return(false, nil) setupRPCForAliveLoop(t, rpc) - err := node.Start(tests.Context(t)) + err := node.Start(t.Context()) require.NoError(t, err) tests.AssertEventually(t, func() bool { return node.State() == nodeStateAlive @@ -1744,7 +1978,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) setupRPCForAliveLoop(t, rpc) - err := node.Start(tests.Context(t)) + err := node.Start(t.Context()) require.NoError(t, err) tests.AssertEventually(t, func() bool { return node.State() == nodeStateAlive diff --git a/multinode/node_test.go b/multinode/node_test.go index e3c8d71..ce97b4c 100644 --- a/multinode/node_test.go +++ b/multinode/node_test.go @@ -16,6 +16,7 @@ import ( type testNodeConfig struct { pollFailureThreshold uint32 + pollSuccessThreshold uint32 pollInterval time.Duration selectionMode string syncThreshold uint32 @@ -34,6 +35,10 @@ func (n testNodeConfig) PollFailureThreshold() uint32 { return n.pollFailureThreshold } +func (n testNodeConfig) PollSuccessThreshold() uint32 { + return n.pollSuccessThreshold +} + func (n testNodeConfig) PollInterval() time.Duration { return n.pollInterval } diff --git a/multinode/poller_test.go b/multinode/poller_test.go index 85a46b8..6510024 100644 --- a/multinode/poller_test.go +++ b/multinode/poller_test.go @@ -20,7 +20,7 @@ func Test_Poller(t *testing.T) { lggr := logger.Test(t) t.Run("Test multiple start", func(t *testing.T) { - ctx := tests.Context(t) + ctx := t.Context() pollFunc := func(ctx context.Context) (Head, error) { return nil, nil } @@ -35,7 +35,7 @@ func Test_Poller(t *testing.T) { }) t.Run("Test polling for heads", func(t *testing.T) { - ctx := tests.Context(t) + ctx := t.Context() // Mock polling function that returns a new value every time it's called var pollNumber int pollLock := sync.Mutex{} @@ -65,7 +65,7 @@ func Test_Poller(t *testing.T) { }) t.Run("Test polling errors", func(t *testing.T) { - ctx := tests.Context(t) + ctx := t.Context() // Mock polling function that returns an error var pollNumber int pollLock := sync.Mutex{} @@ -97,7 +97,7 @@ func Test_Poller(t *testing.T) { }) t.Run("Test polling timeout", func(t *testing.T) { - ctx := tests.Context(t) + ctx := t.Context() pollFunc := func(ctx context.Context) (Head, error) { if <-ctx.Done(); true { return nil, ctx.Err() @@ -123,7 +123,7 @@ func Test_Poller(t *testing.T) { }) t.Run("Test unsubscribe during polling", func(t *testing.T) { - ctx := tests.Context(t) + ctx := t.Context() wait := make(chan struct{}) closeOnce := sync.OnceFunc(func() { close(wait) }) pollFunc := func(ctx context.Context) (Head, error) { @@ -172,7 +172,7 @@ func Test_Poller_Unsubscribe(t *testing.T) { } t.Run("Test multiple unsubscribe", func(t *testing.T) { - ctx := tests.Context(t) + ctx := t.Context() poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr) err := poller.Start(ctx) require.NoError(t, err) @@ -183,7 +183,7 @@ func Test_Poller_Unsubscribe(t *testing.T) { }) t.Run("Read channel after unsubscribe", func(t *testing.T) { - ctx := tests.Context(t) + ctx := t.Context() poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr) err := poller.Start(ctx) require.NoError(t, err) diff --git a/multinode/send_only_node_test.go b/multinode/send_only_node_test.go index d0eba66..f0a95ab 100644 --- a/multinode/send_only_node_test.go +++ b/multinode/send_only_node_test.go @@ -49,7 +49,7 @@ func TestStartSendOnlyNode(t *testing.T) { defer func() { assert.NoError(t, s.Close()) }() assert.Equal(t, nodeStateUndialed, s.State()) - err := s.Start(tests.Context(t)) + err := s.Start(t.Context()) require.NoError(t, err) tests.AssertEventually(t, func() bool { return s.State() == nodeStateUnusable }) @@ -65,7 +65,7 @@ func TestStartSendOnlyNode(t *testing.T) { defer func() { assert.NoError(t, s.Close()) }() assert.Equal(t, nodeStateUndialed, s.State()) - err := s.Start(tests.Context(t)) + err := s.Start(t.Context()) require.NoError(t, err) tests.AssertEventually(t, func() bool { return s.State() == nodeStateAlive }) @@ -87,7 +87,7 @@ func TestStartSendOnlyNode(t *testing.T) { defer func() { assert.NoError(t, s.Close()) }() assert.Equal(t, nodeStateUndialed, s.State()) - err := s.Start(tests.Context(t)) + err := s.Start(t.Context()) require.NoError(t, err) tests.AssertEventually(t, func() bool { return s.State() == nodeStateUnreachable }) @@ -113,7 +113,7 @@ func TestStartSendOnlyNode(t *testing.T) { defer func() { assert.NoError(t, s.Close()) }() assert.Equal(t, nodeStateUndialed, s.State()) - err := s.Start(tests.Context(t)) + err := s.Start(t.Context()) require.NoError(t, err) tests.AssertEventually(t, func() bool { return s.State() == nodeStateInvalidChainID }) @@ -136,7 +136,7 @@ func TestStartSendOnlyNode(t *testing.T) { s := NewSendOnlyNode(lggr, makeMockNodeMetrics(t), url.URL{}, t.Name(), configuredChainID, client) defer func() { assert.NoError(t, s.Close()) }() - err := s.Start(tests.Context(t)) + err := s.Start(t.Context()) require.NoError(t, err) tests.AssertEventually(t, func() bool { return s.State() == nodeStateAlive diff --git a/multinode/transaction_sender_test.go b/multinode/transaction_sender_test.go index 731eee9..1c4bbce 100644 --- a/multinode/transaction_sender_test.go +++ b/multinode/transaction_sender_test.go @@ -91,7 +91,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { t.Run("Fails if there is no nodes available", func(t *testing.T) { lggr := logger.Test(t) _, txSender := newTestTransactionSender(t, RandomID(), lggr, nil, nil) - _, _, err := txSender.SendTransaction(tests.Context(t), nil) + _, _, err := txSender.SendTransaction(t.Context(), nil) assert.EqualError(t, err, ErrNodeError.Error()) }) @@ -104,7 +104,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { []Node[ID, TestSendTxRPCClient]{mainNode}, []SendOnlyNode[ID, TestSendTxRPCClient]{newNode(t, errors.New("unexpected error"), nil)}) - _, code, err := txSender.SendTransaction(tests.Context(t), nil) + _, code, err := txSender.SendTransaction(t.Context(), nil) require.ErrorIs(t, err, expectedError) require.Equal(t, Fatal, code) tests.AssertLogCountEventually(t, observedLogs, "Node sent transaction", 2) @@ -119,7 +119,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { []Node[ID, TestSendTxRPCClient]{mainNode}, []SendOnlyNode[ID, TestSendTxRPCClient]{newNode(t, errors.New("unexpected error"), nil)}) - _, code, err := txSender.SendTransaction(tests.Context(t), nil) + _, code, err := txSender.SendTransaction(t.Context(), nil) require.NoError(t, err) require.Equal(t, Successful, code) tests.AssertLogCountEventually(t, observedLogs, "Node sent transaction", 2) @@ -127,7 +127,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { }) t.Run("Context expired before collecting sufficient results", func(t *testing.T) { - testContext, testCancel := context.WithCancel(tests.Context(t)) + testContext, testCancel := context.WithCancel(t.Context()) defer testCancel() mainNode := newNode(t, nil, func(_ mock.Arguments) { @@ -140,14 +140,14 @@ func TestTransactionSender_SendTransaction(t *testing.T) { _, txSender := newTestTransactionSender(t, RandomID(), lggr, []Node[ID, TestSendTxRPCClient]{mainNode}, nil) - requestContext, cancel := context.WithCancel(tests.Context(t)) + requestContext, cancel := context.WithCancel(t.Context()) cancel() _, _, err := txSender.SendTransaction(requestContext, nil) require.EqualError(t, err, "context canceled") }) t.Run("Context cancelled while sending results does not cause invariant violation", func(t *testing.T) { - requestContext, cancel := context.WithCancel(tests.Context(t)) + requestContext, cancel := context.WithCancel(t.Context()) mainNode := newNode(t, nil, func(_ mock.Arguments) { cancel() }) @@ -159,7 +159,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { lggr, makeMockMultiNodeMetrics(t), NodeSelectionModeRoundRobin, 0, []Node[ID, TestSendTxRPCClient]{mainNode}, nil, chainID, "chainFamily", 0)} txSender := NewTransactionSender[any, any, ID, TestSendTxRPCClient](lggr, chainID, mn.chainFamily, mn.MultiNode, makeMockTxSenderMetrics(t), func(err error) SendTxReturnCode { return 0 }, tests.TestInterval) - require.NoError(t, txSender.Start(tests.Context(t))) + require.NoError(t, txSender.Start(t.Context())) _, _, err := txSender.SendTransaction(requestContext, nil) require.EqualError(t, err, "context canceled") @@ -173,7 +173,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { fastNode := newNode(t, expectedError, nil) // hold reply from the node till end of the test - testContext, testCancel := context.WithCancel(tests.Context(t)) + testContext, testCancel := context.WithCancel(t.Context()) defer testCancel() slowNode := newNode(t, errors.New("transaction failed"), func(_ mock.Arguments) { // block caller til end of the test @@ -183,14 +183,14 @@ func TestTransactionSender_SendTransaction(t *testing.T) { lggr := logger.Test(t) _, txSender := newTestTransactionSender(t, chainID, lggr, []Node[ID, TestSendTxRPCClient]{fastNode, slowNode}, nil) - _, _, err := txSender.SendTransaction(tests.Context(t), nil) + _, _, err := txSender.SendTransaction(t.Context(), nil) require.EqualError(t, err, expectedError.Error()) }) t.Run("Returns success without waiting for the rest of the nodes", func(t *testing.T) { chainID := RandomID() fastNode := newNode(t, nil, nil) // hold reply from the node till end of the test - testContext, testCancel := context.WithCancel(tests.Context(t)) + testContext, testCancel := context.WithCancel(t.Context()) defer testCancel() slowNode := newNode(t, errors.New("transaction failed"), func(_ mock.Arguments) { // block caller til end of the test @@ -205,7 +205,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { []Node[ID, TestSendTxRPCClient]{fastNode, slowNode}, []SendOnlyNode[ID, TestSendTxRPCClient]{slowSendOnly}) - _, code, err := txSender.SendTransaction(tests.Context(t), nil) + _, code, err := txSender.SendTransaction(t.Context(), nil) require.NoError(t, err) require.Equal(t, Successful, code) }) @@ -214,7 +214,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { fastNode := newNode(t, nil, nil) fastNode.On("ConfiguredChainID").Return(chainID).Maybe() // hold reply from the node till end of the test - testContext, testCancel := context.WithCancel(tests.Context(t)) + testContext, testCancel := context.WithCancel(t.Context()) defer testCancel() slowNode := newNode(t, errors.New("transaction failed"), func(_ mock.Arguments) { // block caller til end of the test @@ -233,16 +233,16 @@ func TestTransactionSender_SendTransaction(t *testing.T) { []Node[ID, TestSendTxRPCClient]{fastNode, slowNode}, []SendOnlyNode[ID, TestSendTxRPCClient]{slowSendOnly}) - require.NoError(t, mn.Start(tests.Context(t))) + require.NoError(t, mn.Start(t.Context())) require.NoError(t, mn.Close()) - _, _, err := txSender.SendTransaction(tests.Context(t), nil) + _, _, err := txSender.SendTransaction(t.Context(), nil) require.EqualError(t, err, "service is stopped") }) t.Run("Fails when closed", func(t *testing.T) { chainID := RandomID() fastNode := newNode(t, nil, nil) // hold reply from the node till end of the test - testContext, testCancel := context.WithCancel(tests.Context(t)) + testContext, testCancel := context.WithCancel(t.Context()) defer testCancel() slowNode := newNode(t, errors.New("transaction failed"), func(_ mock.Arguments) { // block caller til end of the test @@ -256,7 +256,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { var txSender *TransactionSender[any, any, ID, TestSendTxRPCClient] t.Cleanup(func() { // after txSender.Close() - _, _, err := txSender.SendTransaction(tests.Context(t), nil) + _, _, err := txSender.SendTransaction(t.Context(), nil) assert.EqualError(t, err, "TransactionSender not started") }) @@ -275,7 +275,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { []Node[ID, TestSendTxRPCClient]{primary}, []SendOnlyNode[ID, TestSendTxRPCClient]{sendOnly}) - _, _, err := txSender.SendTransaction(tests.Context(t), nil) + _, _, err := txSender.SendTransaction(t.Context(), nil) assert.EqualError(t, err, ErrNodeError.Error()) }) @@ -294,7 +294,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { []Node[ID, TestSendTxRPCClient]{mainNode, unhealthyNode}, []SendOnlyNode[ID, TestSendTxRPCClient]{unhealthySendOnlyNode}) - _, code, err := txSender.SendTransaction(tests.Context(t), nil) + _, code, err := txSender.SendTransaction(t.Context(), nil) require.NoError(t, err) require.Equal(t, Successful, code) }) @@ -304,7 +304,7 @@ func TestTransactionSender_SendTransaction(t *testing.T) { fastNode := newNode(t, expectedError, nil) // hold reply from the node till SendTransaction returns result - sendTxContext, sendTxCancel := context.WithCancel(tests.Context(t)) + sendTxContext, sendTxCancel := context.WithCancel(t.Context()) slowNode := newNode(t, errors.New("transaction failed"), func(_ mock.Arguments) { <-sendTxContext.Done() })