Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ __pycache__
.idea

# coverage
.coverage
.coverage
.venv
4 changes: 2 additions & 2 deletions slowapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .extension import Limiter, _rate_limit_exceeded_handler
from .extension import Limiter, _current_request, _rate_limit_exceeded_handler

__all__ = ["Limiter", "_rate_limit_exceeded_handler"]
__all__ = ["Limiter", "_current_request", "_rate_limit_exceeded_handler"]
47 changes: 32 additions & 15 deletions slowapi/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import time
from contextvars import ContextVar
from datetime import datetime
from email.utils import formatdate, parsedate_to_datetime
from functools import wraps
Expand Down Expand Up @@ -39,6 +40,11 @@

# used to annotate get_app_config method
T = TypeVar("T")
# ContextVar to hold the current request, set by middleware so that
# endpoints don't need an explicit ``request`` parameter.
_current_request: ContextVar[Optional[Request]] = ContextVar(
"_current_request", default=None
)
# Define an alias for the most commonly used type
StrOrCallableStr = Union[str, Callable[..., str]]

Expand Down Expand Up @@ -387,7 +393,7 @@ def _inject_headers(
window_stats: Tuple[int, int] = self.limiter.get_window_stats(
current_limit[0], *current_limit[1]
)
reset_in = 1 + window_stats[0]
reset_in = int(1 + window_stats[0])
response.headers.append(
self._header_mapping[HEADERS.LIMIT], str(current_limit[0].amount)
)
Expand Down Expand Up @@ -443,7 +449,7 @@ def _inject_asgi_headers(
window_stats: Tuple[int, int] = self.limiter.get_window_stats(
current_limit[0], *current_limit[1]
)
reset_in = 1 + window_stats[0]
reset_in = int(1 + window_stats[0])
headers[self._header_mapping[HEADERS.LIMIT]] = str(
current_limit[0].amount
)
Expand Down Expand Up @@ -705,25 +711,40 @@ def decorator(func: Callable[..., Response]):
self._route_limits.setdefault(name, []).extend(static_limits)

sig = inspect.signature(func)
idx: Optional[int] = None
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
break
else:
raise Exception(
f'No "request" or "websocket" argument on function "{func}"'
)
# No explicit request param – will resolve from ContextVar at runtime
idx = None

def _resolve_request(args: Any, kwargs: Any) -> Request:
"""Get the request from function args or fall back to ContextVar."""
request = None
if idx is not None:
request = kwargs.get("request", args[idx] if args else None)
if request is None:
request = _current_request.get(None)
if not isinstance(request, Request):
if idx is not None:
raise Exception(
"parameter `request` must be an instance of starlette.requests.Request"
)
raise Exception(
f'No "request" or "websocket" argument on function "{func}" '
"and no request found in context. Either add a `request: Request` "
"parameter or ensure SlowAPI middleware is installed."
)
return request

if asyncio.iscoroutinefunction(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Response:
# get the request object from the decorated endpoint function
if self.enabled:
request = kwargs.get("request", args[idx] if args else None)
if not isinstance(request, Request):
raise Exception(
"parameter `request` must be an instance of starlette.requests.Request"
)
request = _resolve_request(args, kwargs)

if self._auto_check and not getattr(
request.state, "_rate_limiting_complete", False
Expand Down Expand Up @@ -752,11 +773,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Response:
def sync_wrapper(*args: Any, **kwargs: Any) -> Response:
# get the request object from the decorated endpoint function
if self.enabled:
request = kwargs.get("request", args[idx] if args else None)
if not isinstance(request, Request):
raise Exception(
"parameter `request` must be an instance of starlette.requests.Request"
)
request = _resolve_request(args, kwargs)

if self._auto_check and not getattr(
request.state, "_rate_limiting_complete", False
Expand Down
65 changes: 38 additions & 27 deletions slowapi/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from starlette.routing import BaseRoute, Match
from starlette.types import ASGIApp, Message, Scope, Receive, Send

from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi import Limiter, _current_request, _rate_limit_exceeded_handler


def _find_route_handler(
Expand Down Expand Up @@ -120,23 +120,29 @@ async def dispatch(
app: Starlette = request.app
limiter: Limiter = app.state.limiter

if not limiter.enabled:
return await call_next(request)
token = _current_request.set(request)
try:
if not limiter.enabled:
return await call_next(request)

handler = _find_route_handler(app.routes, request.scope)
if _should_exempt(limiter, handler):
return await call_next(request)
handler = _find_route_handler(app.routes, request.scope)
if _should_exempt(limiter, handler):
return await call_next(request)

error_response, should_inject_headers = sync_check_limits(
limiter, request, handler, app
)
if error_response is not None:
return error_response
error_response, should_inject_headers = sync_check_limits(
limiter, request, handler, app
)
if error_response is not None:
return error_response

response = await call_next(request)
if should_inject_headers:
response = limiter._inject_headers(response, request.state.view_rate_limit)
return response
response = await call_next(request)
if should_inject_headers:
response = limiter._inject_headers(
response, request.state.view_rate_limit
)
return response
finally:
_current_request.reset(token)


class SlowAPIASGIMiddleware:
Expand Down Expand Up @@ -189,18 +195,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

handler = _find_route_handler(_app.routes, scope)
request = Request(scope, receive=receive, send=self.send)
if _should_exempt(limiter, handler):
return await self.app(scope, receive, self.send)

error_response, should_inject_headers = await async_check_limits(
limiter, request, handler, _app
)
if error_response is not None:
return await error_response(scope, receive, self.send_wrapper)
token = _current_request.set(request)
try:
if _should_exempt(limiter, handler):
return await self.app(scope, receive, self.send)

error_response, should_inject_headers = await async_check_limits(
limiter, request, handler, _app
)
if error_response is not None:
return await error_response(scope, receive, self.send_wrapper)

if should_inject_headers:
self.inject_headers = True
self.limiter = limiter
self.request = request
if should_inject_headers:
self.inject_headers = True
self.limiter = limiter
self.request = request

return await self.app(scope, receive, self.send_wrapper)
return await self.app(scope, receive, self.send_wrapper)
finally:
_current_request.reset(token)
87 changes: 53 additions & 34 deletions tests/test_fastapi_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,48 +144,48 @@ async def t1(request: Request, response: Response):
== 429
)

def test_endpoint_missing_request_param(self, build_fastapi_app):
def test_endpoint_no_request_param_async(self, build_fastapi_app):
"""Endpoint without explicit request param works via ContextVar."""
app, limiter = build_fastapi_app(key_func=get_ipaddr)

with pytest.raises(Exception) as exc_info:

@app.get("/t3")
@limiter.limit("5/minute")
async def t3():
return PlainTextResponse("test")
@app.get("/t3")
@limiter.limit("5/minute")
async def t3():
return PlainTextResponse("test")

assert exc_info.match(
r"""^No "request" or "websocket" argument on function .*"""
)
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t3")
assert response.status_code == 200 if i < 5 else 429

def test_endpoint_missing_request_param_sync(self, build_fastapi_app):
def test_endpoint_no_request_param_sync(self, build_fastapi_app):
"""Sync endpoint without explicit request param works via ContextVar."""
app, limiter = build_fastapi_app(key_func=get_ipaddr)

with pytest.raises(Exception) as exc_info:

@app.get("/t3_sync")
@limiter.limit("5/minute")
def t3():
return PlainTextResponse("test")
@app.get("/t3_sync")
@limiter.limit("5/minute")
def t3():
return PlainTextResponse("test")

assert exc_info.match(
r"""^No "request" or "websocket" argument on function .*"""
)
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t3_sync")
assert response.status_code == 200 if i < 5 else 429

def test_endpoint_request_param_invalid(self, build_fastapi_app):
def test_endpoint_request_param_wrong_type_hint(self, build_fastapi_app):
"""Even with wrong type hint, FastAPI injects the real Request object
and ContextVar provides a fallback, so rate limiting still works."""
app, limiter = build_fastapi_app(key_func=get_ipaddr)

@app.get("/t4")
@limiter.limit("5/minute")
async def t4(request: str = None):
return PlainTextResponse("test")

with pytest.raises(Exception) as exc_info:
client = TestClient(app)
client.get("/t4")
assert exc_info.match(
r"""parameter `request` must be an instance of starlette.requests.Request"""
)
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t4")
assert response.status_code == 200 if i < 5 else 429

def test_endpoint_response_param_invalid(self, build_fastapi_app):
app, limiter = build_fastapi_app(key_func=get_ipaddr, headers_enabled=True)
Expand All @@ -202,20 +202,20 @@ async def t4(request: Request, response: str = None):
r"""parameter `response` must be an instance of starlette.responses.Response"""
)

def test_endpoint_request_param_invalid_sync(self, build_fastapi_app):
def test_endpoint_request_param_wrong_type_hint_sync(self, build_fastapi_app):
"""Even with wrong type hint, FastAPI injects the real Request object
and ContextVar provides a fallback, so rate limiting still works."""
app, limiter = build_fastapi_app(key_func=get_ipaddr)

@app.get("/t5")
@limiter.limit("5/minute")
def t5(request: str = None):
return PlainTextResponse("test")

with pytest.raises(Exception) as exc_info:
client = TestClient(app)
client.get("/t5")
assert exc_info.match(
r"""parameter `request` must be an instance of starlette.requests.Request"""
)
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t5")
assert response.status_code == 200 if i < 5 else 429

def test_endpoint_response_param_invalid_sync(self, build_fastapi_app):
app, limiter = build_fastapi_app(key_func=get_ipaddr, headers_enabled=True)
Expand All @@ -232,6 +232,25 @@ def t5(request: Request, response: str = None):
r"""parameter `response` must be an instance of starlette.responses.Response"""
)

def test_endpoint_no_request_param_no_middleware(self):
"""Without middleware and without request param, a clear error is raised."""
from fastapi import FastAPI
from slowapi import Limiter
from slowapi.util import get_ipaddr as _get_ipaddr

limiter = Limiter(key_func=_get_ipaddr)
app = FastAPI()
app.state.limiter = limiter

@app.get("/t_no_mw")
@limiter.limit("5/minute")
async def t_no_mw():
return PlainTextResponse("test")

with pytest.raises(Exception, match=r"no request found in context"):
client = TestClient(app)
client.get("/t_no_mw")

def test_dynamic_limit_provider_depending_on_key(self, build_fastapi_app):
def custom_key_func(request: Request):
if request.headers.get("TOKEN") == "secret":
Expand Down
Loading