Now using Dependency Injection to associate conversations with requests (#8863)

This commit is contained in:
tofarr 2025-06-03 17:36:45 -06:00 committed by GitHub
parent 4aed3944cf
commit c2a0e525de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 65 additions and 136 deletions

View File

@ -5,7 +5,6 @@ import socketio
from openhands.server.app import app as base_app
from openhands.server.listen_socket import sio
from openhands.server.middleware import (
AttachConversationMiddleware,
CacheControlMiddleware,
InMemoryRateLimiter,
LocalhostCORSMiddleware,
@ -24,6 +23,5 @@ base_app.add_middleware(
RateLimitMiddleware,
rate_limiter=InMemoryRateLimiter(requests=10, seconds=1),
)
base_app.middleware('http')(AttachConversationMiddleware(base_app))
app = socketio.ASGIApp(sio, other_asgi_app=base_app)

View File

@ -4,7 +4,7 @@ from collections import defaultdict
from datetime import datetime, timedelta
from urllib.parse import urlparse
from fastapi import Request, status
from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
@ -12,10 +12,6 @@ from starlette.requests import Request as StarletteRequest
from starlette.responses import Response
from starlette.types import ASGIApp
from openhands.server import shared
from openhands.server.types import SessionMiddlewareInterface
from openhands.server.user_auth import get_user_id
class LocalhostCORSMiddleware(CORSMiddleware):
"""
@ -136,72 +132,3 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
return False
# Put Other non rate limited checks here
return True
class AttachConversationMiddleware(SessionMiddlewareInterface):
def __init__(self, app: ASGIApp) -> None:
self.app = app
def _should_attach(self, request: Request) -> bool:
"""
Determine if the middleware should attach a session for the given request.
"""
if request.method == 'OPTIONS':
return False
conversation_id = ''
if request.url.path.startswith('/api/conversation'):
# FIXME: we should be able to use path_params
path_parts = request.url.path.split('/')
if len(path_parts) > 4:
conversation_id = request.url.path.split('/')[3]
if not conversation_id:
return False
request.state.sid = conversation_id
return True
async def _attach_conversation(self, request: Request) -> JSONResponse | None:
"""
Attach the user's session based on the provided authentication token.
"""
user_id = await get_user_id(request)
request.state.conversation = (
await shared.conversation_manager.attach_to_conversation(
request.state.sid, user_id
)
)
if not request.state.conversation:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Session not found'},
)
return None
async def _detach_session(self, request: Request) -> None:
"""
Detach the user's session.
"""
await shared.conversation_manager.detach_from_conversation(
request.state.conversation
)
async def __call__(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if not self._should_attach(request):
return await call_next(request)
response = await self._attach_conversation(request)
if response:
return response
try:
# Continue processing the request
response = await call_next(request)
finally:
# Ensure the session is detached
await self._detach_session(request)
return response

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Request, status
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import JSONResponse
from openhands.core.logger import openhands_logger as logger
@ -6,18 +6,20 @@ from openhands.events.event_filter import EventFilter
from openhands.events.serialization.event import event_to_dict
from openhands.runtime.base import Runtime
from openhands.server.dependencies import get_dependencies
from openhands.server.session.conversation import ServerConversation
from openhands.server.shared import conversation_manager
from openhands.server.utils import get_conversation
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
@app.get('/config')
async def get_remote_runtime_config(request: Request) -> JSONResponse:
async def get_remote_runtime_config(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
"""Retrieve the runtime configuration.
Currently, this is the session ID and runtime ID (if available).
"""
runtime = request.state.conversation.runtime
runtime = conversation.runtime
runtime_id = runtime.runtime_id if hasattr(runtime, 'runtime_id') else None
session_id = runtime.sid if hasattr(runtime, 'sid') else None
return JSONResponse(
@ -29,7 +31,7 @@ async def get_remote_runtime_config(request: Request) -> JSONResponse:
@app.get('/vscode-url')
async def get_vscode_url(request: Request) -> JSONResponse:
async def get_vscode_url(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
"""Get the VSCode URL.
This endpoint allows getting the VSCode URL.
@ -41,7 +43,7 @@ async def get_vscode_url(request: Request) -> JSONResponse:
JSONResponse: A JSON response indicating the success of the operation.
"""
try:
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = conversation.runtime
logger.debug(f'Runtime type: {type(runtime)}')
logger.debug(f'Runtime VSCode URL: {runtime.vscode_url}')
return JSONResponse(
@ -59,7 +61,7 @@ async def get_vscode_url(request: Request) -> JSONResponse:
@app.get('/web-hosts')
async def get_hosts(request: Request) -> JSONResponse:
async def get_hosts(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
"""Get the hosts used by the runtime.
This endpoint allows getting the hosts used by the runtime.
@ -71,18 +73,7 @@ async def get_hosts(request: Request) -> JSONResponse:
JSONResponse: A JSON response indicating the success of the operation.
"""
try:
if not hasattr(request.state, 'conversation'):
return JSONResponse(
status_code=500,
content={'error': 'No conversation found in request state'},
)
if not hasattr(request.state.conversation, 'runtime'):
return JSONResponse(
status_code=500, content={'error': 'No runtime found in conversation'}
)
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = conversation.runtime
logger.debug(f'Runtime type: {type(runtime)}')
logger.debug(f'Runtime hosts: {runtime.web_hosts}')
return JSONResponse(status_code=200, content={'hosts': runtime.web_hosts})
@ -99,12 +90,12 @@ async def get_hosts(request: Request) -> JSONResponse:
@app.get('/events')
async def search_events(
request: Request,
start_id: int = 0,
end_id: int | None = None,
reverse: bool = False,
filter: EventFilter | None = None,
limit: int = 20,
conversation: ServerConversation = Depends(get_conversation),
):
"""Search through the event stream with filtering and pagination.
Args:
@ -122,17 +113,13 @@ async def search_events(
HTTPException: If conversation is not found
ValueError: If limit is less than 1 or greater than 100
"""
if not request.state.conversation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail='ServerConversation not found'
)
if limit < 0 or limit > 100:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid limit'
)
# Get matching events from the stream
event_stream = request.state.conversation.event_stream
event_stream = conversation.event_stream
events = list(
event_stream.search_events(
start_id=start_id,
@ -148,15 +135,15 @@ async def search_events(
if has_more:
events = events[:limit] # Remove the extra event
events = [event_to_dict(event) for event in events]
events_json = [event_to_dict(event) for event in events]
return {
'events': events,
'events': events_json,
'has_more': has_more,
}
@app.post('/events')
async def add_event(request: Request):
async def add_event(request: Request, conversation: ServerConversation = Depends(get_conversation)):
data = request.json()
conversation_manager.send_to_event_stream(request.state.sid, data)
conversation_manager.send_to_event_stream(conversation.sid, data)
return JSONResponse({'success': True})

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, Request, status
from fastapi import APIRouter, Depends, Request, status
from fastapi.responses import JSONResponse
from openhands.core.logger import openhands_logger as logger
@ -6,13 +6,15 @@ from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
from openhands.events.serialization import event_to_dict
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
from openhands.server.dependencies import get_dependencies
from openhands.server.utils import get_conversation
from openhands.utils.async_utils import call_sync_from_async
from openhands.server.session.conversation import ServerConversation
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
@app.post('/submit-feedback')
async def submit_feedback(request: Request, conversation_id: str) -> JSONResponse:
async def submit_feedback(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
"""Submit user feedback.
This function stores the provided feedback data.
@ -36,7 +38,7 @@ async def submit_feedback(request: Request, conversation_id: str) -> JSONRespons
# and there is a function to handle the storage.
body = await request.json()
async_store = AsyncEventStoreWrapper(
request.state.conversation.event_stream, filter_hidden=True
conversation.event_stream, filter_hidden=True
)
trajectory = []
async for event in async_store:

View File

@ -32,9 +32,10 @@ from openhands.server.shared import (
config,
)
from openhands.server.user_auth import get_user_id
from openhands.server.utils import get_conversation_store
from openhands.server.utils import get_conversation, get_conversation_store
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.server.session.conversation import ServerConversation
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
@ -48,7 +49,8 @@ app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_
},
)
async def list_files(
request: Request, path: str | None = None
conversation: ServerConversation = Depends(get_conversation),
path: str | None = None
) -> list[str] | JSONResponse:
"""List files in the specified path.
@ -70,13 +72,13 @@ async def list_files(
Raises:
HTTPException: If there's an error listing the files.
"""
if not request.state.conversation.runtime:
if not conversation.runtime:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Runtime not yet initialized'},
)
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = conversation.runtime
try:
file_list = await call_sync_from_async(runtime.list_files, path)
except AgentRuntimeUnavailableError as e:
@ -130,7 +132,7 @@ async def list_files(
415: {'description': 'Unsupported media type', 'model': dict},
},
)
async def select_file(file: str, request: Request) -> FileResponse | JSONResponse:
async def select_file(file: str, conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
"""Retrieve the content of a specified file.
To select a file:
@ -149,7 +151,7 @@ async def select_file(file: str, request: Request) -> FileResponse | JSONRespons
Raises:
HTTPException: If there's an error opening the file.
"""
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = conversation.runtime
file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
read_action = FileReadAction(file)
@ -194,10 +196,10 @@ async def select_file(file: str, request: Request) -> FileResponse | JSONRespons
500: {'description': 'Error zipping workspace', 'model': dict},
},
)
def zip_current_workspace(request: Request) -> FileResponse | JSONResponse:
def zip_current_workspace(conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
try:
logger.debug('Zipping workspace')
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = conversation.runtime
path = runtime.config.workspace_mount_path_in_sandbox
try:
zip_file_path = runtime.copy_from(path)
@ -230,19 +232,15 @@ def zip_current_workspace(request: Request) -> FileResponse | JSONResponse:
},
)
async def git_changes(
request: Request,
conversation_id: str,
conversation: ServerConversation = Depends(get_conversation),
conversation_store: ConversationStore = Depends(get_conversation_store),
user_id: str = Depends(get_user_id),
) -> list[dict[str, str]] | JSONResponse:
runtime: Runtime = request.state.conversation.runtime
conversation_store = await ConversationStoreImpl.get_instance(
config,
user_id,
)
runtime: Runtime = conversation.runtime
cwd = await get_cwd(
conversation_store,
conversation_id,
conversation.sid,
runtime.config.workspace_mount_path_in_sandbox,
)
logger.info(f'Getting git changes in {cwd}')
@ -275,16 +273,15 @@ async def git_changes(
responses={500: {'description': 'Error getting diff', 'model': dict}},
)
async def git_diff(
request: Request,
path: str,
conversation_id: str,
conversation_store: Any = Depends(get_conversation_store),
conversation: ServerConversation = Depends(get_conversation),
) -> dict[str, Any] | JSONResponse:
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = conversation.runtime
cwd = await get_cwd(
conversation_store,
conversation_id,
conversation.sid,
runtime.config.workspace_mount_path_in_sandbox,
)

