From 4e7e7b824e4ba19bc58a79ef7bb77c2c4dcbcea9 Mon Sep 17 00:00:00 2001 From: Chris Guidry Date: Fri, 27 Feb 2026 17:26:23 -0500 Subject: [PATCH 1/2] RateLimit admission control dependency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `RateLimit(limit, per=timedelta)` — caps how many times a task (or a per-parameter scope) can execute within a sliding window. Uses a single Redis sorted set per scope with a Lua script that atomically prunes old entries, counts remaining, and either records the execution or computes `retry_after` from the oldest entry. By default, excess tasks are rescheduled to exactly when a slot opens. `drop=True` quietly drops them instead (like Cooldown). Works both as a default parameter and as `Annotated` metadata for per-parameter scoping — same patterns as Cooldown and Debounce. Closes #161. Co-Authored-By: Claude Opus 4.6 --- docs/task-behaviors.md | 60 +++++++++ src/docket/__init__.py | 2 + src/docket/dependencies/__init__.py | 2 + src/docket/dependencies/_ratelimit.py | 155 +++++++++++++++++++++++ tests/test_ratelimit.py | 171 ++++++++++++++++++++++++++ 5 files changed, 390 insertions(+) create mode 100644 src/docket/dependencies/_ratelimit.py create mode 100644 tests/test_ratelimit.py diff --git a/docs/task-behaviors.md b/docs/task-behaviors.md index df472f7..b382274 100644 --- a/docs/task-behaviors.md +++ b/docs/task-behaviors.md @@ -546,6 +546,66 @@ await docket.add(sync_data)(customer_id=2, region="eu") Only one `Debounce` is allowed per task — its reschedule mechanism requires a single settle window. +## Rate Limiting + +Rate limiting caps how many times a task can execute within a sliding time window. Unlike cooldown (which drops duplicates) or debounce (which waits for quiet), rate limiting counts executions and blocks when the count exceeds a threshold. + +By default, excess tasks are rescheduled to exactly when a slot opens. With `drop=True`, they're quietly dropped instead. + +### Per-Task Rate Limit + +```python +from datetime import timedelta +from docket import RateLimit + +async def sync_data( + rate: RateLimit = RateLimit(10, per=timedelta(minutes=1)), +) -> None: + await perform_sync() + +# The first 10 calls within a minute execute immediately. +# The 11th is rescheduled to when the oldest slot frees up. +``` + +### Per-Parameter Rate Limit + +Annotate a parameter with `RateLimit` to apply independent limits per value: + +```python +from typing import Annotated + +async def process_customer( + customer_id: Annotated[int, RateLimit(5, per=timedelta(minutes=1))], +) -> None: + await refresh_customer_data(customer_id) + +# Each customer_id gets its own independent sliding window. +# Customer 1001 can hit 5/min while customer 2002 independently hits 5/min. +``` + +### Dropping Excess Tasks + +When rescheduling isn't appropriate, use `drop=True` to silently discard excess tasks: + +```python +async def fire_webhook( + endpoint: Annotated[str, RateLimit(100, per=timedelta(hours=1), drop=True)], +) -> None: + await send_webhook(endpoint) + +# After 100 webhook calls to the same endpoint in an hour, +# additional calls are dropped with an INFO log. +``` + +### Rate Limit vs. Cooldown vs. Debounce + +| | RateLimit | Cooldown | Debounce | +|---|---|---|---| +| **Behavior** | Allow N per window | Execute first, drop rest | Wait for quiet, then execute | +| **Window anchored to** | Sliding (each execution) | First execution | Last submission | +| **Over-limit default** | Reschedule | Drop | Drop (losers) / Reschedule (winner) | +| **Good for** | Enforcing throughput caps | Deduplicating rapid-fire | Batching bursts into one action | + ### Combining with Other Controls Debounce, cooldown, and concurrency limits can all coexist on the same task: diff --git a/src/docket/__init__.py b/src/docket/__init__.py index 9c886ea..8ea4741 100644 --- a/src/docket/__init__.py +++ b/src/docket/__init__.py @@ -15,6 +15,7 @@ Cooldown, Cron, Debounce, + RateLimit, CurrentDocket, CurrentExecution, CurrentWorker, @@ -42,6 +43,7 @@ "Cooldown", "Cron", "Debounce", + "RateLimit", "CurrentDocket", "CurrentExecution", "CurrentWorker", diff --git a/src/docket/dependencies/__init__.py b/src/docket/dependencies/__init__.py index a18c40a..5481baf 100644 --- a/src/docket/dependencies/__init__.py +++ b/src/docket/dependencies/__init__.py @@ -21,6 +21,7 @@ from ._concurrency import ConcurrencyBlocked, ConcurrencyLimit from ._cooldown import Cooldown from ._debounce import Debounce +from ._ratelimit import RateLimit from ._cron import Cron from ._contextual import ( CurrentDocket, @@ -87,6 +88,7 @@ "ConcurrencyLimit", "Cooldown", "Debounce", + "RateLimit", "Cron", "Perpetual", "Progress", diff --git a/src/docket/dependencies/_ratelimit.py b/src/docket/dependencies/_ratelimit.py new file mode 100644 index 0000000..02de518 --- /dev/null +++ b/src/docket/dependencies/_ratelimit.py @@ -0,0 +1,155 @@ +"""Rate limit admission control dependency. + +Caps how many times a task (or a per-parameter scope) can execute within a +sliding window. Uses a Redis sorted set as a sliding window log: members are +execution keys, scores are millisecond timestamps. +""" + +from __future__ import annotations + +import time +from datetime import timedelta +from types import TracebackType +from typing import Any + +from ._base import AdmissionBlocked, Dependency, current_docket, current_execution + +# Lua script for atomic sliding-window rate limit check. +# +# KEYS[1] = sorted set key (one per scope) +# ARGV[1] = execution key (member) +# ARGV[2] = current time in milliseconds +# ARGV[3] = window size in milliseconds +# ARGV[4] = max allowed count (limit) +# ARGV[5] = key TTL in milliseconds (window * 2, safety net) +# +# Returns: {action, retry_after_ms} +# action: 1=PROCEED, 2=BLOCKED +# retry_after_ms: ms until the oldest entry expires (only for BLOCKED) +_RATELIMIT_LUA = """ +local key = KEYS[1] +local member = ARGV[1] +local now_ms = tonumber(ARGV[2]) +local window_ms = tonumber(ARGV[3]) +local limit = tonumber(ARGV[4]) +local ttl_ms = tonumber(ARGV[5]) + +-- Prune entries older than the window +local cutoff = now_ms - window_ms +redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff) + +-- Count remaining entries +local count = redis.call('ZCARD', key) + +if count < limit then + -- Under limit: record this execution and set safety TTL + redis.call('ZADD', key, now_ms, member) + redis.call('PEXPIRE', key, ttl_ms) + return {1, 0} +end + +-- Over limit: compute when the oldest entry will expire +local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') +local oldest_score = tonumber(oldest[2]) +local retry_after = oldest_score + window_ms - now_ms +if retry_after < 1 then + retry_after = 1 +end +return {2, retry_after} +""" + +_ACTION_PROCEED = 1 +_ACTION_BLOCKED = 2 + + +class RateLimit(Dependency["RateLimit"]): + """Cap executions within a sliding time window. + + Uses a Redis sorted set as a sliding window log. Each execution adds + an entry; entries older than the window are pruned atomically. + + When the limit is reached: + - ``drop=False`` (default): the task is rescheduled to when a slot opens. + - ``drop=True``: the task is quietly dropped. + + Works both as a default parameter and as ``Annotated`` metadata:: + + # Per-task: max 10 per minute, excess rescheduled + async def sync_data( + rate: RateLimit = RateLimit(10, per=timedelta(minutes=1)), + ) -> None: ... + + # Per-parameter: max 5 per minute per customer, excess dropped + async def process_customer( + customer_id: Annotated[int, RateLimit(5, per=timedelta(minutes=1), drop=True)], + ) -> None: ... + """ + + def __init__( + self, + limit: int, + *, + per: timedelta, + drop: bool = False, + scope: str | None = None, + ) -> None: + self.limit = limit + self.per = per + self.drop = drop + self.scope = scope + self._argument_name: str | None = None + self._argument_value: Any = None + + def bind_to_parameter(self, name: str, value: Any) -> RateLimit: + bound = RateLimit(self.limit, per=self.per, drop=self.drop, scope=self.scope) + bound._argument_name = name + bound._argument_value = value + return bound + + async def __aenter__(self) -> RateLimit: + execution = current_execution.get() + docket = current_docket.get() + + scope = self.scope or docket.name + if self._argument_name is not None: + ratelimit_key = ( + f"{scope}:ratelimit:{self._argument_name}:{self._argument_value}" + ) + else: + ratelimit_key = f"{scope}:ratelimit:{execution.function_name}" + + window_ms = int(self.per.total_seconds() * 1000) + now_ms = int(time.time() * 1000) + ttl_ms = window_ms * 2 + + async with docket.redis() as redis: + script = redis.register_script(_RATELIMIT_LUA) + result: list[int] = await script( + keys=[ratelimit_key], + args=[execution.key, now_ms, window_ms, self.limit, ttl_ms], + ) + + action = result[0] + retry_after_ms = result[1] + + if action == _ACTION_PROCEED: + return self + + reason = f"rate limit ({self.limit}/{self.per}) on {ratelimit_key}" + + if self.drop: + raise AdmissionBlocked(execution, reason=reason, reschedule=False) + + raise AdmissionBlocked( + execution, + reason=reason, + retry_delay=timedelta(milliseconds=retry_after_ms), + ) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass diff --git a/tests/test_ratelimit.py b/tests/test_ratelimit.py new file mode 100644 index 0000000..5783b90 --- /dev/null +++ b/tests/test_ratelimit.py @@ -0,0 +1,171 @@ +"""Tests for RateLimit admission control dependency.""" + +from __future__ import annotations + +import asyncio +from datetime import timedelta +from typing import Annotated + +from docket import ConcurrencyLimit, Docket, Worker +from docket.dependencies import RateLimit + + +async def test_task_level_rate_limit_drops_excess(docket: Docket, worker: Worker): + """Task-level rate limit drops excess executions within the window.""" + results: list[str] = [] + + async def rated_task( + rate: RateLimit = RateLimit(2, per=timedelta(seconds=5), drop=True), + ): + results.append("executed") + + await docket.add(rated_task)() + await docket.add(rated_task)() + await docket.add(rated_task)() + + await worker.run_until_finished() + + assert len(results) == 2 + + +async def test_task_level_rate_limit_allows_after_window( + docket: Docket, worker: Worker +): + """Task-level rate limit allows execution after the window expires.""" + results: list[str] = [] + + async def rated_task( + rate: RateLimit = RateLimit(1, per=timedelta(milliseconds=50), drop=True), + ): + results.append("executed") + + await docket.add(rated_task)() + await worker.run_until_finished() + assert results == ["executed"] + + await asyncio.sleep(0.06) + + await docket.add(rated_task)() + await worker.run_until_finished() + assert results == ["executed", "executed"] + + +async def test_per_parameter_rate_limit_independent_scopes( + docket: Docket, worker: Worker +): + """Per-parameter rate limit scopes independently per value.""" + results: list[int] = [] + + async def rated_task( + customer_id: Annotated[int, RateLimit(1, per=timedelta(seconds=5), drop=True)], + ): + results.append(customer_id) + + await docket.add(rated_task)(customer_id=1) + await docket.add(rated_task)(customer_id=1) + await docket.add(rated_task)(customer_id=2) + + worker.concurrency = 10 + await worker.run_until_finished() + + assert sorted(results) == [1, 2] + assert results.count(1) == 1 + + +async def test_drop_true_drops_excess(docket: Docket, worker: Worker): + """With drop=True, excess tasks are quietly dropped instead of rescheduled.""" + results: list[str] = [] + + async def rated_task( + rate: RateLimit = RateLimit(1, per=timedelta(seconds=5), drop=True), + ): + results.append("executed") + + await docket.add(rated_task)() + await docket.add(rated_task)() + await docket.add(rated_task)() + + await worker.run_until_finished() + + assert results == ["executed"] + + +async def test_drop_false_excess_eventually_executes(docket: Docket, worker: Worker): + """With drop=False (default), excess tasks reschedule and eventually execute.""" + results: list[str] = [] + + async def rated_task( + rate: RateLimit = RateLimit(1, per=timedelta(milliseconds=50)), + ): + results.append("executed") + + await docket.add(rated_task)() + await docket.add(rated_task)() + + await worker.run_until_finished() + + assert len(results) == 2 + + +async def test_multiple_rate_limits_on_different_parameters( + docket: Docket, worker: Worker +): + """Multiple RateLimit annotations on different parameters are independent.""" + results: list[tuple[int, str]] = [] + + async def task( + customer_id: Annotated[int, RateLimit(1, per=timedelta(seconds=5), drop=True)], + region: Annotated[str, RateLimit(1, per=timedelta(seconds=5), drop=True)], + ): + results.append((customer_id, region)) + + await docket.add(task)(customer_id=1, region="us") + await worker.run_until_finished() + + await docket.add(task)(customer_id=1, region="eu") # blocked by customer_id=1 + await docket.add(task)(customer_id=2, region="us") # blocked by region="us" + + await worker.run_until_finished() + + assert results == [(1, "us")] + + +async def test_rate_limit_coexists_with_concurrency_limit( + docket: Docket, worker: Worker +): + """RateLimit + ConcurrencyLimit can coexist on the same task.""" + results: list[str] = [] + + async def task( + customer_id: Annotated[int, ConcurrencyLimit(1)], + rate: RateLimit = RateLimit(10, per=timedelta(seconds=5)), + ): + results.append(f"executed_{customer_id}") + + await docket.add(task)(customer_id=1) + await worker.run_until_finished() + assert results == ["executed_1"] + + +async def test_rate_limit_key_cleaned_up_after_ttl(docket: Docket, worker: Worker): + """Redis key is cleaned up after TTL expires.""" + + async def rated_task( + rate: RateLimit = RateLimit(10, per=timedelta(milliseconds=50)), + ): + pass + + await docket.add(rated_task)() + await worker.run_until_finished() + + # Wait for TTL to expire (key TTL = window * 2) + await asyncio.sleep(0.15) + + async with docket.redis() as redis: + ratelimit_keys: list[str] = [ + key + async for key in redis.scan_iter( # type: ignore[union-attr] + match=f"{docket.name}:ratelimit:*" + ) + ] + assert ratelimit_keys == [] From e331c29719d2471bc5ebcd9bd48be20381d456cc Mon Sep 17 00:00:00 2001 From: Chris Guidry Date: Fri, 27 Feb 2026 17:56:31 -0500 Subject: [PATCH 2/2] Fix RateLimit sorted set member uniqueness for Perpetual tasks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Using `execution.key` as the sorted set member meant Perpetual tasks (which reuse the same key via `replace()`) only ever had one entry — ZADD overwrites the score instead of adding a new member, so ZCARD stays at 1 and the rate limit never fires. The member is now `{execution.key}:{now_ms}`, so each execution attempt gets its own entry. Also cleans up phantom slots in `__aexit__` when a later dependency (like ConcurrencyLimit) blocks — the task never ran but would otherwise consume a rate-limit slot. Caught during PR #356 review. Co-Authored-By: Claude Opus 4.6 --- src/docket/dependencies/_ratelimit.py | 19 ++++++-- tests/test_ratelimit.py | 70 ++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 5 deletions(-) diff --git a/src/docket/dependencies/_ratelimit.py b/src/docket/dependencies/_ratelimit.py index 02de518..ac61565 100644 --- a/src/docket/dependencies/_ratelimit.py +++ b/src/docket/dependencies/_ratelimit.py @@ -2,7 +2,8 @@ Caps how many times a task (or a per-parameter scope) can execute within a sliding window. Uses a Redis sorted set as a sliding window log: members are -execution keys, scores are millisecond timestamps. +``{execution_key}:{now_ms}`` strings (unique per attempt), scores are +millisecond timestamps. """ from __future__ import annotations @@ -17,7 +18,7 @@ # Lua script for atomic sliding-window rate limit check. # # KEYS[1] = sorted set key (one per scope) -# ARGV[1] = execution key (member) +# ARGV[1] = member (execution key + timestamp, unique per attempt) # ARGV[2] = current time in milliseconds # ARGV[3] = window size in milliseconds # ARGV[4] = max allowed count (limit) @@ -99,6 +100,8 @@ def __init__( self.scope = scope self._argument_name: str | None = None self._argument_value: Any = None + self._ratelimit_key: str | None = None + self._member: str | None = None def bind_to_parameter(self, name: str, value: Any) -> RateLimit: bound = RateLimit(self.limit, per=self.per, drop=self.drop, scope=self.scope) @@ -121,18 +124,21 @@ async def __aenter__(self) -> RateLimit: window_ms = int(self.per.total_seconds() * 1000) now_ms = int(time.time() * 1000) ttl_ms = window_ms * 2 + member = f"{execution.key}:{now_ms}" async with docket.redis() as redis: script = redis.register_script(_RATELIMIT_LUA) result: list[int] = await script( keys=[ratelimit_key], - args=[execution.key, now_ms, window_ms, self.limit, ttl_ms], + args=[member, now_ms, window_ms, self.limit, ttl_ms], ) action = result[0] retry_after_ms = result[1] if action == _ACTION_PROCEED: + self._ratelimit_key = ratelimit_key + self._member = member return self reason = f"rate limit ({self.limit}/{self.per}) on {ratelimit_key}" @@ -152,4 +158,9 @@ async def __aexit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - pass + if exc_type is not None and self._member is not None: + if issubclass(exc_type, AdmissionBlocked): + assert self._ratelimit_key is not None + docket = current_docket.get() + async with docket.redis() as redis: + await redis.zrem(self._ratelimit_key, self._member) diff --git a/tests/test_ratelimit.py b/tests/test_ratelimit.py index 5783b90..fc6efac 100644 --- a/tests/test_ratelimit.py +++ b/tests/test_ratelimit.py @@ -6,7 +6,7 @@ from datetime import timedelta from typing import Annotated -from docket import ConcurrencyLimit, Docket, Worker +from docket import ConcurrencyLimit, Docket, Perpetual, Worker from docket.dependencies import RateLimit @@ -169,3 +169,71 @@ async def rated_task( ) ] assert ratelimit_keys == [] + + +async def test_rate_limit_slot_kept_on_task_failure(docket: Docket, worker: Worker): + """A failed task still counts against the rate limit.""" + results: list[str] = [] + + async def failing_task( + rate: RateLimit = RateLimit(2, per=timedelta(seconds=5), drop=True), + ): + results.append("attempted") + if len(results) == 1: + raise RuntimeError("boom") + + await docket.add(failing_task)() + await docket.add(failing_task)() + await docket.add(failing_task)() + + await worker.run_until_finished() + + assert len(results) == 2 + + +async def test_perpetual_task_counts_each_execution(docket: Docket, worker: Worker): + """Perpetual re-executions each count against the rate limit.""" + results: list[str] = [] + + async def perpetual_rated( + perpetual: Perpetual = Perpetual(every=timedelta(milliseconds=10)), + rate: RateLimit = RateLimit(2, per=timedelta(seconds=5), drop=True), + ): + results.append("executed") + + execution = await docket.add(perpetual_rated)() + await worker.run_at_most({execution.key: 4}) + + assert len(results) == 2 + + +async def test_rate_limit_slot_freed_when_another_dep_blocks( + docket: Docket, worker: Worker +): + """RateLimit slot is freed if a later dependency blocks the task. + + RateLimit (default-param, resolved first) proceeds and records a slot, + then ConcurrencyLimit (annotation, resolved second) blocks. Without + __aexit__ cleanup the phantom slot stays in the sorted set. + """ + results: list[str] = [] + + async def blocker( + customer_id: Annotated[int, ConcurrencyLimit(1)], + rate: RateLimit = RateLimit(2, per=timedelta(seconds=5), drop=True), + ): + results.append(f"executed_{customer_id}") + if len(results) == 1: + await asyncio.sleep(0.3) + + # Task 1 grabs both the rate-limit slot and the concurrency slot. + # Task 2 passes rate-limit (2/2) but is concurrency-blocked, then + # rescheduled. Without cleanup, the phantom slot means task 2's + # retry hits 2/2 and gets dropped. + await docket.add(blocker)(customer_id=1) + await docket.add(blocker)(customer_id=1) + + worker.concurrency = 2 + await worker.run_until_finished() + + assert len(results) == 2