diff --git a/openhands/server/app.py b/openhands/server/app.py index 4e1a094317..5568ccd00f 100644 --- a/openhands/server/app.py +++ b/openhands/server/app.py @@ -10,13 +10,6 @@ from fastapi import ( import openhands.agenthub # noqa F401 (we import this to get the agents registered) from openhands import __version__ -from openhands.server.middleware import ( - AttachConversationMiddleware, - CacheControlMiddleware, - InMemoryRateLimiter, - LocalhostCORSMiddleware, - RateLimitMiddleware, -) from openhands.server.routes.conversation import app as conversation_api_router from openhands.server.routes.feedback import app as feedback_api_router from openhands.server.routes.files import app as files_api_router @@ -29,7 +22,6 @@ from openhands.server.routes.security import app as security_api_router from openhands.server.routes.settings import app as settings_router from openhands.server.routes.trajectory import app as trajectory_router from openhands.server.shared import openhands_config, session_manager -from openhands.utils.import_utils import get_impl @asynccontextmanager @@ -44,17 +36,7 @@ app = FastAPI( version=__version__, lifespan=_lifespan, ) -app.add_middleware( - LocalhostCORSMiddleware, - allow_credentials=True, - allow_methods=['*'], - allow_headers=['*'], -) - -app.add_middleware(CacheControlMiddleware) -app.add_middleware( - RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=10, seconds=1) -) +openhands_config.attach_middleware(app) @app.get('/health') @@ -71,8 +53,3 @@ app.include_router(manage_conversation_api_router) app.include_router(settings_router) app.include_router(github_api_router) app.include_router(trajectory_router) - -AttachConversationMiddlewareImpl = get_impl( - AttachConversationMiddleware, openhands_config.attach_conversation_middleware_path -) -app.middleware('http')(AttachConversationMiddlewareImpl(app)) diff --git a/openhands/server/config/openhands_config.py b/openhands/server/config/openhands_config.py index c4d472e06b..228d263db2 100644 --- a/openhands/server/config/openhands_config.py +++ b/openhands/server/config/openhands_config.py @@ -1,8 +1,15 @@ import os -from fastapi import HTTPException +from fastapi import FastAPI, HTTPException from openhands.core.logger import openhands_logger as logger +from openhands.server.middleware import ( + AttachConversationMiddleware, + CacheControlMiddleware, + InMemoryRateLimiter, + LocalhostCORSMiddleware, + RateLimitMiddleware, +) from openhands.server.types import AppMode, OpenhandsConfigInterface from openhands.utils.import_utils import get_impl @@ -12,9 +19,6 @@ class OpenhandsConfig(OpenhandsConfigInterface): app_mode = AppMode.OSS posthog_client_key = 'phc_3ESMmY9SgqEAGBB6sMGK5ayYHkeUuknH2vP6FmWH9RA' github_client_id = os.environ.get('GITHUB_APP_CLIENT_ID', '') - attach_conversation_middleware_path = ( - 'openhands.server.middleware.AttachConversationMiddleware' - ) settings_store_class: str = ( 'openhands.storage.settings.file_settings_store.FileSettingsStore' ) @@ -42,6 +46,21 @@ class OpenhandsConfig(OpenhandsConfigInterface): return config + def attach_middleware(self, api: FastAPI) -> None: + api.add_middleware( + LocalhostCORSMiddleware, + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], + ) + + api.add_middleware(CacheControlMiddleware) + api.add_middleware( + RateLimitMiddleware, + rate_limiter=InMemoryRateLimiter(requests=10, seconds=1), + ) + api.middleware('http')(AttachConversationMiddleware(api)) + def load_openhands_config(): config_cls = os.environ.get('OPENHANDS_CONFIG_CLS', None) diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index eb0e4ef3f2..cd5ceec1db 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -11,7 +11,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request as StarletteRequest from starlette.types import ASGIApp -from openhands.server.shared import session_manager +from openhands.server import shared from openhands.server.types import SessionMiddlewareInterface @@ -146,8 +146,8 @@ class AttachConversationMiddleware(SessionMiddlewareInterface): """ Attach the user's session based on the provided authentication token. """ - request.state.conversation = await session_manager.attach_to_conversation( - request.state.sid + request.state.conversation = ( + await shared.session_manager.attach_to_conversation(request.state.sid) ) if not request.state.conversation: return JSONResponse( @@ -160,7 +160,9 @@ class AttachConversationMiddleware(SessionMiddlewareInterface): """ Detach the user's session. """ - await session_manager.detach_from_conversation(request.state.conversation) + await shared.session_manager.detach_from_conversation( + request.state.conversation + ) async def __call__(self, request: Request, call_next: Callable): if not self._should_attach(request): diff --git a/openhands/server/types.py b/openhands/server/types.py index cbf9389d2b..d241cc8995 100644 --- a/openhands/server/types.py +++ b/openhands/server/types.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from enum import Enum from typing import ClassVar, Protocol +from fastapi import FastAPI + class AppMode(Enum): OSS = 'oss' @@ -36,6 +38,11 @@ class OpenhandsConfigInterface(ABC): """Configure attributes for frontend""" raise NotImplementedError + @abstractmethod + def attach_middleware(self, api: FastAPI) -> None: + """Attach required middleware for the current environment""" + raise NotImplementedError + class MissingSettingsError(ValueError): """Raised when settings are missing or not found."""