From d2b997521a5840de2609b0235f5b655a4aa0e1f3 Mon Sep 17 00:00:00 2001 From: Rockford Lhotka Date: Thu, 7 May 2026 15:46:47 -0500 Subject: [PATCH] LLM gateway: per-tier concurrency cap (phase 1 of #352) Adds the gateway core: a singleton ILlmGateway with per-tier SemaphoreSlim instances that all LLM calls flow through. Bursty parallel callers cannot overwhelm a tier; pending waiters cancel automatically when their ct fires, which is how user work effectively preempts dream-cycle work without an explicit priority queue. Phase 1 scope per design/llm-gateway.md: - LlmGateway singleton with per-tier semaphores (Low/Balanced/High) - LlmGatewayOptions with configurable concurrency caps (defaults: Low=8, Balanced=4, High=2) - LlmClient.CallTierAsync routes through the gateway, preserving existing tier-fallback and SDK-quirk-retry behavior - New metric: rockbot.llm.gateway.slot_wait.duration - Internal counters (Pending, InFlight) surfaced for diagnostics and tests - 8 unit tests covering cap enforcement, ct propagation, cross-tier independence, slot release on exception, and cancellation-while-waiting cleanup Phases 2-4 (retry-on-429, bounded queue, ct audit) land separately. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/RockBot.Host.Abstractions/ILlmClient.cs | 8 +- .../LlmGatewayOptions.cs | 30 +++ src/RockBot.Host/HostDiagnostics.cs | 12 + src/RockBot.Host/ILlmGateway.cs | 27 ++ src/RockBot.Host/LlmClient.cs | 16 +- src/RockBot.Host/LlmGateway.cs | 128 ++++++++++ .../ServiceCollectionExtensions.cs | 2 + tests/RockBot.Host.Tests/LlmGatewayTests.cs | 235 ++++++++++++++++++ 8 files changed, 450 insertions(+), 8 deletions(-) create mode 100644 src/RockBot.Host.Abstractions/LlmGatewayOptions.cs create mode 100644 src/RockBot.Host/ILlmGateway.cs create mode 100644 src/RockBot.Host/LlmGateway.cs create mode 100644 tests/RockBot.Host.Tests/LlmGatewayTests.cs diff --git a/src/RockBot.Host.Abstractions/ILlmClient.cs b/src/RockBot.Host.Abstractions/ILlmClient.cs index e1eb3f3a..6b2ce467 100644 --- a/src/RockBot.Host.Abstractions/ILlmClient.cs +++ b/src/RockBot.Host.Abstractions/ILlmClient.cs @@ -4,9 +4,11 @@ namespace RockBot.Host; /// /// Wrapper around for all LLM calls in an agent process. -/// Adds retry logic for known model-specific SDK quirks. Registered as transient -/// so concurrent callers (user loop, background tasks, dreaming, session evaluation) -/// each get their own instance and never queue behind each other. +/// Adds retry logic for known model-specific SDK quirks and routes every call through +/// the per-tier LlmGateway which caps concurrency and propagates cancellation. +/// +/// Registered as transient so each consumer gets its own instance, but the gateway +/// is a singleton so all consumers share the per-tier concurrency budget. /// /// To avoid starting background LLM work while the user is actively waiting /// for a response, use instead of this interface. diff --git a/src/RockBot.Host.Abstractions/LlmGatewayOptions.cs b/src/RockBot.Host.Abstractions/LlmGatewayOptions.cs new file mode 100644 index 00000000..36b705cb --- /dev/null +++ b/src/RockBot.Host.Abstractions/LlmGatewayOptions.cs @@ -0,0 +1,30 @@ +namespace RockBot.Host; + +/// +/// Options for LlmGateway: the global per-tier concurrency layer that all LLM +/// calls flow through. See design/llm-gateway.md for the full design rationale. +/// +/// +/// Caps are per-process. Across multiple agent processes against the same provider +/// account, total concurrency is the sum. Per-account rate limits ultimately bound the +/// system; the gateway is a per-process governor, not a global one. +/// +public sealed class LlmGatewayOptions +{ + /// + /// Maximum concurrent in-flight LLM calls on the tier. + /// Cheap calls used heavily for batch/extraction work, so a higher cap is appropriate. + /// + public int LowMaxConcurrent { get; set; } = 8; + + /// + /// Maximum concurrent in-flight LLM calls on the tier. + /// + public int BalancedMaxConcurrent { get; set; } = 4; + + /// + /// Maximum concurrent in-flight LLM calls on the tier. + /// Expensive judgment calls; lower cap. + /// + public int HighMaxConcurrent { get; set; } = 2; +} diff --git a/src/RockBot.Host/HostDiagnostics.cs b/src/RockBot.Host/HostDiagnostics.cs index 19f9c5e1..9b4f81d0 100644 --- a/src/RockBot.Host/HostDiagnostics.cs +++ b/src/RockBot.Host/HostDiagnostics.cs @@ -46,6 +46,18 @@ public static class HostDiagnostics unit: "{token}", description: "Total number of output tokens produced"); + /// + /// Time a caller spent waiting for a per-tier gateway slot before its LLM call + /// could proceed. Non-zero values indicate contention; sustained high values + /// indicate the tier's MaxConcurrent cap is too low for the workload + /// (or that callers are issuing too many parallel calls). + /// + public static readonly Histogram LlmGatewaySlotWaitDuration = + Meter.CreateHistogram( + "rockbot.llm.gateway.slot_wait.duration", + unit: "ms", + description: "Time spent waiting for a per-tier LLM gateway slot"); + // ── Agent turn metrics — recorded at architectural boundaries ───────────── /// Duration from user message receipt to final reply published. diff --git a/src/RockBot.Host/ILlmGateway.cs b/src/RockBot.Host/ILlmGateway.cs new file mode 100644 index 00000000..c259b5b0 --- /dev/null +++ b/src/RockBot.Host/ILlmGateway.cs @@ -0,0 +1,27 @@ +namespace RockBot.Host; + +/// +/// Global per-tier concurrency layer for LLM calls. All +/// invocations flow through an implementation of this gateway so that parallel +/// callers cannot overwhelm a tier and so that cancellation reliably drains +/// pending work. See design/llm-gateway.md. +/// +internal interface ILlmGateway +{ + /// + /// Acquires a slot on the per-tier concurrency semaphore, then invokes + /// . If + /// fires while waiting for a slot, the wait aborts with + /// before the operation runs. + /// + /// + /// The same is passed to + /// . Implementations must propagate cancellation + /// end-to-end; any path that drops the token re-introduces the rate-limit and + /// preemption hazards the gateway exists to prevent. + /// + Task ExecuteAsync( + ModelTier tier, + Func> operation, + CancellationToken cancellationToken); +} diff --git a/src/RockBot.Host/LlmClient.cs b/src/RockBot.Host/LlmClient.cs index e764834b..0e2aafaa 100644 --- a/src/RockBot.Host/LlmClient.cs +++ b/src/RockBot.Host/LlmClient.cs @@ -7,13 +7,13 @@ namespace RockBot.Host; /// /// Default implementation of . /// Selects the appropriate from the -/// and adds retry logic for -/// known model-specific SDK quirks. Registered as transient so each consumer -/// gets its own instance — concurrent calls from the user loop, background tasks, -/// dreaming, and session evaluation proceed independently without queuing. +/// , adds retry logic for known +/// model-specific SDK quirks, and routes the call through the singleton +/// which caps per-tier concurrency. /// internal sealed class LlmClient( TieredChatClientRegistry registry, + ILlmGateway gateway, LlmCostEstimator costEstimator, ILogger logger) : ILlmClient { @@ -67,7 +67,13 @@ private async Task CallTierAsync( // next model. The original cancellation token is passed through so // FallbackChatClient can correctly distinguish user cancellation from // provider timeouts. - var response = await InvokeWithNullArgRetryAsync(client, messages, options, cancellationToken); + // + // The gateway gates the actual SDK call on the per-tier concurrency + // semaphore so bursty parallel callers cannot overwhelm a tier. + var response = await gateway.ExecuteAsync( + tier, + ct => InvokeWithNullArgRetryAsync(client, messages, options, ct), + cancellationToken); if (response.Usage is { } usage) { diff --git a/src/RockBot.Host/LlmGateway.cs b/src/RockBot.Host/LlmGateway.cs new file mode 100644 index 00000000..595db073 --- /dev/null +++ b/src/RockBot.Host/LlmGateway.cs @@ -0,0 +1,128 @@ +using System.Diagnostics; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace RockBot.Host; + +/// +/// Global per-tier concurrency layer for all LLM calls. Every call to +/// flows through here so that bursty parallel work +/// (e.g. observation-framework extraction) cannot overwhelm a tier. +/// +/// +/// +/// Cancellation is the priority mechanism. When a caller's ct fires, the +/// pending wait on the per-tier aborts immediately, +/// freeing the slot for other waiters. This is how user-initiated work effectively +/// preempts dream-cycle work without an explicit priority queue: the work-serializer +/// already cancels the dream when a user message arrives, and that cancellation +/// drains the dream's queued LLM calls. +/// +/// +/// Registered as a singleton so all callers share the same per-tier semaphores. +/// +/// +/// See design/llm-gateway.md for the full design. +/// +/// +internal sealed class LlmGateway : ILlmGateway, IDisposable +{ + private readonly TierSlot[] _slots; + private readonly ILogger _logger; + + public LlmGateway(IOptions options, ILogger logger) + { + var opts = options.Value; + + var tierValues = (ModelTier[])Enum.GetValues(typeof(ModelTier)); + _slots = new TierSlot[tierValues.Length]; + foreach (var tier in tierValues) + { + var cap = tier switch + { + ModelTier.Low => opts.LowMaxConcurrent, + ModelTier.High => opts.HighMaxConcurrent, + _ => opts.BalancedMaxConcurrent, + }; + + if (cap < 1) + throw new ArgumentOutOfRangeException( + nameof(options), + $"LlmGatewayOptions {tier}MaxConcurrent must be >= 1 (was {cap})."); + + _slots[(int)tier] = new TierSlot(cap); + } + + _logger = logger; + + _logger.LogInformation( + "LlmGateway: per-tier concurrency caps Low={Low} Balanced={Balanced} High={High}", + opts.LowMaxConcurrent, opts.BalancedMaxConcurrent, opts.HighMaxConcurrent); + } + + /// + /// Returns the current number of waiters on the per-tier semaphore. Useful for + /// diagnostics and tests; values are observational and may race. + /// + internal int GetPendingCount(ModelTier tier) => Volatile.Read(ref _slots[(int)tier].Pending); + + /// + /// Returns the current number of in-flight calls on the tier. Useful for + /// diagnostics and tests; values are observational and may race. + /// + internal int GetInFlightCount(ModelTier tier) => Volatile.Read(ref _slots[(int)tier].InFlight); + + public async Task ExecuteAsync( + ModelTier tier, + Func> operation, + CancellationToken cancellationToken) + { + ArgumentNullException.ThrowIfNull(operation); + + var slot = _slots[(int)tier]; + var tierTag = new KeyValuePair("rockbot.llm.tier", tier.ToString()); + + var slotWaitSw = Stopwatch.StartNew(); + Interlocked.Increment(ref slot.Pending); + try + { + await slot.Semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + Interlocked.Decrement(ref slot.Pending); + slotWaitSw.Stop(); + HostDiagnostics.LlmGatewaySlotWaitDuration.Record( + slotWaitSw.Elapsed.TotalMilliseconds, tierTag); + } + + Interlocked.Increment(ref slot.InFlight); + try + { + return await operation(cancellationToken).ConfigureAwait(false); + } + finally + { + Interlocked.Decrement(ref slot.InFlight); + slot.Semaphore.Release(); + } + } + + public void Dispose() + { + foreach (var slot in _slots) + slot.Semaphore.Dispose(); + } + + private sealed class TierSlot + { + public readonly SemaphoreSlim Semaphore; + public int Pending; + public int InFlight; + + public TierSlot(int maxConcurrent) + { + Semaphore = new SemaphoreSlim(maxConcurrent, maxConcurrent); + } + } +} diff --git a/src/RockBot.Host/ServiceCollectionExtensions.cs b/src/RockBot.Host/ServiceCollectionExtensions.cs index 1862aa35..0de383d6 100644 --- a/src/RockBot.Host/ServiceCollectionExtensions.cs +++ b/src/RockBot.Host/ServiceCollectionExtensions.cs @@ -34,6 +34,8 @@ public static IServiceCollection AddRockBotHost( services.AddSingleton(); services.Configure(_ => { }); + services.Configure(_ => { }); + services.AddSingleton(); services.AddTransient(); services.AddSingleton(); services.AddTransient(); diff --git a/tests/RockBot.Host.Tests/LlmGatewayTests.cs b/tests/RockBot.Host.Tests/LlmGatewayTests.cs new file mode 100644 index 00000000..1c1fe9df --- /dev/null +++ b/tests/RockBot.Host.Tests/LlmGatewayTests.cs @@ -0,0 +1,235 @@ +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; + +namespace RockBot.Host.Tests; + +[TestClass] +public class LlmGatewayTests +{ + private static LlmGateway CreateGateway(int low = 2, int balanced = 2, int high = 2) + { + var options = Options.Create(new LlmGatewayOptions + { + LowMaxConcurrent = low, + BalancedMaxConcurrent = balanced, + HighMaxConcurrent = high, + }); + return new LlmGateway(options, NullLogger.Instance); + } + + [TestMethod] + public async Task ExecuteAsync_RunsOperationAndReturnsResult() + { + using var gateway = CreateGateway(); + + var result = await gateway.ExecuteAsync( + ModelTier.Balanced, + ct => Task.FromResult(42), + CancellationToken.None); + + Assert.AreEqual(42, result); + } + + [TestMethod] + public async Task ExecuteAsync_PropagatesCancellationToOperation() + { + using var gateway = CreateGateway(); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + // SemaphoreSlim.WaitAsync may throw the derived TaskCanceledException + // when the token is cancelled — accept any OperationCanceledException. + await Assert.ThrowsAsync(async () => + await gateway.ExecuteAsync( + ModelTier.Balanced, + ct => + { + ct.ThrowIfCancellationRequested(); + return Task.FromResult(0); + }, + cts.Token)); + } + + [TestMethod] + public async Task ExecuteAsync_EnforcesPerTierConcurrencyCap() + { + using var gateway = CreateGateway(low: 2); + var inFlight = 0; + var maxInFlight = 0; + var gate = new TaskCompletionSource(); + + async Task Op(CancellationToken ct) + { + var current = Interlocked.Increment(ref inFlight); + // Race-tolerant max tracking + int observed; + do + { + observed = Volatile.Read(ref maxInFlight); + if (current <= observed) break; + } while (Interlocked.CompareExchange(ref maxInFlight, current, observed) != observed); + + await gate.Task; + Interlocked.Decrement(ref inFlight); + return 1; + } + + var tasks = Enumerable.Range(0, 5) + .Select(_ => gateway.ExecuteAsync(ModelTier.Low, Op, CancellationToken.None)) + .ToArray(); + + // Wait long enough for the first batch to actually be running + await WaitUntilAsync(() => Volatile.Read(ref inFlight) == 2, TimeSpan.FromSeconds(5)); + + Assert.AreEqual(2, Volatile.Read(ref inFlight), + "Cap should serialize callers at MaxConcurrent"); + + gate.SetResult(); + await Task.WhenAll(tasks); + + Assert.AreEqual(2, Volatile.Read(ref maxInFlight), + "MaxConcurrent should never have been exceeded"); + } + + [TestMethod] + public async Task ExecuteAsync_TiersDoNotShareSlots() + { + using var gateway = CreateGateway(low: 1, balanced: 1, high: 1); + var lowGate = new TaskCompletionSource(); + var balancedGate = new TaskCompletionSource(); + var highGate = new TaskCompletionSource(); + + var lowTask = gateway.ExecuteAsync(ModelTier.Low, async ct => + { + await lowGate.Task; + return 0; + }, CancellationToken.None); + + var balancedTask = gateway.ExecuteAsync(ModelTier.Balanced, async ct => + { + await balancedGate.Task; + return 0; + }, CancellationToken.None); + + var highTask = gateway.ExecuteAsync(ModelTier.High, async ct => + { + await highGate.Task; + return 0; + }, CancellationToken.None); + + // All three should be in-flight despite each tier's cap of 1 + await WaitUntilAsync( + () => gateway.GetInFlightCount(ModelTier.Low) == 1 + && gateway.GetInFlightCount(ModelTier.Balanced) == 1 + && gateway.GetInFlightCount(ModelTier.High) == 1, + TimeSpan.FromSeconds(5)); + + lowGate.SetResult(); + balancedGate.SetResult(); + highGate.SetResult(); + await Task.WhenAll(lowTask, balancedTask, highTask); + } + + [TestMethod] + public async Task ExecuteAsync_CancellationWhileWaitingForSlot_AbortsAndDoesNotConsumeSlot() + { + using var gateway = CreateGateway(low: 1); + var gate = new TaskCompletionSource(); + + // Occupy the only slot + var holder = gateway.ExecuteAsync(ModelTier.Low, async ct => + { + await gate.Task; + return 0; + }, CancellationToken.None); + + await WaitUntilAsync( + () => gateway.GetInFlightCount(ModelTier.Low) == 1, + TimeSpan.FromSeconds(5)); + + // Try to enqueue a second call with a CT that cancels before the slot frees + using var cts = new CancellationTokenSource(); + var operationRan = false; + var blocked = gateway.ExecuteAsync(ModelTier.Low, ct => + { + operationRan = true; + return Task.FromResult(0); + }, cts.Token); + + await WaitUntilAsync( + () => gateway.GetPendingCount(ModelTier.Low) == 1, + TimeSpan.FromSeconds(5)); + + cts.Cancel(); + + await Assert.ThrowsAsync(async () => await blocked); + + Assert.IsFalse(operationRan, "Cancelled waiter must not run the operation"); + + // The pending count should drop to zero; the holder still owns the in-flight slot + await WaitUntilAsync( + () => gateway.GetPendingCount(ModelTier.Low) == 0, + TimeSpan.FromSeconds(5)); + Assert.AreEqual(1, gateway.GetInFlightCount(ModelTier.Low)); + + // Free the holder; subsequent call should proceed normally (slot wasn't leaked) + gate.SetResult(); + await holder; + + var followup = await gateway.ExecuteAsync(ModelTier.Low, + ct => Task.FromResult(99), + CancellationToken.None); + Assert.AreEqual(99, followup); + } + + [TestMethod] + public async Task ExecuteAsync_ExceptionInOperation_ReleasesSlot() + { + using var gateway = CreateGateway(low: 1); + + await Assert.ThrowsExactlyAsync(async () => + await gateway.ExecuteAsync( + ModelTier.Low, + ct => throw new InvalidOperationException("boom"), + CancellationToken.None)); + + // Slot should be free for the next call + var ok = await gateway.ExecuteAsync( + ModelTier.Low, + ct => Task.FromResult(7), + CancellationToken.None); + Assert.AreEqual(7, ok); + Assert.AreEqual(0, gateway.GetInFlightCount(ModelTier.Low)); + } + + [TestMethod] + public async Task ExecuteAsync_NullOperation_Throws() + { + using var gateway = CreateGateway(); + + await Assert.ThrowsExactlyAsync(async () => + await gateway.ExecuteAsync( + ModelTier.Balanced, + operation: null!, + CancellationToken.None)); + } + + [TestMethod] + public void Constructor_CapBelowOne_Throws() + { + var bad = Options.Create(new LlmGatewayOptions { LowMaxConcurrent = 0 }); + Assert.ThrowsExactly(() => + new LlmGateway(bad, NullLogger.Instance)); + } + + private static async Task WaitUntilAsync(Func predicate, TimeSpan timeout) + { + var sw = System.Diagnostics.Stopwatch.StartNew(); + while (sw.Elapsed < timeout) + { + if (predicate()) return; + await Task.Delay(10); + } + Assert.Fail($"Condition not met within {timeout}"); + } +}