mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
178 lines
6.0 KiB
Python
178 lines
6.0 KiB
Python
import asyncio
|
|
from collections import defaultdict
|
|
from datetime import datetime, timedelta
|
|
from typing import Callable
|
|
from urllib.parse import urlparse
|
|
|
|
import jwt
|
|
from fastapi import APIRouter, Request, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.types import ASGIApp
|
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
from openhands.server.auth import get_sid_from_token
|
|
from openhands.server.github_utils import UserVerifier
|
|
from openhands.server.shared import config, session_manager
|
|
|
|
|
|
class LocalhostCORSMiddleware(CORSMiddleware):
|
|
"""
|
|
Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
|
|
while using standard CORS rules for other origins.
|
|
"""
|
|
|
|
def __init__(self, app: ASGIApp, **kwargs) -> None:
|
|
super().__init__(app, **kwargs)
|
|
|
|
def is_allowed_origin(self, origin: str) -> bool:
|
|
if origin:
|
|
parsed = urlparse(origin)
|
|
hostname = parsed.hostname or ''
|
|
|
|
# Allow any localhost/127.0.0.1 origin regardless of port
|
|
if hostname in ['localhost', '127.0.0.1']:
|
|
return True
|
|
|
|
# For missing origin or other origins, use the parent class's logic
|
|
return super().is_allowed_origin(origin)
|
|
|
|
|
|
class NoCacheMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Middleware to disable caching for all routes by adding appropriate headers
|
|
"""
|
|
|
|
async def dispatch(self, request, call_next):
|
|
response = await call_next(request)
|
|
if not request.url.path.startswith('/assets'):
|
|
response.headers['Cache-Control'] = (
|
|
'no-cache, no-store, must-revalidate, max-age=0'
|
|
)
|
|
response.headers['Pragma'] = 'no-cache'
|
|
response.headers['Expires'] = '0'
|
|
return response
|
|
|
|
|
|
class InMemoryRateLimiter:
|
|
history: dict
|
|
requests: int
|
|
seconds: int
|
|
sleep_seconds: int
|
|
|
|
def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
|
|
self.requests = requests
|
|
self.seconds = seconds
|
|
self.sleep_seconds = sleep_seconds
|
|
self.history = defaultdict(list)
|
|
self.sleep_seconds = sleep_seconds
|
|
|
|
def _clean_old_requests(self, key: str) -> None:
|
|
now = datetime.now()
|
|
cutoff = now - timedelta(seconds=self.seconds)
|
|
self.history[key] = [ts for ts in self.history[key] if ts > cutoff]
|
|
|
|
async def __call__(self, request: Request) -> bool:
|
|
key = request.client.host
|
|
now = datetime.now()
|
|
|
|
self._clean_old_requests(key)
|
|
|
|
self.history[key].append(now)
|
|
|
|
if len(self.history[key]) > self.requests * 2:
|
|
return False
|
|
elif len(self.history[key]) > self.requests:
|
|
if self.sleep_seconds > 0:
|
|
await asyncio.sleep(self.sleep_seconds)
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
|
|
super().__init__(app)
|
|
self.rate_limiter = rate_limiter
|
|
|
|
async def dispatch(self, request, call_next):
|
|
ok = await self.rate_limiter(request)
|
|
if not ok:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={'message': 'Too many requests'},
|
|
headers={'Retry-After': '1'},
|
|
)
|
|
return await call_next(request)
|
|
|
|
|
|
class AttachSessionMiddleware:
|
|
def __init__(self, app, target_router: APIRouter):
|
|
self.app = app
|
|
self.target_router = target_router
|
|
self.target_paths = {route.path for route in target_router.routes}
|
|
|
|
async def __call__(self, request: Request, call_next: Callable):
|
|
do_attach = False
|
|
if request.url.path in self.target_paths:
|
|
do_attach = True
|
|
|
|
if request.method == 'OPTIONS':
|
|
do_attach = False
|
|
|
|
if not do_attach:
|
|
return await call_next(request)
|
|
|
|
user_verifier = UserVerifier()
|
|
if user_verifier.is_active():
|
|
signed_token = request.cookies.get('github_auth')
|
|
if not signed_token:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={'error': 'Not authenticated'},
|
|
)
|
|
try:
|
|
jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256'])
|
|
except Exception as e:
|
|
logger.warning(f'Invalid token: {e}')
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={'error': 'Invalid token'},
|
|
)
|
|
|
|
if not request.headers.get('Authorization'):
|
|
logger.warning('Missing Authorization header')
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={'error': 'Missing Authorization header'},
|
|
)
|
|
|
|
auth_token = request.headers.get('Authorization')
|
|
if 'Bearer' in auth_token:
|
|
auth_token = auth_token.split('Bearer')[1].strip()
|
|
|
|
request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
|
|
if request.state.sid == '':
|
|
logger.warning('Invalid token')
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={'error': 'Invalid token'},
|
|
)
|
|
|
|
request.state.conversation = await session_manager.attach_to_conversation(
|
|
request.state.sid
|
|
)
|
|
if request.state.conversation is None:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
content={'error': 'Session not found'},
|
|
)
|
|
try:
|
|
response = await call_next(request)
|
|
finally:
|
|
await session_manager.detach_from_conversation(request.state.conversation)
|
|
return response
|