View File

@ -1,5 +1,6 @@
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
Response,
@ -7,12 +8,14 @@ from fastapi import (
)
from openhands.server.dependencies import get_dependencies
from openhands.server.utils import get_conversation
from openhands.server.session.conversation import ServerConversation
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
@app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
async def security_api(request: Request) -> Response:
async def security_api(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> Response:
"""Catch-all route for security analyzer API requests.
Each request is handled directly to the security analyzer.
@ -26,12 +29,12 @@ async def security_api(request: Request) -> Response:
Raises:
HTTPException: If the security analyzer is not initialized.
"""
if not request.state.conversation.security_analyzer:
if not conversation.security_analyzer:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Security analyzer not initialized',
)
return await request.state.conversation.security_analyzer.handle_api_request(
return await conversation.security_analyzer.handle_api_request(
request
)

View File

@ -1,16 +1,18 @@
from fastapi import APIRouter, Request, status
from fastapi import APIRouter, Depends, Request, status
from fastapi.responses import JSONResponse
from openhands.core.logger import openhands_logger as logger
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
from openhands.events.serialization import event_to_trajectory
from openhands.server.dependencies import get_dependencies
from openhands.server.utils import get_conversation
from openhands.server.session.conversation import ServerConversation
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
@app.get('/trajectory')
async def get_trajectory(request: Request) -> JSONResponse:
async def get_trajectory(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
"""Get trajectory.
This function retrieves the current trajectory and returns it.
@ -24,7 +26,7 @@ async def get_trajectory(request: Request) -> JSONResponse:
"""
try:
async_store = AsyncEventStoreWrapper(
request.state.conversation.event_stream, filter_hidden=True
conversation.event_stream, filter_hidden=True
)
trajectory = []
async for event in async_store:

View File

@ -1,7 +1,7 @@
from fastapi import Request
from fastapi import Depends, Request
from openhands.server.shared import ConversationStoreImpl, config
from openhands.server.user_auth import get_user_auth
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
from openhands.server.user_auth import get_user_auth, get_user_id
from openhands.storage.conversation.conversation_store import ConversationStore
@ -16,3 +16,16 @@ async def get_conversation_store(request: Request) -> ConversationStore | None:
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
request.state.conversation_store = conversation_store
return conversation_store
async def get_conversation(
conversation_id: str, user_id: str | None = Depends(get_user_id)
):
"""Grabs conversation id set by middleware. Adds the conversation_id to the openapi schema."""
conversation = await conversation_manager.attach_to_conversation(
conversation_id, user_id
)
try:
yield conversation
finally:
await conversation_manager.detach_from_conversation(conversation)