Skip to content

Commit 4119051

Browse files
authored
Merge pull request #158 from igorbenav/rate-limiter-fix
rate limiter changed from module to class
2 parents 61bb4ce + a0cad4c commit 4119051

File tree

6 files changed

+83
-49
lines changed

6 files changed

+83
-49
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,17 +1342,17 @@ async def your_background_function(
13421342
13431343
### 5.11 Rate Limiting
13441344

1345-
To limit how many times a user can make a request in a certain interval of time (very useful to create subscription plans or just to protect your API against DDOS), you may just use the `rate_limiter` dependency:
1345+
To limit how many times a user can make a request in a certain interval of time (very useful to create subscription plans or just to protect your API against DDOS), you may just use the `rate_limiter_dependency` dependency:
13461346

13471347
```python
13481348
from fastapi import Depends
13491349

1350-
from app.api.dependencies import rate_limiter
1350+
from app.api.dependencies import rate_limiter_dependency
13511351
from app.core.utils import queue
13521352
from app.schemas.job import Job
13531353

13541354

1355-
@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter)])
1355+
@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter_dependency)])
13561356
async def create_task(message: str):
13571357
job = await queue.pool.enqueue_job("sample_background_task", message)
13581358
return {"id": job.job_id}
@@ -1446,7 +1446,7 @@ curl -X POST 'http://127.0.0.1:8000/api/v1/tasks/task?message=test' \
14461446
```
14471447

14481448
> \[!TIP\]
1449-
> Since the `rate_limiter` dependency uses the `get_optional_user` dependency instead of `get_current_user`, it will not require authentication to be used, but will behave accordingly if the user is authenticated (and token is passed in header). If you want to ensure authentication, also use `get_current_user` if you need.
1449+
> Since the `rate_limiter_dependency` dependency uses the `get_optional_user` dependency instead of `get_current_user`, it will not require authentication to be used, but will behave accordingly if the user is authenticated (and token is passed in header). If you want to ensure authentication, also use `get_current_user` if you need.
14501450
14511451
To change a user's tier, you may just use the `PATCH api/v1/user/{username}/tier` endpoint.
14521452
Note that for flexibility (since this is a boilerplate), it's not necessary to previously inform a tier_id to create a user, but you probably should set every user to a certain tier (let's say `free`) once they are created.

