OpenHands/enterprise/server/rate_limit.py
2025-09-04 15:44:54 -04:00

138 lines
4.6 KiB
Python

"""
Usage:
Call setup_rate_limit_handler on your FastAPI app to add the exception handler
Create a rate limiter like:
`rate_limiter = create_redis_rate_limiter("10/second; 100/minute")`
Call hit() with some key and allow the RateLimitException to propagate:
`rate_limiter.hit('some action', user_id)`
"""
import time
from dataclasses import dataclass
import limits
from fastapi.responses import JSONResponse
from starlette.applications import Request, Response, Starlette
from starlette.exceptions import HTTPException
from storage.redis import get_redis_authed_url
from openhands.core.logger import openhands_logger as logger
def setup_rate_limit_handler(app: Starlette):
"""
Add exception handler that
"""
app.add_exception_handler(RateLimitException, _rate_limit_exceeded_handler)
@dataclass
class RateLimitResult:
"""Result of a rate limit check, times in seconds"""
description: str
remaining: int
reset_time: int
retry_after: int | None = None
def add_headers(self, response: Response) -> None:
"""Add rate limit headers to a response"""
response.headers['X-RateLimit-Limit'] = self.description
response.headers['X-RateLimit-Remaining'] = str(self.remaining)
response.headers['X-RateLimit-Reset'] = str(self.reset_time)
if self.retry_after is not None:
response.headers['Retry-After'] = str(self.retry_after)
class RateLimiter:
strategy: limits.aio.strategies.RateLimiter
limit_items: list[limits.RateLimitItem]
def __init__(self, strategy: limits.aio.strategies.RateLimiter, windows: str):
self.strategy = strategy
self.limit_items = limits.parse_many(windows)
async def hit(self, namespace: str, key: str):
"""
Raises RateLimitException when limit is hit.
Logs and swallows exceptions and logs if lookup fails.
"""
for lim in self.limit_items:
allowed = True
try:
allowed = await self.strategy.hit(lim, namespace, key)
except Exception:
logger.exception('Rate limit check could not complete, redis issue?')
if not allowed:
logger.info(f'Rate limit hit for {namespace}:{key}')
try:
result = await self._get_stats_as_result(lim, namespace, key)
except Exception:
logger.exception(
'Rate limit exceeded but window lookup failed, swallowing'
)
else:
raise RateLimitException(result)
async def _get_stats_as_result(
self, lim: limits.RateLimitItem, namespace: str, key: str
) -> RateLimitResult:
"""
Lookup rate limit window stats and return a RateLimitResult with the data needed for response headers.
"""
stats: limits.WindowStats = await self.strategy.get_window_stats(
lim, namespace, key
)
return RateLimitResult(
description=str(lim),
remaining=stats.remaining,
reset_time=int(stats.reset_time),
retry_after=int(stats.reset_time - time.time())
if stats.remaining == 0
else None,
)
def create_redis_rate_limiter(windows: str) -> RateLimiter:
"""
Create a RateLimiter with the Redis backend and "Fixed Window" strategy.
windows arg example: "10/second; 100/minute"
"""
backend = limits.aio.storage.RedisStorage(f'async+{get_redis_authed_url()}')
strategy = limits.aio.strategies.FixedWindowRateLimiter(backend)
return RateLimiter(strategy, windows)
class RateLimitException(HTTPException):
"""
exception raised when a rate limit is hit.
"""
result: RateLimitResult
def __init__(self, result: RateLimitResult) -> None:
self.result = result
super(RateLimitException, self).__init__(
status_code=429, detail=result.description
)
def _rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
"""
Build a simple JSON response that includes the details of the rate limit that was hit.
"""
logger.info(exc.__class__.__name__)
if isinstance(exc, RateLimitException):
response = JSONResponse(
{'error': f'Rate limit exceeded: { exc.detail}'}, status_code=429
)
if exc.result:
exc.result.add_headers(response)
else:
# Shouldn't happen, this handler is only bound to RateLimitException
response = JSONResponse({'error': 'Rate limit exceeded'}, status_code=429)
return response