diff --git a/slowapi/extension.py b/slowapi/extension.py index 050f882..f7d85f4 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -326,13 +326,6 @@ def emit(*_): self._fallback_storage = MemoryStorage() self._fallback_limiter = STRATEGIES[strategy](self._fallback_storage) - def slowapi_startup(self) -> None: - """ - Starlette startup event handler that links the app with the Limiter instance. - """ - app.state.limiter = self # type: ignore - app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # type: ignore - def get_app_config(self, key: str, default_value: T = None) -> T: """ Place holder until we find a better way to load config from app @@ -486,7 +479,7 @@ def __evaluate_limits( failed_limit = None limit_for_header = None for lim in limits: - limit_scope = lim.scope or endpoint + limit_scope = lim.scope_for(request) or endpoint if lim.is_exempt(request): continue if lim.methods is not None and request.method.lower() not in lim.methods: diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index d5677d4..94cf790 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -51,18 +51,12 @@ def is_exempt(self, request: Optional[Request] = None) -> bool: return self.exempt_when(request) return self.exempt_when() - @property - def scope(self) -> str: - # flack.request.endpoint is the name of the function for the endpoint - # FIXME: how to get the request here? + def scope_for(self, request: Request) -> str: if self.__scope is None: return "" - else: - return ( - self.__scope(request.endpoint) # type: ignore - if callable(self.__scope) - else self.__scope - ) + if callable(self.__scope): + return self.__scope(request) + return self.__scope class LimitGroup(object): diff --git a/tests/test_starlette_extension.py b/tests/test_starlette_extension.py index 0e26baa..6c0414b 100644 --- a/tests/test_starlette_extension.py +++ b/tests/test_starlette_extension.py @@ -121,6 +121,36 @@ def t2(request: Request): # the shared limit has already been hit via t1 assert client.get("/t2").status_code == 429 + def test_shared_decorator_callable_scope(self, build_starlette_app): + """Callable scope receives the request and buckets are keyed by its return value.""" + app, limiter = build_starlette_app(key_func=get_ipaddr) + + def scope_from_tenant(request: Request) -> str: + return request.headers.get("X-Tenant", "default") + + shared_lim = limiter.shared_limit("5/minute", scope=scope_from_tenant) + + @shared_lim + def t1(request: Request): + return PlainTextResponse("test") + + @shared_lim + def t2(request: Request): + return PlainTextResponse("test") + + app.add_route("/t1", t1) + app.add_route("/t2", t2) + + client = TestClient(app) + # tenant A burns its budget on /t1 ... + for i in range(10): + resp = client.get("/t1", headers={"X-Tenant": "A"}) + assert resp.status_code == (200 if i < 5 else 429) + # ... and /t2 is also exhausted for tenant A (shared scope) + assert client.get("/t2", headers={"X-Tenant": "A"}).status_code == 429 + # but tenant B has its own bucket — scope callable isolates them + assert client.get("/t2", headers={"X-Tenant": "B"}).status_code == 200 + def test_multiple_decorators(self, build_starlette_app): app, limiter = build_starlette_app(key_func=get_ipaddr)