src/app/api/dependencies.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..core.exceptions.http_exceptions import ForbiddenException, RateLimitException, UnauthorizedException
99
from ..core.logger import logging
1010
from ..core.security import oauth2_scheme, verify_token
11-
from ..core.utils.rate_limit import is_rate_limited
11+
from ..core.utils.rate_limit import rate_limiter
1212
from ..crud.crud_rate_limit import crud_rate_limits
1313
from ..crud.crud_tier import crud_tiers
1414
from ..crud.crud_users import crud_users
@@ -72,9 +72,12 @@ async def get_current_superuser(current_user: Annotated[dict, Depends(get_curren
7272
return current_user
7373

7474

75-
async def rate_limiter(
75+
async def rate_limiter_dependency(
7676
request: Request, db: Annotated[AsyncSession, Depends(async_get_db)], user: User | None = Depends(get_optional_user)
7777
) -> None:
78+
if hasattr(request.app.state, "initialization_complete"):
79+
await request.app.state.initialization_complete.wait()
80+
7881
path = sanitize_path(request.url.path)
7982
if user:
8083
user_id = user["id"]
@@ -96,6 +99,6 @@ async def rate_limiter(
9699
user_id = request.client.host
97100
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD
98101

99-
is_limited = await is_rate_limited(db=db, user_id=user_id, path=path, limit=limit, period=period)
102+
is_limited = await rate_limiter.is_rate_limited(db=db, user_id=user_id, path=path, limit=limit, period=period)
100103
if is_limited:
101104
raise RateLimitException("Rate limit exceeded.")

src/app/api/v1/rate_limits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ...api.dependencies import get_current_superuser
88
from ...core.db.database import async_get_db
9-
from ...core.exceptions.http_exceptions import DuplicateValueException, NotFoundException, RateLimitException
9+
from ...core.exceptions.http_exceptions import DuplicateValueException, NotFoundException
1010
from ...crud.crud_rate_limit import crud_rate_limits
1111
from ...crud.crud_tier import crud_tiers
1212
from ...schemas.rate_limit import RateLimitCreate, RateLimitCreateInternal, RateLimitRead, RateLimitUpdate

src/app/api/v1/tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from arq.jobs import Job as ArqJob
44
from fastapi import APIRouter, Depends
55

6-
from ...api.dependencies import rate_limiter
6+
from ...api.dependencies import rate_limiter_dependency
77
from ...core.utils import queue
88
from ...schemas.job import Job
99

1010
router = APIRouter(prefix="/tasks", tags=["tasks"])
1111

1212

13-
@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter)])
13+
@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter_dependency)])
1414
async def create_task(message: str) -> dict[str, str]:
1515
"""Create a new background task.
1616

src/app/core/setup.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from fastapi.openapi.utils import get_openapi
1313

1414
from ..api.dependencies import get_current_superuser
15+
from ..core.utils.rate_limit import rate_limiter
1516
from ..middleware.client_cache_middleware import ClientCacheMiddleware
17+
from ..models import *
1618
from .config import (
1719
AppSettings,
1820
ClientSideCacheSettings,
@@ -24,9 +26,10 @@
2426
RedisRateLimiterSettings,
2527
settings,
2628
)
27-
from .db.database import Base, async_engine as engine
29+
from .db.database import Base
30+
from .db.database import async_engine as engine
2831
from .utils import cache, queue, rate_limit
29-
from ..models import *
32+
3033

3134
# -------------- database --------------
3235
async def create_tables() -> None:
@@ -55,8 +58,7 @@ async def close_redis_queue_pool() -> None:
5558

5659
# -------------- rate limit --------------
5760
async def create_redis_rate_limit_pool() -> None:
58-
rate_limit.pool = redis.ConnectionPool.from_url(settings.REDIS_RATE_LIMIT_URL)
59-
rate_limit.client = redis.Redis.from_pool(rate_limit.pool) # type: ignore
61+
rate_limiter.initialize(settings.REDIS_RATE_LIMIT_URL) # type: ignore
6062

6163

6264
async def close_redis_rate_limit_pool() -> None:
@@ -85,30 +87,36 @@ def lifespan_factory(
8587

8688
@asynccontextmanager
8789
async def lifespan(app: FastAPI) -> AsyncGenerator:
90+
from asyncio import Event
91+
92+
initialization_complete = Event()
93+
app.state.initialization_complete = initialization_complete
94+
8895
await set_threadpool_tokens()
8996

90-
if isinstance(settings, DatabaseSettings) and create_tables_on_start:
91-
await create_tables()
97+
try:
98+
if isinstance(settings, RedisCacheSettings):
99+
await create_redis_cache_pool()
92100

93-
if isinstance(settings, RedisCacheSettings):
94-
await create_redis_cache_pool()
101+
if isinstance(settings, RedisQueueSettings):
102+
await create_redis_queue_pool()
95103

96-
if isinstance(settings, RedisQueueSettings):
97-
await create_redis_queue_pool()
104+
if isinstance(settings, RedisRateLimiterSettings):
105+
await create_redis_rate_limit_pool()
98106

99-
if isinstance(settings, RedisRateLimiterSettings):
100-
await create_redis_rate_limit_pool()
107+
initialization_complete.set()
101108

102-
yield
109+
yield
103110

104-
if isinstance(settings, RedisCacheSettings):
105-
await close_redis_cache_pool()
111+
finally:
112+
if isinstance(settings, RedisCacheSettings):
113+
await close_redis_cache_pool()
106114

107-
if isinstance(settings, RedisQueueSettings):
108-
await close_redis_queue_pool()
115+
if isinstance(settings, RedisQueueSettings):
116+
await close_redis_queue_pool()
109117

110-
if isinstance(settings, RedisRateLimiterSettings):
111-
await close_redis_rate_limit_pool()
118+
if isinstance(settings, RedisRateLimiterSettings):
119+
await close_redis_rate_limit_pool()
112120

113121
return lifespan
114122

src/app/core/utils/rate_limit.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import UTC, datetime
2+
from typing import Optional
23

34
from redis.asyncio import ConnectionPool, Redis
45
from sqlalchemy.ext.asyncio import AsyncSession
@@ -8,31 +9,53 @@
89

910
logger = logging.getLogger(__name__)
1011

11-
pool: ConnectionPool | None = None
12-
client: Redis | None = None
1312

13+
class RateLimiter:
14+
_instance: Optional["RateLimiter"] = None
15+
pool: Optional[ConnectionPool] = None
16+
client: Optional[Redis] = None
1417

15-
async def is_rate_limited(db: AsyncSession, user_id: int, path: str, limit: int, period: int) -> bool:
16-
if client is None:
17-
logger.error("Redis client is not initialized.")
18-
raise Exception("Redis client is not initialized.")
18+
def __new__(cls):
19+
if cls._instance is None:
20+
cls._instance = super().__new__(cls)
21+
return cls._instance
1922

20-
current_timestamp = int(datetime.now(UTC).timestamp())
21-
window_start = current_timestamp - (current_timestamp % period)
23+
@classmethod
24+
def initialize(cls, redis_url: str) -> None:
25+
instance = cls()
26+
if instance.pool is None:
27+
instance.pool = ConnectionPool.from_url(redis_url)
28+
instance.client = Redis(connection_pool=instance.pool)
2229

23-
sanitized_path = sanitize_path(path)
24-
key = f"ratelimit:{user_id}:{sanitized_path}:{window_start}"
30+
@classmethod
31+
def get_client(cls) -> Redis:
32+
instance = cls()
33+
if instance.client is None:
34+
logger.error("Redis client is not initialized.")
35+
raise Exception("Redis client is not initialized.")
36+
return instance.client
2537

26-
try:
27-
current_count = await client.incr(key)
28-
if current_count == 1:
29-
await client.expire(key, period)
38+
async def is_rate_limited(self, db: AsyncSession, user_id: int, path: str, limit: int, period: int) -> bool:
39+
client = self.get_client()
40+
current_timestamp = int(datetime.now(UTC).timestamp())
41+
window_start = current_timestamp - (current_timestamp % period)
3042

31-
if current_count > limit:
32-
return True
43+
sanitized_path = sanitize_path(path)
44+
key = f"ratelimit:{user_id}:{sanitized_path}:{window_start}"
3345

34-
except Exception as e:
35-
logger.exception(f"Error checking rate limit for user {user_id} on path {path}: {e}")
36-
raise e
46+
try:
47+
current_count = await client.incr(key)
48+
if current_count == 1:
49+
await client.expire(key, period)
3750

38-
return False
51+
if current_count > limit:
52+
return True
53+
54+
except Exception as e:
55+
logger.exception(f"Error checking rate limit for user {user_id} on path {path}: {e}")
56+
raise e
57+
58+
return False
59+
60+
61+
rate_limiter = RateLimiter()

0 commit comments

Comments
 (0)