diff --git a/CHANGELOG.md b/CHANGELOG.md index 72394c45e0..0d7a48db2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ All notable changes to this project will be documented in this file. From versio ### Fixed - Fix unnecessary connection pool flushes during schema cache reloading by @mkleczek in #4645 +- Fix race condition in pool_available metric causing negative values during network instability by @mkleczek in #4622 +- Limit concurrent schema cache loads by @mkleczek in #4643 ## [14.9] - 2026-04-10 diff --git a/postgrest.cabal b/postgrest.cabal index 2f100b3e92..807e104383 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -162,6 +162,7 @@ library , stm-hamt >= 1.2 && < 2 , focus >= 1.0 && < 2 , some >= 1.0.4.1 && < 2 + , uuid >= 1.3 && < 2 -- -fno-spec-constr may help keep compile time memory use in check, -- see https://gitlab.haskell.org/ghc/ghc/issues/16017#note_219304 -- -optP-Wno-nonportable-include-path diff --git a/src/PostgREST/AppState.hs b/src/PostgREST/AppState.hs index fdaac1a908..b2b0b8c01e 100644 --- a/src/PostgREST/AppState.hs +++ b/src/PostgREST/AppState.hs @@ -1,7 +1,9 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE RecursiveDo #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE RecursiveDo #-} +{-# LANGUAGE TypeApplications #-} module PostgREST.AppState ( AppState @@ -33,7 +35,8 @@ import qualified Data.ByteString.Char8 as BS import Data.Either.Combinators (whenLeft) import qualified Hasql.Pool as SQL import qualified Hasql.Pool.Config as SQL -import qualified Hasql.Session as SQL +import qualified Hasql.Session as SQL hiding (statement) +import qualified Hasql.Transaction as SQL hiding (sql) import qualified Hasql.Transaction.Sessions as SQL import qualified Network.HTTP.Types.Status as HTTP import qualified PostgREST.Auth.JwtCache as JwtCache @@ -63,11 +66,17 @@ import PostgREST.Config.Database (queryDbSettings, import PostgREST.Config.PgVersion (PgVersion (..), minimumPgVersion) import PostgREST.Debounce (makeDebouncer) +import PostgREST.Metrics (MetricsState (connTrack)) import PostgREST.SchemaCache (SchemaCache (..), querySchemaCache, showSummary) import PostgREST.SchemaCache.Identifiers (quoteQi) +import qualified Hasql.Decoders as HD +import qualified Hasql.Encoders as HE +import qualified Hasql.Statement as SQL +import NeatInterpolation (trimming) + import Protolude data AppState = AppState @@ -303,7 +312,7 @@ getObserver = stateObserver -- + Because connections cache the pg catalog(see #2620) -- + For rapid recovery. Otherwise, the pool idle or lifetime timeout would have to be reached for new healthy connections to be acquired. retryingSchemaCacheLoad :: AppState -> IO () -retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThreadId=mainThreadId} = +retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThreadId=mainThreadId, stateMetrics} = void $ retrying retryPolicy shouldRetry (\RetryStatus{rsIterNumber, rsPreviousDelay} -> do when (rsIterNumber > 0) $ do let delay = fromMaybe 0 rsPreviousDelay `div` oneSecondInUs @@ -342,8 +351,22 @@ retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThrea qSchemaCache :: IO (Maybe SchemaCache) qSchemaCache = do conf@AppConfig{..} <- getConfig appState + -- Throttle concurrent schema cache loads, guarded by advisory locks. + -- This is to prevent thundering herd problem on startup or when many PostgREST + -- instances receive "reload schema" notifications at the same time + -- See get_lock_sql for details of the algorithm. + -- Here we calculate the number of open connections passed to the query. + Metrics.ConnStats connected inUse <- Metrics.connectionCounts $ connTrack stateMetrics + -- Determine whether schema cache loading will create a new session + let + -- if all connections in use but pool not full - schema cache loading will create session + scLoadingSessions = if connected <= inUse && inUse < configDbPoolSize then 1 else 0 + withTxLock = SQL.statement + (fromIntegral $ connected + scLoadingSessions) + (SQL.Statement get_lock_sql get_lock_params HD.noResult configDbPreparedStatements) + (resultTime, result) <- - timeItT $ usePool appState (SQL.transactionNoRetry SQL.ReadCommitted SQL.Read $ querySchemaCache conf) + timeItT $ usePool appState (SQL.transactionNoRetry SQL.ReadCommitted SQL.Read $ withTxLock *> querySchemaCache conf) case result of Left e -> do markSchemaCachePending appState @@ -365,6 +388,43 @@ retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThrea observer $ SchemaCacheLoadedObs loadTime summary markSchemaCacheLoaded appState return $ Just sCache + where + -- Recursive query that tries acquiring locks in order + -- and waits for randomly selected lock if no attempt succeeded. + -- It has a single parameter: this node open connection count. + -- It is used to estimate the number of nodes + -- by counting the number of active sessions for current session_user + -- and dividing it by this node open connections. + -- Assuming load is uniform among cluster nodes, all should have + -- statistically the same number of open connections. + -- Once the number of nodes is known we calculate the number + -- of locks as ceil(log(2, number_of_nodes)) + get_lock_sql = encodeUtf8 [trimming| + WITH RECURSIVE attempts AS ( + SELECT 1 AS lock_number, pg_try_advisory_xact_lock(lock_id, 1) AS success FROM parameters + UNION ALL + SELECT next_lock_number AS lock_number, pg_try_advisory_xact_lock(lock_id, next_lock_number) AS success + FROM + parameters CROSS JOIN LATERAL ( + SELECT lock_number + 1 AS next_lock_number FROM attempts + WHERE NOT success AND lock_number < locks_count + ORDER BY lock_number DESC + LIMIT 1 + ) AS previous_attempt + ), + counts AS ( + SELECT round(log(2, round(count(*)::double precision/$$1)::numeric))::int AS locks_count + FROM + pg_stat_activity WHERE usename = SESSION_USER + ), + parameters AS ( + SELECT locks_count, 50168275 AS lock_id FROM counts WHERE locks_count > 0 + ) + SELECT pg_advisory_xact_lock(lock_id, floor(random() * locks_count)::int + 1) + FROM + parameters WHERE NOT EXISTS (SELECT 1 FROM attempts WHERE success) |] + + get_lock_params = HE.param (HE.nonNullable HE.int4) shouldRetry :: RetryStatus -> (Maybe PgVersion, Maybe SchemaCache) -> IO Bool shouldRetry _ (pgVer, sCache) = do diff --git a/src/PostgREST/Metrics.hs b/src/PostgREST/Metrics.hs index 75db34b4d6..4a3ad91373 100644 --- a/src/PostgREST/Metrics.hs +++ b/src/PostgREST/Metrics.hs @@ -5,7 +5,10 @@ Description : Metrics based on the Observation module. See Observation.hs. -} module PostgREST.Metrics ( init + , ConnTrack + , ConnStats (..) , MetricsState (..) + , connectionCounts , observationMetrics , metricsToText ) where @@ -19,13 +22,18 @@ import qualified Prometheus.Metric.GHC as PMG import PostgREST.Observation - -import Protolude +import Control.Arrow ((&&&)) +import Data.Bitraversable (bisequenceA) +import Data.Tuple.Extra (both) +import Data.UUID (UUID) +import qualified Focus +import Protolude +import qualified StmHamt.SizedHamt as SH data MetricsState = MetricsState { poolTimeouts :: Counter, - poolAvailable :: Gauge, + connTrack :: ConnTrack, poolWaiting :: Gauge, poolMaxSize :: Gauge, schemaCacheLoads :: Vector Label1 Counter, @@ -40,7 +48,7 @@ init configDbPoolSize = do whenM getRTSStatsEnabled $ void $ register PMG.ghcMetrics metricState <- MetricsState <$> register (counter (Info "pgrst_db_pool_timeouts_total" "The total number of pool connection timeouts")) <*> - register (gauge (Info "pgrst_db_pool_available" "Available connections in the pool")) <*> + register (Metric ((identity &&& dbPoolAvailable) <$> connectionTracker)) <*> register (gauge (Info "pgrst_db_pool_waiting" "Requests waiting to acquire a pool connection")) <*> register (gauge (Info "pgrst_db_pool_max" "Max pool connections")) <*> register (vector "status" $ counter (Info "pgrst_schema_cache_loads_total" "The total number of times the schema cache was loaded")) <*> @@ -50,20 +58,28 @@ init configDbPoolSize = do register (counter (Info "pgrst_jwt_cache_evictions_total" "The total number of JWT cache evictions")) setGauge (poolMaxSize metricState) (fromIntegral configDbPoolSize) pure metricState + where + dbPoolAvailable = (pure . noLabelsGroup (Info "pgrst_db_pool_available" "Available connections in the pool") GaugeType . calcAvailable <$>) . connectionCounts + where + calcAvailable = liftA2 (-) connected inUse + toSample name labels = Sample name labels . encodeUtf8 . show + noLabelsGroup info sampleType = SampleGroup info sampleType . pure . toSample (metricName info) mempty -- Only some observations are used as metrics observationMetrics :: MetricsState -> ObservationHandler observationMetrics MetricsState{..} obs = case obs of PoolAcqTimeoutObs -> do incCounter poolTimeouts - (HasqlPoolObs (SQL.ConnectionObservation _ status)) -> case status of - SQL.ReadyForUseConnectionStatus _ -> do - incGauge poolAvailable - SQL.InUseConnectionStatus -> do - decGauge poolAvailable - SQL.TerminatedConnectionStatus _ -> do - decGauge poolAvailable - SQL.ConnectingConnectionStatus -> pure () + -- Handle pool observations with connection tracking + -- this is necessary because it is not possible + -- to accurately maintain open/in use conneciton counts + -- statelessly based only on pool observation events. + -- The reason is that hasql-pool emits TerminatedConnectionStatus + -- both for connections successfully established and failed when connecting. + -- When receiving TerminatedConnectionStatus we have to find out + -- if we can decrement established connection count. To do that we have to track + -- established connections. + (HasqlPoolObs sqlObs) -> trackConnections connTrack sqlObs PoolRequest -> incGauge poolWaiting PoolRequestFullfilled -> @@ -81,3 +97,28 @@ observationMetrics MetricsState{..} obs = case obs of metricsToText :: IO LBS.ByteString metricsToText = exportMetricsAsText + +data ConnStats = ConnStats { + connected :: Int, + inUse :: Int +} deriving (Eq, Show) + +data ConnTrack = ConnTrack { connTrackConnected :: SH.SizedHamt UUID, connTrackInUse :: SH.SizedHamt UUID } + +connectionTracker :: IO ConnTrack +connectionTracker = ConnTrack <$> SH.newIO <*> SH.newIO + +trackConnections :: ConnTrack -> SQL.Observation -> IO () +trackConnections ConnTrack{..} (SQL.ConnectionObservation uuid status) = case status of + SQL.ReadyForUseConnectionStatus _ -> atomically $ + SH.insert identity uuid connTrackConnected *> + SH.focus Focus.delete identity uuid connTrackInUse + SQL.TerminatedConnectionStatus _ -> atomically $ + SH.focus Focus.delete identity uuid connTrackConnected *> + SH.focus Focus.delete identity uuid connTrackInUse + SQL.InUseConnectionStatus -> atomically $ + SH.insert identity uuid connTrackInUse + _ -> mempty + +connectionCounts :: ConnTrack -> IO ConnStats +connectionCounts = atomically . fmap (uncurry ConnStats) . bisequenceA . both SH.size . (connTrackConnected &&& connTrackInUse) diff --git a/test/io/test_io.py b/test/io/test_io.py index f2e255bca8..d8b78e58d9 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1,5 +1,6 @@ "Unit tests for Input/Ouput of PostgREST seen as a black box." +import contextlib import os import re import signal @@ -19,6 +20,7 @@ sleep_until_postgrest_full_reload, sleep_until_postgrest_scache_reload, wait_until_exit, + wait_until_status_code, ) @@ -1252,6 +1254,93 @@ def test_schema_cache_concurrent_notifications(slow_schema_cache_env): assert response.status_code == 200 +@pytest.mark.parametrize( + "instance_count, expected_concurrency", [(2, 2), (4, 3), (6, 4), (8, 4), (16, 5)] +) +def test_schema_cache_reload_throttled_with_advisory_locks( + instance_count, expected_concurrency, slow_schema_cache_env +): + "schema cache reloads should be throttled across instances" + + internal_sleep_ms = int( + slow_schema_cache_env["PGRST_INTERNAL_SCHEMA_CACHE_QUERY_SLEEP"] + ) + lock_wait_threshold_ms = internal_sleep_ms * 2 + query_log_pattern = re.compile(r"Schema cache queried in ([\d.]+) milliseconds") + + def read_available_output_lines(postgrest): + try: + output = postgrest.process.stdout.read() + except BlockingIOError: + return [] + + if not output: + return [] + return output.decode().splitlines() + + with contextlib.ExitStack() as stack: + instances = [ + stack.enter_context( + run( + env=slow_schema_cache_env, + wait_for_readiness=False, + wait_max_seconds=10, + ) + ) + for _ in range(instance_count) + ] + + for postgrest in instances: + wait_until_status_code( + postgrest.admin.baseurl + "/ready", max_seconds=10, status_code=200 + ) + + # Drop startup logs so only reload logs are parsed. + for postgrest in instances: + read_available_output_lines(postgrest) + + response = instances[0].session.get("/rpc/notify_pgrst") + assert response.status_code == 204 + + # Wait long enough for the lock-throttled cache reloads to finish. + time.sleep((internal_sleep_ms / 1000) * 2) + + reload_durations_ms = [] + for postgrest in instances: + output_lines = [] + for _ in range(instance_count * 2): + output_lines.extend(read_available_output_lines(postgrest)) + if any(query_log_pattern.search(line) for line in output_lines): + break + time.sleep(0.2) + + durations = [] + for line in output_lines: + match = query_log_pattern.search(line) + if match: + durations.append(float(match.group(1))) + + assert durations + reload_durations_ms.append(max(durations)) + + assert len(reload_durations_ms) == instance_count + + # expected_concurrency instances should have + # reload_durations_ms <= lock_wait_threshold_ms + # the rest should wait + assert ( + instance_count + - len( + [ + duration + for duration in reload_durations_ms + if duration > lock_wait_threshold_ms + ] + ) + == expected_concurrency + ) + + def test_schema_cache_query_sleep_logs(defaultenv): """Schema cache sleep should be reflected in the logged query duration.""" @@ -1945,7 +2034,7 @@ def test_requests_with_resource_embedding_wait_for_schema_cache_reload(defaulten env = { **defaultenv, "PGRST_DB_POOL": "2", - "PGRST_INTERNAL_SCHEMA_CACHE_RELATIONSHIP_LOAD_SLEEP": "5100", + "PGRST_INTERNAL_SCHEMA_CACHE_RELATIONSHIP_LOAD_SLEEP": "5200", } with run(env=env, wait_max_seconds=30) as postgrest: diff --git a/test/observability/Observation/MetricsSpec.hs b/test/observability/Observation/MetricsSpec.hs index 524e0c1018..6057756024 100644 --- a/test/observability/Observation/MetricsSpec.hs +++ b/test/observability/Observation/MetricsSpec.hs @@ -6,17 +6,21 @@ module Observation.MetricsSpec where -import Data.List (lookup) -import Network.Wai (Application) +import Data.List (lookup) +import qualified Hasql.Pool.Observation as SQL +import Network.Wai (Application) import ObsHelper -import qualified PostgREST.AppState as AppState -import PostgREST.Config (AppConfig (configDbSchemas)) -import qualified PostgREST.Metrics as Metrics +import qualified PostgREST.AppState as AppState +import PostgREST.Config (AppConfig (configDbSchemas)) +import PostgREST.Metrics (ConnStats (..), + MetricsState (..), + connectionCounts) import PostgREST.Observation -import Prometheus (getCounter, getVectorWith) -import Protolude -import Test.Hspec (SpecWith, describe, it) -import Test.Hspec.Wai (getState) +import Prometheus (getCounter, getVectorWith) +import Test.Hspec (SpecWith, describe, it) +import Test.Hspec.Wai (getState) + +import Protolude spec :: SpecWith (SpecState, Application) spec = describe "Server started with metrics enabled" $ do @@ -71,9 +75,33 @@ spec = describe "Server started with metrics enabled" $ do -- (there should be none but we need to verify that) threadDelay $ 1 * sec + it "Should track in use connections" $ do + SpecState{specAppState = appState, specMetrics = metrics, specObsChan} <- getState + let waitFor = waitForObs specObsChan + + liftIO $ checkState' metrics [ + -- we expect in use connections to be the same once finished + inUseConnections (+ 0) + ] $ do + signal <- newEmptyMVar + -- make sure waiting thread is signaled + bracket_ (pure ()) (putMVar signal ()) $ + -- expecting one more connection in use + checkState' metrics [ + inUseConnections (+ 1) + ] $ do + -- start a thread hanging on a single connection until signaled + void $ forkIO $ void $ AppState.usePool appState $ liftIO (readMVar signal) + -- main thread waits for ConnectionObservation with InUseConnectionStatus + -- after which used connections count should be incremented + waitFor (1 * sec) "InUseConnectionStatus" $ \x -> [ o | o@(HasqlPoolObs (SQL.ConnectionObservation _ SQL.InUseConnectionStatus)) <- pure x] + + -- hanging thread was signaled and should return the connection + waitFor (1 * sec) "ReadyForUseConnectionStatus" $ \x -> [ o | o@(HasqlPoolObs (SQL.ConnectionObservation _ (SQL.ReadyForUseConnectionStatus _))) <- pure x] where -- prometheus-client api to handle vectors is convoluted schemaCacheLoads label = expectField @"schemaCacheLoads" $ fmap (maybe (0::Int) round . lookup label) . (`getVectorWith` getCounter) + inUseConnections = expectField @"connTrack" ((inUse <$>) . connectionCounts) sec = 1000000