Fix typing in server directory (#8375)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
This commit is contained in:
Graham Neubig
2025-05-13 17:27:59 -04:00
committed by GitHub
parent f9b0fcd76e
commit 154eed148f
15 changed files with 79 additions and 56 deletions

View File

@@ -1,14 +1,15 @@
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Callable
from typing import Any
from urllib.parse import urlparse
from fastapi import Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response
from starlette.types import ASGIApp
from openhands.server import shared
@@ -22,7 +23,7 @@ class LocalhostCORSMiddleware(CORSMiddleware):
while using standard CORS rules for other origins.
"""
def __init__(self, app: ASGIApp, **kwargs) -> None:
def __init__(self, app: ASGIApp, **kwargs: Any) -> None:
super().__init__(app, **kwargs)
def is_allowed_origin(self, origin: str) -> bool:
@@ -35,7 +36,8 @@ class LocalhostCORSMiddleware(CORSMiddleware):
return True
# For missing origin or other origins, use the parent class's logic
return super().is_allowed_origin(origin)
result: bool = super().is_allowed_origin(origin)
return result
class CacheControlMiddleware(BaseHTTPMiddleware):
@@ -43,7 +45,9 @@ class CacheControlMiddleware(BaseHTTPMiddleware):
Middleware to disable caching for all routes by adding appropriate headers
"""
async def dispatch(self, request, call_next):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
if request.url.path.startswith('/assets'):
# The content of the assets directory has fingerprinted file names so we cache aggressively
@@ -58,7 +62,7 @@ class CacheControlMiddleware(BaseHTTPMiddleware):
class InMemoryRateLimiter:
history: dict
history: dict[str, list[datetime]]
requests: int
seconds: int
sleep_seconds: int
@@ -100,7 +104,9 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
super().__init__(app)
self.rate_limiter = rate_limiter
async def dispatch(self, request: StarletteRequest, call_next):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if not self.is_rate_limited_request(request):
return await call_next(request)
ok = await self.rate_limiter(request)
@@ -112,7 +118,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
)
return await call_next(request)
def is_rate_limited_request(self, request: StarletteRequest):
def is_rate_limited_request(self, request: StarletteRequest) -> bool:
if request.url.path.startswith('/assets'):
return False
# Put Other non rate limited checks here
@@ -120,10 +126,10 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
class AttachConversationMiddleware(SessionMiddlewareInterface):
def __init__(self, app):
def __init__(self, app: ASGIApp) -> None:
self.app = app
def _should_attach(self, request) -> bool:
def _should_attach(self, request: Request) -> bool:
"""
Determine if the middleware should attach a session for the given request.
"""
@@ -168,7 +174,9 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
request.state.conversation
)
async def __call__(self, request: Request, call_next: Callable):
async def __call__(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if not self._should_attach(request):
return await call_next(request)