mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
138 lines
4.6 KiB
Python
138 lines
4.6 KiB
Python
"""
|
|
Usage:
|
|
|
|
Call setup_rate_limit_handler on your FastAPI app to add the exception handler
|
|
|
|
Create a rate limiter like:
|
|
`rate_limiter = create_redis_rate_limiter("10/second; 100/minute")`
|
|
|
|
Call hit() with some key and allow the RateLimitException to propagate:
|
|
`rate_limiter.hit('some action', user_id)`
|
|
"""
|
|
|
|
import time
|
|
from dataclasses import dataclass
|
|
|
|
import limits
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.applications import Request, Response, Starlette
|
|
from starlette.exceptions import HTTPException
|
|
from storage.redis import get_redis_authed_url
|
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
|
|
|
|
def setup_rate_limit_handler(app: Starlette):
|
|
"""
|
|
Add exception handler that
|
|
"""
|
|
app.add_exception_handler(RateLimitException, _rate_limit_exceeded_handler)
|
|
|
|
|
|
@dataclass
|
|
class RateLimitResult:
|
|
"""Result of a rate limit check, times in seconds"""
|
|
|
|
description: str
|
|
remaining: int
|
|
reset_time: int
|
|
retry_after: int | None = None
|
|
|
|
def add_headers(self, response: Response) -> None:
|
|
"""Add rate limit headers to a response"""
|
|
response.headers['X-RateLimit-Limit'] = self.description
|
|
response.headers['X-RateLimit-Remaining'] = str(self.remaining)
|
|
response.headers['X-RateLimit-Reset'] = str(self.reset_time)
|
|
if self.retry_after is not None:
|
|
response.headers['Retry-After'] = str(self.retry_after)
|
|
|
|
|
|
class RateLimiter:
|
|
strategy: limits.aio.strategies.RateLimiter
|
|
limit_items: list[limits.RateLimitItem]
|
|
|
|
def __init__(self, strategy: limits.aio.strategies.RateLimiter, windows: str):
|
|
self.strategy = strategy
|
|
self.limit_items = limits.parse_many(windows)
|
|
|
|
async def hit(self, namespace: str, key: str):
|
|
"""
|
|
Raises RateLimitException when limit is hit.
|
|
Logs and swallows exceptions and logs if lookup fails.
|
|
"""
|
|
for lim in self.limit_items:
|
|
allowed = True
|
|
try:
|
|
allowed = await self.strategy.hit(lim, namespace, key)
|
|
except Exception:
|
|
logger.exception('Rate limit check could not complete, redis issue?')
|
|
if not allowed:
|
|
logger.info(f'Rate limit hit for {namespace}:{key}')
|
|
try:
|
|
result = await self._get_stats_as_result(lim, namespace, key)
|
|
except Exception:
|
|
logger.exception(
|
|
'Rate limit exceeded but window lookup failed, swallowing'
|
|
)
|
|
else:
|
|
raise RateLimitException(result)
|
|
|
|
async def _get_stats_as_result(
|
|
self, lim: limits.RateLimitItem, namespace: str, key: str
|
|
) -> RateLimitResult:
|
|
"""
|
|
Lookup rate limit window stats and return a RateLimitResult with the data needed for response headers.
|
|
"""
|
|
stats: limits.WindowStats = await self.strategy.get_window_stats(
|
|
lim, namespace, key
|
|
)
|
|
return RateLimitResult(
|
|
description=str(lim),
|
|
remaining=stats.remaining,
|
|
reset_time=int(stats.reset_time),
|
|
retry_after=int(stats.reset_time - time.time())
|
|
if stats.remaining == 0
|
|
else None,
|
|
)
|
|
|
|
|
|
def create_redis_rate_limiter(windows: str) -> RateLimiter:
|
|
"""
|
|
Create a RateLimiter with the Redis backend and "Fixed Window" strategy.
|
|
windows arg example: "10/second; 100/minute"
|
|
"""
|
|
backend = limits.aio.storage.RedisStorage(f'async+{get_redis_authed_url()}')
|
|
strategy = limits.aio.strategies.FixedWindowRateLimiter(backend)
|
|
return RateLimiter(strategy, windows)
|
|
|
|
|
|
class RateLimitException(HTTPException):
|
|
"""
|
|
exception raised when a rate limit is hit.
|
|
"""
|
|
|
|
result: RateLimitResult
|
|
|
|
def __init__(self, result: RateLimitResult) -> None:
|
|
self.result = result
|
|
super(RateLimitException, self).__init__(
|
|
status_code=429, detail=result.description
|
|
)
|
|
|
|
|
|
def _rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
|
|
"""
|
|
Build a simple JSON response that includes the details of the rate limit that was hit.
|
|
"""
|
|
logger.info(exc.__class__.__name__)
|
|
if isinstance(exc, RateLimitException):
|
|
response = JSONResponse(
|
|
{'error': f'Rate limit exceeded: { exc.detail}'}, status_code=429
|
|
)
|
|
if exc.result:
|
|
exc.result.add_headers(response)
|
|
else:
|
|
# Shouldn't happen, this handler is only bound to RateLimitException
|
|
response = JSONResponse({'error': 'Rate limit exceeded'}, status_code=429)
|
|
return response
|