Add rate limiting to FastAPI server

- Default rate limit: 2 req/sec
- Static files: 10 req/sec
- WebSocket endpoint: 1 req/5s
- Authenticate endpoint: 1 req/5s

Uses fastapi-limiter with Redis backend for rate limiting implementation.
This commit is contained in:
openhands
2024-11-12 21:13:04 +00:00
parent 59f7093428
commit ce9963db01

View File

@@ -7,7 +7,10 @@ import uuid
import warnings
import jwt
import redis.asyncio as redis
import requests
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern
@@ -29,6 +32,7 @@ with warnings.catch_warnings():
from dotenv import load_dotenv
from fastapi import (
BackgroundTasks,
Depends,
FastAPI,
HTTPException,
Request,
@@ -73,7 +77,7 @@ file_store = get_file_store(config.file_store, config.file_store_path)
session_manager = SessionManager(config, file_store)
app = FastAPI()
app = FastAPI(dependencies=[Depends(RateLimiter(times=2, seconds=1))]) # Default 2 req/sec
app.add_middleware(
LocalhostCORSMiddleware,
allow_credentials=True,
@@ -81,6 +85,11 @@ app.add_middleware(
allow_headers=['*'],
)
@app.on_event("startup")
async def startup():
redis_instance = redis.from_url("redis://localhost", encoding="utf-8", decode_responses=True)
await FastAPILimiter.init(redis_instance)
app.add_middleware(NoCacheMiddleware)
@@ -251,6 +260,7 @@ async def attach_session(request: Request, call_next):
@app.websocket('/ws')
@RateLimiter(times=1, seconds=5) # 1 request per 5 seconds
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for receiving events from the client (i.e., the browser).
Once connected, the client can send various actions:
@@ -861,6 +871,7 @@ def github_callback(auth_code: AuthCode):
@app.post('/api/authenticate')
@RateLimiter(times=1, seconds=5) # 1 request per 5 seconds
async def authenticate(request: Request):
token = request.headers.get('X-GitHub-Token')
if not await authenticate_github_user(token):
@@ -900,5 +911,12 @@ class SPAStaticFiles(StaticFiles):
# FIXME: just making this HTTPException doesn't work for some reason
return await super().get_response('index.html', scope)
async def __call__(self, scope, receive, send) -> None:
if scope["type"] == "http":
# Apply rate limiting
limiter = RateLimiter(times=10, seconds=1) # 10 requests per second
await limiter(scope, receive, send)
return await super().__call__(scope, receive, send)
app.mount('/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist')