From c2a0e525de5fc6f6dc70a8a31e3799af37708e16 Mon Sep 17 00:00:00 2001 From: tofarr Date: Tue, 3 Jun 2025 17:36:45 -0600 Subject: [PATCH] Now using Dependency Injection to associate conversations with requests (#8863) --- openhands/server/listen.py | 2 - openhands/server/middleware.py | 75 +------------------------ openhands/server/routes/conversation.py | 43 +++++--------- openhands/server/routes/feedback.py | 8 ++- openhands/server/routes/files.py | 37 ++++++------ openhands/server/routes/security.py | 9 ++- openhands/server/routes/trajectory.py | 8 ++- openhands/server/utils.py | 19 ++++++- 8 files changed, 65 insertions(+), 136 deletions(-) diff --git a/openhands/server/listen.py b/openhands/server/listen.py index c4da664897..e8ef7ce2ba 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -5,7 +5,6 @@ import socketio from openhands.server.app import app as base_app from openhands.server.listen_socket import sio from openhands.server.middleware import ( - AttachConversationMiddleware, CacheControlMiddleware, InMemoryRateLimiter, LocalhostCORSMiddleware, @@ -24,6 +23,5 @@ base_app.add_middleware( RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=10, seconds=1), ) -base_app.middleware('http')(AttachConversationMiddleware(base_app)) app = socketio.ASGIApp(sio, other_asgi_app=base_app) diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index d1e6d22ff2..cc7f09b024 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -4,7 +4,7 @@ from collections import defaultdict from datetime import datetime, timedelta from urllib.parse import urlparse -from fastapi import Request, status +from fastapi import Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint @@ -12,10 +12,6 @@ from starlette.requests import Request as StarletteRequest from starlette.responses import Response from starlette.types import ASGIApp -from openhands.server import shared -from openhands.server.types import SessionMiddlewareInterface -from openhands.server.user_auth import get_user_id - class LocalhostCORSMiddleware(CORSMiddleware): """ @@ -136,72 +132,3 @@ class RateLimitMiddleware(BaseHTTPMiddleware): return False # Put Other non rate limited checks here return True - - -class AttachConversationMiddleware(SessionMiddlewareInterface): - def __init__(self, app: ASGIApp) -> None: - self.app = app - - def _should_attach(self, request: Request) -> bool: - """ - Determine if the middleware should attach a session for the given request. - """ - if request.method == 'OPTIONS': - return False - - conversation_id = '' - if request.url.path.startswith('/api/conversation'): - # FIXME: we should be able to use path_params - path_parts = request.url.path.split('/') - if len(path_parts) > 4: - conversation_id = request.url.path.split('/')[3] - if not conversation_id: - return False - - request.state.sid = conversation_id - - return True - - async def _attach_conversation(self, request: Request) -> JSONResponse | None: - """ - Attach the user's session based on the provided authentication token. - """ - user_id = await get_user_id(request) - request.state.conversation = ( - await shared.conversation_manager.attach_to_conversation( - request.state.sid, user_id - ) - ) - if not request.state.conversation: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={'error': 'Session not found'}, - ) - return None - - async def _detach_session(self, request: Request) -> None: - """ - Detach the user's session. - """ - await shared.conversation_manager.detach_from_conversation( - request.state.conversation - ) - - async def __call__( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: - if not self._should_attach(request): - return await call_next(request) - - response = await self._attach_conversation(request) - if response: - return response - - try: - # Continue processing the request - response = await call_next(request) - finally: - # Ensure the session is detached - await self._detach_session(request) - - return response diff --git a/openhands/server/routes/conversation.py b/openhands/server/routes/conversation.py index 5905c64fa8..5993c83446 100644 --- a/openhands/server/routes/conversation.py +++ b/openhands/server/routes/conversation.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, Request, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import JSONResponse from openhands.core.logger import openhands_logger as logger @@ -6,18 +6,20 @@ from openhands.events.event_filter import EventFilter from openhands.events.serialization.event import event_to_dict from openhands.runtime.base import Runtime from openhands.server.dependencies import get_dependencies +from openhands.server.session.conversation import ServerConversation from openhands.server.shared import conversation_manager +from openhands.server.utils import get_conversation app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()) @app.get('/config') -async def get_remote_runtime_config(request: Request) -> JSONResponse: +async def get_remote_runtime_config(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse: """Retrieve the runtime configuration. Currently, this is the session ID and runtime ID (if available). """ - runtime = request.state.conversation.runtime + runtime = conversation.runtime runtime_id = runtime.runtime_id if hasattr(runtime, 'runtime_id') else None session_id = runtime.sid if hasattr(runtime, 'sid') else None return JSONResponse( @@ -29,7 +31,7 @@ async def get_remote_runtime_config(request: Request) -> JSONResponse: @app.get('/vscode-url') -async def get_vscode_url(request: Request) -> JSONResponse: +async def get_vscode_url(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse: """Get the VSCode URL. This endpoint allows getting the VSCode URL. @@ -41,7 +43,7 @@ async def get_vscode_url(request: Request) -> JSONResponse: JSONResponse: A JSON response indicating the success of the operation. """ try: - runtime: Runtime = request.state.conversation.runtime + runtime: Runtime = conversation.runtime logger.debug(f'Runtime type: {type(runtime)}') logger.debug(f'Runtime VSCode URL: {runtime.vscode_url}') return JSONResponse( @@ -59,7 +61,7 @@ async def get_vscode_url(request: Request) -> JSONResponse: @app.get('/web-hosts') -async def get_hosts(request: Request) -> JSONResponse: +async def get_hosts(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse: """Get the hosts used by the runtime. This endpoint allows getting the hosts used by the runtime. @@ -71,18 +73,7 @@ async def get_hosts(request: Request) -> JSONResponse: JSONResponse: A JSON response indicating the success of the operation. """ try: - if not hasattr(request.state, 'conversation'): - return JSONResponse( - status_code=500, - content={'error': 'No conversation found in request state'}, - ) - - if not hasattr(request.state.conversation, 'runtime'): - return JSONResponse( - status_code=500, content={'error': 'No runtime found in conversation'} - ) - - runtime: Runtime = request.state.conversation.runtime + runtime: Runtime = conversation.runtime logger.debug(f'Runtime type: {type(runtime)}') logger.debug(f'Runtime hosts: {runtime.web_hosts}') return JSONResponse(status_code=200, content={'hosts': runtime.web_hosts}) @@ -99,12 +90,12 @@ async def get_hosts(request: Request) -> JSONResponse: @app.get('/events') async def search_events( - request: Request, start_id: int = 0, end_id: int | None = None, reverse: bool = False, filter: EventFilter | None = None, limit: int = 20, + conversation: ServerConversation = Depends(get_conversation), ): """Search through the event stream with filtering and pagination. Args: @@ -122,17 +113,13 @@ async def search_events( HTTPException: If conversation is not found ValueError: If limit is less than 1 or greater than 100 """ - if not request.state.conversation: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail='ServerConversation not found' - ) if limit < 0 or limit > 100: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid limit' ) # Get matching events from the stream - event_stream = request.state.conversation.event_stream + event_stream = conversation.event_stream events = list( event_stream.search_events( start_id=start_id, @@ -148,15 +135,15 @@ async def search_events( if has_more: events = events[:limit] # Remove the extra event - events = [event_to_dict(event) for event in events] + events_json = [event_to_dict(event) for event in events] return { - 'events': events, + 'events': events_json, 'has_more': has_more, } @app.post('/events') -async def add_event(request: Request): +async def add_event(request: Request, conversation: ServerConversation = Depends(get_conversation)): data = request.json() - conversation_manager.send_to_event_stream(request.state.sid, data) + conversation_manager.send_to_event_stream(conversation.sid, data) return JSONResponse({'success': True}) diff --git a/openhands/server/routes/feedback.py b/openhands/server/routes/feedback.py index c042b609fd..88d9e01dac 100644 --- a/openhands/server/routes/feedback.py +++ b/openhands/server/routes/feedback.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Request, status +from fastapi import APIRouter, Depends, Request, status from fastapi.responses import JSONResponse from openhands.core.logger import openhands_logger as logger @@ -6,13 +6,15 @@ from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper from openhands.events.serialization import event_to_dict from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback from openhands.server.dependencies import get_dependencies +from openhands.server.utils import get_conversation from openhands.utils.async_utils import call_sync_from_async +from openhands.server.session.conversation import ServerConversation app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()) @app.post('/submit-feedback') -async def submit_feedback(request: Request, conversation_id: str) -> JSONResponse: +async def submit_feedback(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse: """Submit user feedback. This function stores the provided feedback data. @@ -36,7 +38,7 @@ async def submit_feedback(request: Request, conversation_id: str) -> JSONRespons # and there is a function to handle the storage. body = await request.json() async_store = AsyncEventStoreWrapper( - request.state.conversation.event_stream, filter_hidden=True + conversation.event_stream, filter_hidden=True ) trajectory = [] async for event in async_store: diff --git a/openhands/server/routes/files.py b/openhands/server/routes/files.py index 7771a895fa..d6ae4dd161 100644 --- a/openhands/server/routes/files.py +++ b/openhands/server/routes/files.py @@ -32,9 +32,10 @@ from openhands.server.shared import ( config, ) from openhands.server.user_auth import get_user_id -from openhands.server.utils import get_conversation_store +from openhands.server.utils import get_conversation, get_conversation_store from openhands.storage.conversation.conversation_store import ConversationStore from openhands.utils.async_utils import call_sync_from_async +from openhands.server.session.conversation import ServerConversation app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()) @@ -48,7 +49,8 @@ app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_ }, ) async def list_files( - request: Request, path: str | None = None + conversation: ServerConversation = Depends(get_conversation), + path: str | None = None ) -> list[str] | JSONResponse: """List files in the specified path. @@ -70,13 +72,13 @@ async def list_files( Raises: HTTPException: If there's an error listing the files. """ - if not request.state.conversation.runtime: + if not conversation.runtime: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, content={'error': 'Runtime not yet initialized'}, ) - runtime: Runtime = request.state.conversation.runtime + runtime: Runtime = conversation.runtime try: file_list = await call_sync_from_async(runtime.list_files, path) except AgentRuntimeUnavailableError as e: @@ -130,7 +132,7 @@ async def list_files( 415: {'description': 'Unsupported media type', 'model': dict}, }, ) -async def select_file(file: str, request: Request) -> FileResponse | JSONResponse: +async def select_file(file: str, conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse: """Retrieve the content of a specified file. To select a file: @@ -149,7 +151,7 @@ async def select_file(file: str, request: Request) -> FileResponse | JSONRespons Raises: HTTPException: If there's an error opening the file. """ - runtime: Runtime = request.state.conversation.runtime + runtime: Runtime = conversation.runtime file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file) read_action = FileReadAction(file) @@ -194,10 +196,10 @@ async def select_file(file: str, request: Request) -> FileResponse | JSONRespons 500: {'description': 'Error zipping workspace', 'model': dict}, }, ) -def zip_current_workspace(request: Request) -> FileResponse | JSONResponse: +def zip_current_workspace(conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse: try: logger.debug('Zipping workspace') - runtime: Runtime = request.state.conversation.runtime + runtime: Runtime = conversation.runtime path = runtime.config.workspace_mount_path_in_sandbox try: zip_file_path = runtime.copy_from(path) @@ -230,19 +232,15 @@ def zip_current_workspace(request: Request) -> FileResponse | JSONResponse: }, ) async def git_changes( - request: Request, - conversation_id: str, + conversation: ServerConversation = Depends(get_conversation), + conversation_store: ConversationStore = Depends(get_conversation_store), user_id: str = Depends(get_user_id), ) -> list[dict[str, str]] | JSONResponse: - runtime: Runtime = request.state.conversation.runtime - conversation_store = await ConversationStoreImpl.get_instance( - config, - user_id, - ) + runtime: Runtime = conversation.runtime cwd = await get_cwd( conversation_store, - conversation_id, + conversation.sid, runtime.config.workspace_mount_path_in_sandbox, ) logger.info(f'Getting git changes in {cwd}') @@ -275,16 +273,15 @@ async def git_changes( responses={500: {'description': 'Error getting diff', 'model': dict}}, ) async def git_diff( - request: Request, path: str, - conversation_id: str, conversation_store: Any = Depends(get_conversation_store), + conversation: ServerConversation = Depends(get_conversation), ) -> dict[str, Any] | JSONResponse: - runtime: Runtime = request.state.conversation.runtime + runtime: Runtime = conversation.runtime cwd = await get_cwd( conversation_store, - conversation_id, + conversation.sid, runtime.config.workspace_mount_path_in_sandbox, ) diff --git a/openhands/server/routes/security.py b/openhands/server/routes/security.py index 035cb2f729..7cc16bf241 100644 --- a/openhands/server/routes/security.py +++ b/openhands/server/routes/security.py @@ -1,5 +1,6 @@ from fastapi import ( APIRouter, + Depends, HTTPException, Request, Response, @@ -7,12 +8,14 @@ from fastapi import ( ) from openhands.server.dependencies import get_dependencies +from openhands.server.utils import get_conversation +from openhands.server.session.conversation import ServerConversation app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()) @app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE']) -async def security_api(request: Request) -> Response: +async def security_api(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> Response: """Catch-all route for security analyzer API requests. Each request is handled directly to the security analyzer. @@ -26,12 +29,12 @@ async def security_api(request: Request) -> Response: Raises: HTTPException: If the security analyzer is not initialized. """ - if not request.state.conversation.security_analyzer: + if not conversation.security_analyzer: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail='Security analyzer not initialized', ) - return await request.state.conversation.security_analyzer.handle_api_request( + return await conversation.security_analyzer.handle_api_request( request ) diff --git a/openhands/server/routes/trajectory.py b/openhands/server/routes/trajectory.py index 4a070daed7..3f3015e868 100644 --- a/openhands/server/routes/trajectory.py +++ b/openhands/server/routes/trajectory.py @@ -1,16 +1,18 @@ -from fastapi import APIRouter, Request, status +from fastapi import APIRouter, Depends, Request, status from fastapi.responses import JSONResponse from openhands.core.logger import openhands_logger as logger from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper from openhands.events.serialization import event_to_trajectory from openhands.server.dependencies import get_dependencies +from openhands.server.utils import get_conversation +from openhands.server.session.conversation import ServerConversation app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()) @app.get('/trajectory') -async def get_trajectory(request: Request) -> JSONResponse: +async def get_trajectory(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse: """Get trajectory. This function retrieves the current trajectory and returns it. @@ -24,7 +26,7 @@ async def get_trajectory(request: Request) -> JSONResponse: """ try: async_store = AsyncEventStoreWrapper( - request.state.conversation.event_stream, filter_hidden=True + conversation.event_stream, filter_hidden=True ) trajectory = [] async for event in async_store: diff --git a/openhands/server/utils.py b/openhands/server/utils.py index f24d250e4e..f977649806 100644 --- a/openhands/server/utils.py +++ b/openhands/server/utils.py @@ -1,7 +1,7 @@ -from fastapi import Request +from fastapi import Depends, Request -from openhands.server.shared import ConversationStoreImpl, config -from openhands.server.user_auth import get_user_auth +from openhands.server.shared import ConversationStoreImpl, config, conversation_manager +from openhands.server.user_auth import get_user_auth, get_user_id from openhands.storage.conversation.conversation_store import ConversationStore @@ -16,3 +16,16 @@ async def get_conversation_store(request: Request) -> ConversationStore | None: conversation_store = await ConversationStoreImpl.get_instance(config, user_id) request.state.conversation_store = conversation_store return conversation_store + + +async def get_conversation( + conversation_id: str, user_id: str | None = Depends(get_user_id) +): + """Grabs conversation id set by middleware. Adds the conversation_id to the openapi schema.""" + conversation = await conversation_manager.attach_to_conversation( + conversation_id, user_id + ) + try: + yield conversation + finally: + await conversation_manager.detach_from_conversation(conversation)