diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index d33794006..37d467f8a 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -712,17 +712,24 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq pure (clnt, sess) newProxiedRelay :: SMPConnectedClient -> Maybe SMP.BasicAuth -> ProxiedRelayVar -> AM (Either AgentErrorType ProxiedRelay) newProxiedRelay (SMPConnectedClient smp prs) proxyAuth rv = - tryAllErrors (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case - Right sess -> do - atomically $ putTMVar (sessionVar rv) (Right sess) - pure $ Right sess - Left e -> do - atomically $ do - unless (serverHostError e) $ do - removeSessVar rv destSrv prs - TM.delete destSess smpProxiedRelays - putTMVar (sessionVar rv) (Left e) - pure $ Left e + -- as in newProtocolClient: drop the empty proxied-relay var (and release waiters) if the + -- connect to the destination relay is interrupted before the var is filled. + ExceptT $ runExceptT connect `E.onException` clearOnInterrupt + where + connect = + tryAllErrors (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case + Right sess -> do + atomically $ putTMVar (sessionVar rv) (Right sess) + pure $ Right sess + Left e -> do + atomically $ do + unless (serverHostError e) $ do + removeSessVar rv destSrv prs + TM.delete destSess smpProxiedRelays + putTMVar (sessionVar rv) (Left e) + pure $ Left e + clearOnInterrupt = + liftIO $ atomically $ clearSessVarOnInterrupt prs destSrv rv (Left (BROKER (B.unpack $ strEncode destSrv) TIMEOUT)) waitForProxiedRelay :: SMPTransportSession -> ProxiedRelayVar -> AM (Either AgentErrorType ProxiedRelay) waitForProxiedRelay (_, srv, _) rv = do NetworkConfig {tcpConnectTimeout} <- getNetworkConfig c @@ -922,22 +929,30 @@ newProtocolClient :: ClientVar msg -> AM (Client msg) newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v = - tryAllErrors (connectClient v) >>= \case - Right client -> do - logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")" - atomically $ putTMVar (sessionVar v) (Right client) - liftIO $ nonBlockingWriteTBQueue (subQ c) ("", "", AEvt SAENone $ hostEvent CONNECT client) - pure client - Left e -> do - ei <- asks $ persistErrorInterval . config - if ei == 0 - then atomically $ do - removeSessVar v tSess clients - putTMVar (sessionVar v) (Left (e, Nothing)) - else do - ts <- addUTCTime ei <$> liftIO getCurrentTime - atomically $ putTMVar (sessionVar v) (Left (e, Just ts)) - throwE e -- signal error to caller + -- connectClient is interruptible; tryAllErrors rethrows async exceptions, so if the thread is + -- killed before the session var is filled below it would be left empty in the map forever. + -- Run cleanup at the underlying monad (AM is not MonadUnliftIO) to release waiters and drop it. + ExceptT $ runExceptT connect `E.onException` clearOnInterrupt + where + connect = + tryAllErrors (connectClient v) >>= \case + Right client -> do + logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")" + atomically $ putTMVar (sessionVar v) (Right client) + liftIO $ nonBlockingWriteTBQueue (subQ c) ("", "", AEvt SAENone $ hostEvent CONNECT client) + pure client + Left e -> do + ei <- asks $ persistErrorInterval . config + if ei == 0 + then atomically $ do + removeSessVar v tSess clients + putTMVar (sessionVar v) (Left (e, Nothing)) + else do + ts <- addUTCTime ei <$> liftIO getCurrentTime + atomically $ putTMVar (sessionVar v) (Left (e, Just ts)) + throwE e -- signal error to caller + clearOnInterrupt = + liftIO $ atomically $ clearSessVarOnInterrupt clients tSess v (Left (BROKER (B.unpack $ strEncode srv) TIMEOUT, Nothing)) hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> AEvent 'AENone) -> Client msg -> AEvent 'AENone hostEvent event = hostEvent' event . protocolClient diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 76b2a7cf9..f290f463f 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -226,7 +226,9 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke Nothing -> throwE PCEResponseTimeout newSMPClient :: SMPClientVar -> IO (Either SMPClientError (OwnServer, SMPClient)) - newSMPClient v = do + -- if the thread is killed before the session var is filled below, release waiters with an + -- error and drop the empty var instead of leaking it. + newSMPClient v = (`E.onException` clearOnInterrupt) $ do r <- connectClient ca srv v `E.catches` clientHandlers case r of Right smp -> do @@ -249,6 +251,8 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke atomically $ putTMVar (sessionVar v) (Left (e, Just ts)) reconnectClient ca srv pure $ Left e + where + clearOnInterrupt = atomically $ clearSessVarOnInterrupt smpClients srv v (Left (PCEResponseTimeout, Nothing)) isOwnServer :: SMPClientAgent p -> SMPServer -> OwnServer isOwnServer SMPClientAgent {agentCfg} ProtocolServer {host} = diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 02429e910..a5542b086 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -648,12 +648,15 @@ getOrCreatePushWorker :: NtfPushServer -> (Maybe T.Text, PushProvider) -> OwnSer getOrCreatePushWorker s@NtfPushServer {pushWorkers, pushWorkerSeq, pushQSize} key@(srvHost_, _) isOwn = do ts <- liftIO getCurrentTime atomically (getSessVar pushWorkerSeq key pushWorkers ts) >>= \case - Left v -> do + -- drop the empty worker var if the create is interrupted, else it leaks and blocks all lookups + Left v -> (`onException` clearOnInterrupt v) $ do q <- liftIO $ newTBQueueIO pushQSize tId <- mkWeakThreadId =<< forkIO (runPushWorker s srvHost_ isOwn q) atomically $ putTMVar (sessionVar v) PushWorker {workerQ = q, workerThreadId = tId} pure q Right v -> workerQ <$> atomically (readTMVar $ sessionVar v) + where + clearOnInterrupt v = atomically $ whenM (isEmptyTMVar $ sessionVar v) $ removeSessVar v key pushWorkers runPushWorker :: NtfPushServer -> Maybe T.Text -> OwnServer -> TBQueue (NtfTknRec, PushNotification) -> M () runPushWorker s srvHost_ isOwn q = forever $ do diff --git a/src/Simplex/Messaging/Session.hs b/src/Simplex/Messaging/Session.hs index ff5d7e0a0..f0aaaa943 100644 --- a/src/Simplex/Messaging/Session.hs +++ b/src/Simplex/Messaging/Session.hs @@ -6,6 +6,7 @@ module Simplex.Messaging.Session ( SessionVar (..), getSessVar, removeSessVar, + clearSessVarOnInterrupt, tryReadSessVar, ) where @@ -13,7 +14,7 @@ import Control.Concurrent.STM import Data.Time (UTCTime) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (($>>=)) +import Simplex.Messaging.Util (whenM, ($>>=)) data SessionVar a = SessionVar { sessionVar :: TMVar a, @@ -38,5 +39,13 @@ removeSessVar v sessKey vs = Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs _ -> pure () +-- | For an exception handler around the action that fills a freshly created session var: +-- if the var is still empty (the action was interrupted before filling it, e.g. by an async +-- exception during connect), fill it with `onInterrupt` to release waiters, then remove it from +-- the map so the next request creates a fresh session. A no-op if the action already filled it. +clearSessVarOnInterrupt :: Ord k => TMap k (SessionVar a) -> k -> SessionVar a -> a -> STM () +clearSessVarOnInterrupt vs sessKey v onInterrupt = + whenM (tryPutTMVar (sessionVar v) onInterrupt) $ removeSessVar v sessKey vs + tryReadSessVar :: Ord k => k -> TMap k (SessionVar a) -> STM (Maybe a) tryReadSessVar sessKey vs = TM.lookup sessKey vs $>>= (tryReadTMVar . sessionVar) diff --git a/tests/CoreTests/UtilTests.hs b/tests/CoreTests/UtilTests.hs index 580f4e9b0..c274cbf32 100644 --- a/tests/CoreTests/UtilTests.hs +++ b/tests/CoreTests/UtilTests.hs @@ -50,6 +50,13 @@ utilTests = do runExceptT (tryAllErrors throwTestException) `shouldReturn` Right (Left (TestException "user error (error)")) it "should return no errors as Right" $ runExceptT (tryAllErrors noErrors) `shouldReturn` Right (Right "no errors") + -- tryAllErrors rethrows asynchronous exceptions (it uses UnliftIO.catch). Any recovery placed + -- after `tryAllErrors action` - e.g. putTMVar to fill a SessionVar - is therefore SKIPPED when + -- the thread is killed mid-action. Unlike tryAllOwnErrors, it also rethrows the overflow exceptions. + it "should rethrow ThreadKilled" $ + runExceptT (tryAllErrors $ throwAsync ThreadKilled) `shouldThrow` (\e -> e == ThreadKilled) + it "should rethrow StackOverflow" $ + runExceptT (tryAllErrors $ throwAsync StackOverflow) `shouldThrow` (\e -> e == StackOverflow) describe "catchAllErrors" $ do it "should catch ExceptT error" $ runExceptT (throwTestError `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestError \"error\"" diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index d043fd3c8..85bd08a1b 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -25,7 +25,7 @@ import Network.Socket import qualified Network.TLS as TLS import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) -import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultNetworkConfig) +import Simplex.Messaging.Client (NetworkConfig (..), NetworkTimeout (..), ProtocolClientConfig (..), chooseTransportHost, defaultNetworkConfig) import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding @@ -339,6 +339,16 @@ proxyCfgJ2QS = \case SQSMemory -> journalCfg (proxyCfgMS $ ASType SQSMemory SMSJournal) testStoreLogFile2 testStoreMsgsDir2 SQSPostgres -> journalCfgDB (proxyCfgMS $ ASType SQSPostgres SMSJournal) testStoreDBOpts2 testStoreMsgsDir2 +-- Proxy config with a short relay-connection timeout, to bound how long a failing +-- proxy->relay connection attempt blocks in the relay reconnection tests. +proxyCfgShortTimeout :: AServerConfig +proxyCfgShortTimeout = + updateCfg proxyCfg $ \cfg' -> + let aCfg = smpAgentCfg cfg' + cCfg = smpCfg aCfg + nt = NetworkTimeout {backgroundTimeout = 4_000000, interactiveTimeout = 4_000000} + in cfg' {smpAgentCfg = aCfg {smpCfg = cCfg {networkConfig = (networkConfig cCfg) {tcpConnectTimeout = nt}}}} + proxyVRangeV8 :: VersionRangeSMP proxyVRangeV8 = mkVersionRange minServerSMPRelayVersion sendingProxySMPVersion @@ -383,6 +393,15 @@ serverBracket process afterProcess f = do Nothing -> error $ "server did not " <> s _ -> pure () +-- A TCP server that accepts connections but never performs a TLS handshake, so a client +-- connecting to it stays blocked in the TLS handshake until its connection timeout. +withStallingServerOn :: HasCallStack => ServiceName -> IO a -> IO a +withStallingServerOn port action = + serverBracket + (\started -> runLocalTCPServer started port (\_ -> threadDelay maxBound)) + (pure ()) + (const action) + withSmpServerOn :: HasCallStack => (ASrvTransport, AStoreType) -> ServiceName -> IO a -> IO a withSmpServerOn ps port' = withSmpServerThreadOn ps port' . const diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index 0d8ccdf89..88525eac0 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -58,6 +58,14 @@ smpProxyTests = do describe "server configuration" $ do it "refuses proxy handshake unless enabled" testNoProxy it "checks basic auth in proxy requests" testProxyAuth + describe "relay reconnection" $ do + it "recovers when unresponsive relay restarts (control, no disconnect)" $ \_ -> + testProxyRecoversWithoutDisconnect + it "reconnects to relay after sender disconnects mid-connection" $ \_ -> + testProxyReconnectAfterRelayRestart + describe "agent client reconnection" $ do + it "reconnects after a connect is cancelled mid-flight" $ \_ -> + testAgentClientReconnectAfterCancel describe "proxy requests" $ do describe "bad relay URIs" $ do xit "host not resolved" todo @@ -447,6 +455,89 @@ testProxyAuth msType = do where proxyCfgAuth = updateCfg (proxyCfgMS msType) $ \cfg_ -> cfg_ {newQueueBasicAuth = Just "correct"} +-- Connect a sender client to the proxy and request a relay session to testSMPServer2 (PRXY). +-- On success the reply is PKEY; otherwise it is the proxy error for the relay connection. +requestRelaySession :: IO (Either SMP.ErrorType SMP.BrokerMsg) +requestRelaySession = + testSMPClient_ "localhost" testPort proxyVRangeV8 Nothing $ \(th :: THandleSMP TLS 'TClient) -> + (\(_, _, reply) -> reply) <$> sendRecv th (Nothing, "1", NoEntity, SMP.PRXY testSMPServer2 Nothing) + +-- Shared "phase 2" of the reconnection tests: start a healthy relay, confirm it is reachable +-- directly (PING, not via the proxy) so a proxy failure can only mean the proxy didn't reconnect, +-- let any stored connection error expire, then require the proxy to establish the session (PKEY). +requireProxyReconnect :: IO () +requireProxyReconnect = + withSmpServerConfigOn (transport @TLS) proxyCfgJ2 testPort2 $ \_ -> do + testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 Nothing $ \(th :: THandleSMP TLS 'TClient) -> do + (_, _, reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PING) + reply `shouldBe` Right SMP.PONG + threadDelay 1500000 -- > persistErrorInterval (1s), so the stored connection error has expired + requestRelaySession >>= \case + Right SMP.PKEY {} -> pure () + reply -> expectationFailure $ "proxy failed to reach the healthy relay; expected PKEY, got: " <> show reply + +-- Control: same stalling relay and proxy config as the bug test, but the sender stays connected. +-- The connect fails by timing out (storing a Left error that self-heals via persistErrorInterval), +-- so once a healthy relay is running the proxy reconnects. This proves the stalling relay alone +-- does not cause the permanent failure - only the mid-connection disconnect does. +testProxyRecoversWithoutDisconnect :: IO () +testProxyRecoversWithoutDisconnect = + withSmpServerConfigOn (transport @TLS) proxyCfgShortTimeout testPort $ \_ -> do + withStallingServerOn testPort2 $ + requestRelaySession >>= \case + Right (SMP.ERR (SMP.PROXY (SMP.BROKER _))) -> pure () + reply -> expectationFailure $ "expected a proxy broker error from the unresponsive relay, got: " <> show reply + requireProxyReconnect + +-- Reproduces the production bug: an SMP proxy permanently fails to reconnect to a destination +-- relay after the relay restarts (logs: repeated PCEResponseTimeout). +-- +-- A PRXY request makes the proxy worker (forked via forkClient, registered in the sender's +-- endThreads) insert an empty SessionVar into smpClients and then block in connectClient. If the +-- sender disconnects while that connect is in flight, clientDisconnected kills the worker; +-- clientHandlers re-throws the async exception, so the SessionVar is never filled. Nothing removes +-- an empty SessionVar, so every later request waits the connection timeout on it - PROXY (BROKER +-- TIMEOUT) - forever, even once the relay is healthy again. +-- +-- The stalling relay (accepts TCP, never completes TLS) holds the connect open long enough to +-- interleave the disconnect. Phase 2 (requireProxyReconnect) is identical to the control above; +-- the only difference is this disconnect. +testProxyReconnectAfterRelayRestart :: IO () +testProxyReconnectAfterRelayRestart = + withSmpServerConfigOn (transport @TLS) proxyCfgShortTimeout testPort $ \_ -> do + -- disconnect the sender 1s into the 4s connect to the stalling relay, killing the in-flight worker + withStallingServerOn testPort2 $ + race_ (threadDelay 1000000) requestRelaySession + requireProxyReconnect + +-- Bug B (same root cause as the proxy, in the messaging agent): getSMPServerClient inserts an +-- empty SessionVar into smpClients, then connects inside newProtocolClient's tryAllErrors, which +-- rethrows async exceptions. If the connecting thread is cancelled mid-connect, putTMVar is +-- skipped and the empty var is left in smpClients, so every later connection to that server times +-- out on it. Phase 1 cancels a connect to a stalling relay; phase 2 requires a fresh connect to a +-- healthy relay to succeed. +testAgentClientReconnectAfterCancel :: IO () +testAgentClientReconnectAfterCancel = + withAgent 1 agentCfg agentServersLeak testDB $ \a -> do + withStallingServerOn testPort2 $ do + t <- async $ runExceptT $ A.createConnection a NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + threadDelay 1000000 -- let the connect to the stalling relay start, then kill it mid-flight + cancel t + withSmpServerConfigOn (transport @TLS) cfgJ2 testPort2 $ \_ -> do + testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 Nothing $ \(th :: THandleSMP TLS 'TClient) -> do + (_, _, reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PING) + reply `shouldBe` Right SMP.PONG -- the relay is up and reachable, so a timeout can only be the poisoned var + r <- timeout 8000000 $ runExceptT $ A.createConnection a NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + case r of + Just (Right _) -> pure () + _ -> expectationFailure $ "agent failed to connect after a cancelled connect; got: " <> show r + where + agentServersLeak = + initAgentServers + { smp = userServers [testSMPServer2], + netCfg = (netCfg initAgentServers) {tcpConnectTimeout = NetworkTimeout 4000000 4000000} + } + todo :: AStoreType -> IO () todo _ = fail "TODO"