diff --git a/.gitignore b/.gitignore index d590aa2..50c6bd5 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ __pycache__ .idea # coverage -.coverage \ No newline at end of file +.coverage +.venv diff --git a/slowapi/__init__.py b/slowapi/__init__.py index cfa284e..a25c4be 100644 --- a/slowapi/__init__.py +++ b/slowapi/__init__.py @@ -1,3 +1,3 @@ -from .extension import Limiter, _rate_limit_exceeded_handler +from .extension import Limiter, _current_request, _rate_limit_exceeded_handler -__all__ = ["Limiter", "_rate_limit_exceeded_handler"] +__all__ = ["Limiter", "_current_request", "_rate_limit_exceeded_handler"] diff --git a/slowapi/extension.py b/slowapi/extension.py index 050f882..808e034 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -9,6 +9,7 @@ import logging import os import time +from contextvars import ContextVar from datetime import datetime from email.utils import formatdate, parsedate_to_datetime from functools import wraps @@ -39,6 +40,11 @@ # used to annotate get_app_config method T = TypeVar("T") +# ContextVar to hold the current request, set by middleware so that +# endpoints don't need an explicit ``request`` parameter. +_current_request: ContextVar[Optional[Request]] = ContextVar( + "_current_request", default=None +) # Define an alias for the most commonly used type StrOrCallableStr = Union[str, Callable[..., str]] @@ -387,7 +393,7 @@ def _inject_headers( window_stats: Tuple[int, int] = self.limiter.get_window_stats( current_limit[0], *current_limit[1] ) - reset_in = 1 + window_stats[0] + reset_in = int(1 + window_stats[0]) response.headers.append( self._header_mapping[HEADERS.LIMIT], str(current_limit[0].amount) ) @@ -443,7 +449,7 @@ def _inject_asgi_headers( window_stats: Tuple[int, int] = self.limiter.get_window_stats( current_limit[0], *current_limit[1] ) - reset_in = 1 + window_stats[0] + reset_in = int(1 + window_stats[0]) headers[self._header_mapping[HEADERS.LIMIT]] = str( current_limit[0].amount ) @@ -705,13 +711,32 @@ def decorator(func: Callable[..., Response]): self._route_limits.setdefault(name, []).extend(static_limits) sig = inspect.signature(func) + idx: Optional[int] = None for idx, parameter in enumerate(sig.parameters.values()): if parameter.name == "request" or parameter.name == "websocket": break else: - raise Exception( - f'No "request" or "websocket" argument on function "{func}"' - ) + # No explicit request param – will resolve from ContextVar at runtime + idx = None + + def _resolve_request(args: Any, kwargs: Any) -> Request: + """Get the request from function args or fall back to ContextVar.""" + request = None + if idx is not None: + request = kwargs.get("request", args[idx] if args else None) + if request is None: + request = _current_request.get(None) + if not isinstance(request, Request): + if idx is not None: + raise Exception( + "parameter `request` must be an instance of starlette.requests.Request" + ) + raise Exception( + f'No "request" or "websocket" argument on function "{func}" ' + "and no request found in context. Either add a `request: Request` " + "parameter or ensure SlowAPI middleware is installed." + ) + return request if asyncio.iscoroutinefunction(func): # Handle async request/response functions. @@ -719,11 +744,7 @@ def decorator(func: Callable[..., Response]): async def async_wrapper(*args: Any, **kwargs: Any) -> Response: # get the request object from the decorated endpoint function if self.enabled: - request = kwargs.get("request", args[idx] if args else None) - if not isinstance(request, Request): - raise Exception( - "parameter `request` must be an instance of starlette.requests.Request" - ) + request = _resolve_request(args, kwargs) if self._auto_check and not getattr( request.state, "_rate_limiting_complete", False @@ -752,11 +773,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Response: def sync_wrapper(*args: Any, **kwargs: Any) -> Response: # get the request object from the decorated endpoint function if self.enabled: - request = kwargs.get("request", args[idx] if args else None) - if not isinstance(request, Request): - raise Exception( - "parameter `request` must be an instance of starlette.requests.Request" - ) + request = _resolve_request(args, kwargs) if self._auto_check and not getattr( request.state, "_rate_limiting_complete", False diff --git a/slowapi/middleware.py b/slowapi/middleware.py index 76cdeec..dbc6339 100644 --- a/slowapi/middleware.py +++ b/slowapi/middleware.py @@ -12,7 +12,7 @@ from starlette.routing import BaseRoute, Match from starlette.types import ASGIApp, Message, Scope, Receive, Send -from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi import Limiter, _current_request, _rate_limit_exceeded_handler def _find_route_handler( @@ -120,23 +120,29 @@ async def dispatch( app: Starlette = request.app limiter: Limiter = app.state.limiter - if not limiter.enabled: - return await call_next(request) + token = _current_request.set(request) + try: + if not limiter.enabled: + return await call_next(request) - handler = _find_route_handler(app.routes, request.scope) - if _should_exempt(limiter, handler): - return await call_next(request) + handler = _find_route_handler(app.routes, request.scope) + if _should_exempt(limiter, handler): + return await call_next(request) - error_response, should_inject_headers = sync_check_limits( - limiter, request, handler, app - ) - if error_response is not None: - return error_response + error_response, should_inject_headers = sync_check_limits( + limiter, request, handler, app + ) + if error_response is not None: + return error_response - response = await call_next(request) - if should_inject_headers: - response = limiter._inject_headers(response, request.state.view_rate_limit) - return response + response = await call_next(request) + if should_inject_headers: + response = limiter._inject_headers( + response, request.state.view_rate_limit + ) + return response + finally: + _current_request.reset(token) class SlowAPIASGIMiddleware: @@ -189,18 +195,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: handler = _find_route_handler(_app.routes, scope) request = Request(scope, receive=receive, send=self.send) - if _should_exempt(limiter, handler): - return await self.app(scope, receive, self.send) - error_response, should_inject_headers = await async_check_limits( - limiter, request, handler, _app - ) - if error_response is not None: - return await error_response(scope, receive, self.send_wrapper) + token = _current_request.set(request) + try: + if _should_exempt(limiter, handler): + return await self.app(scope, receive, self.send) + + error_response, should_inject_headers = await async_check_limits( + limiter, request, handler, _app + ) + if error_response is not None: + return await error_response(scope, receive, self.send_wrapper) - if should_inject_headers: - self.inject_headers = True - self.limiter = limiter - self.request = request + if should_inject_headers: + self.inject_headers = True + self.limiter = limiter + self.request = request - return await self.app(scope, receive, self.send_wrapper) + return await self.app(scope, receive, self.send_wrapper) + finally: + _current_request.reset(token) diff --git a/tests/test_fastapi_extension.py b/tests/test_fastapi_extension.py index 42e6322..bb92621 100644 --- a/tests/test_fastapi_extension.py +++ b/tests/test_fastapi_extension.py @@ -144,35 +144,37 @@ async def t1(request: Request, response: Response): == 429 ) - def test_endpoint_missing_request_param(self, build_fastapi_app): + def test_endpoint_no_request_param_async(self, build_fastapi_app): + """Endpoint without explicit request param works via ContextVar.""" app, limiter = build_fastapi_app(key_func=get_ipaddr) - with pytest.raises(Exception) as exc_info: - - @app.get("/t3") - @limiter.limit("5/minute") - async def t3(): - return PlainTextResponse("test") + @app.get("/t3") + @limiter.limit("5/minute") + async def t3(): + return PlainTextResponse("test") - assert exc_info.match( - r"""^No "request" or "websocket" argument on function .*""" - ) + client = TestClient(app) + for i in range(0, 10): + response = client.get("/t3") + assert response.status_code == 200 if i < 5 else 429 - def test_endpoint_missing_request_param_sync(self, build_fastapi_app): + def test_endpoint_no_request_param_sync(self, build_fastapi_app): + """Sync endpoint without explicit request param works via ContextVar.""" app, limiter = build_fastapi_app(key_func=get_ipaddr) - with pytest.raises(Exception) as exc_info: - - @app.get("/t3_sync") - @limiter.limit("5/minute") - def t3(): - return PlainTextResponse("test") + @app.get("/t3_sync") + @limiter.limit("5/minute") + def t3(): + return PlainTextResponse("test") - assert exc_info.match( - r"""^No "request" or "websocket" argument on function .*""" - ) + client = TestClient(app) + for i in range(0, 10): + response = client.get("/t3_sync") + assert response.status_code == 200 if i < 5 else 429 - def test_endpoint_request_param_invalid(self, build_fastapi_app): + def test_endpoint_request_param_wrong_type_hint(self, build_fastapi_app): + """Even with wrong type hint, FastAPI injects the real Request object + and ContextVar provides a fallback, so rate limiting still works.""" app, limiter = build_fastapi_app(key_func=get_ipaddr) @app.get("/t4") @@ -180,12 +182,10 @@ def test_endpoint_request_param_invalid(self, build_fastapi_app): async def t4(request: str = None): return PlainTextResponse("test") - with pytest.raises(Exception) as exc_info: - client = TestClient(app) - client.get("/t4") - assert exc_info.match( - r"""parameter `request` must be an instance of starlette.requests.Request""" - ) + client = TestClient(app) + for i in range(0, 10): + response = client.get("/t4") + assert response.status_code == 200 if i < 5 else 429 def test_endpoint_response_param_invalid(self, build_fastapi_app): app, limiter = build_fastapi_app(key_func=get_ipaddr, headers_enabled=True) @@ -202,7 +202,9 @@ async def t4(request: Request, response: str = None): r"""parameter `response` must be an instance of starlette.responses.Response""" ) - def test_endpoint_request_param_invalid_sync(self, build_fastapi_app): + def test_endpoint_request_param_wrong_type_hint_sync(self, build_fastapi_app): + """Even with wrong type hint, FastAPI injects the real Request object + and ContextVar provides a fallback, so rate limiting still works.""" app, limiter = build_fastapi_app(key_func=get_ipaddr) @app.get("/t5") @@ -210,12 +212,10 @@ def test_endpoint_request_param_invalid_sync(self, build_fastapi_app): def t5(request: str = None): return PlainTextResponse("test") - with pytest.raises(Exception) as exc_info: - client = TestClient(app) - client.get("/t5") - assert exc_info.match( - r"""parameter `request` must be an instance of starlette.requests.Request""" - ) + client = TestClient(app) + for i in range(0, 10): + response = client.get("/t5") + assert response.status_code == 200 if i < 5 else 429 def test_endpoint_response_param_invalid_sync(self, build_fastapi_app): app, limiter = build_fastapi_app(key_func=get_ipaddr, headers_enabled=True) @@ -232,6 +232,25 @@ def t5(request: Request, response: str = None): r"""parameter `response` must be an instance of starlette.responses.Response""" ) + def test_endpoint_no_request_param_no_middleware(self): + """Without middleware and without request param, a clear error is raised.""" + from fastapi import FastAPI + from slowapi import Limiter + from slowapi.util import get_ipaddr as _get_ipaddr + + limiter = Limiter(key_func=_get_ipaddr) + app = FastAPI() + app.state.limiter = limiter + + @app.get("/t_no_mw") + @limiter.limit("5/minute") + async def t_no_mw(): + return PlainTextResponse("test") + + with pytest.raises(Exception, match=r"no request found in context"): + client = TestClient(app) + client.get("/t_no_mw") + def test_dynamic_limit_provider_depending_on_key(self, build_fastapi_app): def custom_key_func(request: Request): if request.headers.get("TOKEN") == "secret": diff --git a/tests/test_starlette_extension.py b/tests/test_starlette_extension.py index 0e26baa..a7007e5 100644 --- a/tests/test_starlette_extension.py +++ b/tests/test_starlette_extension.py @@ -175,16 +175,17 @@ def test_headers_no_breach(self, build_starlette_app): headers_enabled=True, key_func=get_remote_address ) - @app.route("/t1") @limiter.limit("10/minute") def t1(request: Request): return PlainTextResponse("test") - @app.route("/t2") @limiter.limit("2/second; 5 per minute; 10/hour") def t2(request: Request): return PlainTextResponse("test") + app.add_route("/t1", t1) + app.add_route("/t2", t2) + with hiro.Timeline().freeze(): with TestClient(app) as cli: resp = cli.get("/t1") @@ -208,11 +209,12 @@ def test_headers_breach(self, build_starlette_app): headers_enabled=True, key_func=get_remote_address ) - @app.route("/t1") @limiter.limit("2/second; 10 per minute; 20/hour") def t(request: Request): return PlainTextResponse("test") + app.add_route("/t1", t) + with hiro.Timeline().freeze() as timeline: with TestClient(app) as cli: for i in range(11): @@ -233,11 +235,12 @@ def test_retry_after(self, build_starlette_app): headers_enabled=True, key_func=get_remote_address ) - @app.route("/t1") @limiter.limit("1/minute") def t(request: Request): return PlainTextResponse("test") + app.add_route("/t1", t) + with hiro.Timeline().freeze() as timeline: with TestClient(app) as cli: resp = cli.get("/t1") @@ -254,34 +257,37 @@ def test_exempt_decorator(self, build_starlette_app): default_limits=["1/minute"], ) - @app.route("/t1") def t1(request: Request): return PlainTextResponse("test") + app.add_route("/t1", t1) + with TestClient(app) as cli: resp = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.10"}) assert resp.status_code == 200 resp2 = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.10"}) assert resp2.status_code == 429 - @app.route("/t2") @limiter.exempt def t2(request: Request): """Exempt a sync route""" return PlainTextResponse("test") + app.add_route("/t2", t2) + with TestClient(app) as cli: resp = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.10"}) assert resp.status_code == 200 resp2 = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.10"}) assert resp2.status_code == 200 - @app.route("/t3") @limiter.exempt async def t3(request: Request): """Exempt an async route""" return PlainTextResponse("test") + app.add_route("/t3", t3) + with TestClient(app) as cli: resp = cli.get("/t3", headers={"X_FORWARDED_FOR": "127.0.0.10"}) assert resp.status_code == 200