mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Now using Dependency Injection to associate conversations with requests (#8863)
This commit is contained in:
parent
4aed3944cf
commit
c2a0e525de
@ -5,7 +5,6 @@ import socketio
|
||||
from openhands.server.app import app as base_app
|
||||
from openhands.server.listen_socket import sio
|
||||
from openhands.server.middleware import (
|
||||
AttachConversationMiddleware,
|
||||
CacheControlMiddleware,
|
||||
InMemoryRateLimiter,
|
||||
LocalhostCORSMiddleware,
|
||||
@ -24,6 +23,5 @@ base_app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
rate_limiter=InMemoryRateLimiter(requests=10, seconds=1),
|
||||
)
|
||||
base_app.middleware('http')(AttachConversationMiddleware(base_app))
|
||||
|
||||
app = socketio.ASGIApp(sio, other_asgi_app=base_app)
|
||||
|
||||
@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import Request, status
|
||||
from fastapi import Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
@ -12,10 +12,6 @@ from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from openhands.server import shared
|
||||
from openhands.server.types import SessionMiddlewareInterface
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
class LocalhostCORSMiddleware(CORSMiddleware):
|
||||
"""
|
||||
@ -136,72 +132,3 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
return False
|
||||
# Put Other non rate limited checks here
|
||||
return True
|
||||
|
||||
|
||||
class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def _should_attach(self, request: Request) -> bool:
|
||||
"""
|
||||
Determine if the middleware should attach a session for the given request.
|
||||
"""
|
||||
if request.method == 'OPTIONS':
|
||||
return False
|
||||
|
||||
conversation_id = ''
|
||||
if request.url.path.startswith('/api/conversation'):
|
||||
# FIXME: we should be able to use path_params
|
||||
path_parts = request.url.path.split('/')
|
||||
if len(path_parts) > 4:
|
||||
conversation_id = request.url.path.split('/')[3]
|
||||
if not conversation_id:
|
||||
return False
|
||||
|
||||
request.state.sid = conversation_id
|
||||
|
||||
return True
|
||||
|
||||
async def _attach_conversation(self, request: Request) -> JSONResponse | None:
|
||||
"""
|
||||
Attach the user's session based on the provided authentication token.
|
||||
"""
|
||||
user_id = await get_user_id(request)
|
||||
request.state.conversation = (
|
||||
await shared.conversation_manager.attach_to_conversation(
|
||||
request.state.sid, user_id
|
||||
)
|
||||
)
|
||||
if not request.state.conversation:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'Session not found'},
|
||||
)
|
||||
return None
|
||||
|
||||
async def _detach_session(self, request: Request) -> None:
|
||||
"""
|
||||
Detach the user's session.
|
||||
"""
|
||||
await shared.conversation_manager.detach_from_conversation(
|
||||
request.state.conversation
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
if not self._should_attach(request):
|
||||
return await call_next(request)
|
||||
|
||||
response = await self._attach_conversation(request)
|
||||
if response:
|
||||
return response
|
||||
|
||||
try:
|
||||
# Continue processing the request
|
||||
response = await call_next(request)
|
||||
finally:
|
||||
# Ensure the session is detached
|
||||
await self._detach_session(request)
|
||||
|
||||
return response
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, HTTPException, Request, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -6,18 +6,20 @@ from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.shared import conversation_manager
|
||||
from openhands.server.utils import get_conversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
|
||||
|
||||
@app.get('/config')
|
||||
async def get_remote_runtime_config(request: Request) -> JSONResponse:
|
||||
async def get_remote_runtime_config(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
"""Retrieve the runtime configuration.
|
||||
|
||||
Currently, this is the session ID and runtime ID (if available).
|
||||
"""
|
||||
runtime = request.state.conversation.runtime
|
||||
runtime = conversation.runtime
|
||||
runtime_id = runtime.runtime_id if hasattr(runtime, 'runtime_id') else None
|
||||
session_id = runtime.sid if hasattr(runtime, 'sid') else None
|
||||
return JSONResponse(
|
||||
@ -29,7 +31,7 @@ async def get_remote_runtime_config(request: Request) -> JSONResponse:
|
||||
|
||||
|
||||
@app.get('/vscode-url')
|
||||
async def get_vscode_url(request: Request) -> JSONResponse:
|
||||
async def get_vscode_url(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
"""Get the VSCode URL.
|
||||
|
||||
This endpoint allows getting the VSCode URL.
|
||||
@ -41,7 +43,7 @@ async def get_vscode_url(request: Request) -> JSONResponse:
|
||||
JSONResponse: A JSON response indicating the success of the operation.
|
||||
"""
|
||||
try:
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
runtime: Runtime = conversation.runtime
|
||||
logger.debug(f'Runtime type: {type(runtime)}')
|
||||
logger.debug(f'Runtime VSCode URL: {runtime.vscode_url}')
|
||||
return JSONResponse(
|
||||
@ -59,7 +61,7 @@ async def get_vscode_url(request: Request) -> JSONResponse:
|
||||
|
||||
|
||||
@app.get('/web-hosts')
|
||||
async def get_hosts(request: Request) -> JSONResponse:
|
||||
async def get_hosts(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
"""Get the hosts used by the runtime.
|
||||
|
||||
This endpoint allows getting the hosts used by the runtime.
|
||||
@ -71,18 +73,7 @@ async def get_hosts(request: Request) -> JSONResponse:
|
||||
JSONResponse: A JSON response indicating the success of the operation.
|
||||
"""
|
||||
try:
|
||||
if not hasattr(request.state, 'conversation'):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={'error': 'No conversation found in request state'},
|
||||
)
|
||||
|
||||
if not hasattr(request.state.conversation, 'runtime'):
|
||||
return JSONResponse(
|
||||
status_code=500, content={'error': 'No runtime found in conversation'}
|
||||
)
|
||||
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
runtime: Runtime = conversation.runtime
|
||||
logger.debug(f'Runtime type: {type(runtime)}')
|
||||
logger.debug(f'Runtime hosts: {runtime.web_hosts}')
|
||||
return JSONResponse(status_code=200, content={'hosts': runtime.web_hosts})
|
||||
@ -99,12 +90,12 @@ async def get_hosts(request: Request) -> JSONResponse:
|
||||
|
||||
@app.get('/events')
|
||||
async def search_events(
|
||||
request: Request,
|
||||
start_id: int = 0,
|
||||
end_id: int | None = None,
|
||||
reverse: bool = False,
|
||||
filter: EventFilter | None = None,
|
||||
limit: int = 20,
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
):
|
||||
"""Search through the event stream with filtering and pagination.
|
||||
Args:
|
||||
@ -122,17 +113,13 @@ async def search_events(
|
||||
HTTPException: If conversation is not found
|
||||
ValueError: If limit is less than 1 or greater than 100
|
||||
"""
|
||||
if not request.state.conversation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail='ServerConversation not found'
|
||||
)
|
||||
if limit < 0 or limit > 100:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid limit'
|
||||
)
|
||||
|
||||
# Get matching events from the stream
|
||||
event_stream = request.state.conversation.event_stream
|
||||
event_stream = conversation.event_stream
|
||||
events = list(
|
||||
event_stream.search_events(
|
||||
start_id=start_id,
|
||||
@ -148,15 +135,15 @@ async def search_events(
|
||||
if has_more:
|
||||
events = events[:limit] # Remove the extra event
|
||||
|
||||
events = [event_to_dict(event) for event in events]
|
||||
events_json = [event_to_dict(event) for event in events]
|
||||
return {
|
||||
'events': events,
|
||||
'events': events_json,
|
||||
'has_more': has_more,
|
||||
}
|
||||
|
||||
|
||||
@app.post('/events')
|
||||
async def add_event(request: Request):
|
||||
async def add_event(request: Request, conversation: ServerConversation = Depends(get_conversation)):
|
||||
data = request.json()
|
||||
conversation_manager.send_to_event_stream(request.state.sid, data)
|
||||
conversation_manager.send_to_event_stream(conversation.sid, data)
|
||||
return JSONResponse({'success': True})
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Request, status
|
||||
from fastapi import APIRouter, Depends, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -6,13 +6,15 @@ from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
|
||||
|
||||
@app.post('/submit-feedback')
|
||||
async def submit_feedback(request: Request, conversation_id: str) -> JSONResponse:
|
||||
async def submit_feedback(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
"""Submit user feedback.
|
||||
|
||||
This function stores the provided feedback data.
|
||||
@ -36,7 +38,7 @@ async def submit_feedback(request: Request, conversation_id: str) -> JSONRespons
|
||||
# and there is a function to handle the storage.
|
||||
body = await request.json()
|
||||
async_store = AsyncEventStoreWrapper(
|
||||
request.state.conversation.event_stream, filter_hidden=True
|
||||
conversation.event_stream, filter_hidden=True
|
||||
)
|
||||
trajectory = []
|
||||
async for event in async_store:
|
||||
|
||||
@ -32,9 +32,10 @@ from openhands.server.shared import (
|
||||
config,
|
||||
)
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.utils import get_conversation_store
|
||||
from openhands.server.utils import get_conversation, get_conversation_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
|
||||
@ -48,7 +49,8 @@ app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_
|
||||
},
|
||||
)
|
||||
async def list_files(
|
||||
request: Request, path: str | None = None
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
path: str | None = None
|
||||
) -> list[str] | JSONResponse:
|
||||
"""List files in the specified path.
|
||||
|
||||
@ -70,13 +72,13 @@ async def list_files(
|
||||
Raises:
|
||||
HTTPException: If there's an error listing the files.
|
||||
"""
|
||||
if not request.state.conversation.runtime:
|
||||
if not conversation.runtime:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'Runtime not yet initialized'},
|
||||
)
|
||||
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
runtime: Runtime = conversation.runtime
|
||||
try:
|
||||
file_list = await call_sync_from_async(runtime.list_files, path)
|
||||
except AgentRuntimeUnavailableError as e:
|
||||
@ -130,7 +132,7 @@ async def list_files(
|
||||
415: {'description': 'Unsupported media type', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def select_file(file: str, request: Request) -> FileResponse | JSONResponse:
|
||||
async def select_file(file: str, conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
|
||||
"""Retrieve the content of a specified file.
|
||||
|
||||
To select a file:
|
||||
@ -149,7 +151,7 @@ async def select_file(file: str, request: Request) -> FileResponse | JSONRespons
|
||||
Raises:
|
||||
HTTPException: If there's an error opening the file.
|
||||
"""
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
runtime: Runtime = conversation.runtime
|
||||
|
||||
file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
|
||||
read_action = FileReadAction(file)
|
||||
@ -194,10 +196,10 @@ async def select_file(file: str, request: Request) -> FileResponse | JSONRespons
|
||||
500: {'description': 'Error zipping workspace', 'model': dict},
|
||||
},
|
||||
)
|
||||
def zip_current_workspace(request: Request) -> FileResponse | JSONResponse:
|
||||
def zip_current_workspace(conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
|
||||
try:
|
||||
logger.debug('Zipping workspace')
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
runtime: Runtime = conversation.runtime
|
||||
path = runtime.config.workspace_mount_path_in_sandbox
|
||||
try:
|
||||
zip_file_path = runtime.copy_from(path)
|
||||
@ -230,19 +232,15 @@ def zip_current_workspace(request: Request) -> FileResponse | JSONResponse:
|
||||
},
|
||||
)
|
||||
async def git_changes(
|
||||
request: Request,
|
||||
conversation_id: str,
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
conversation_store: ConversationStore = Depends(get_conversation_store),
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> list[dict[str, str]] | JSONResponse:
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config,
|
||||
user_id,
|
||||
)
|
||||
runtime: Runtime = conversation.runtime
|
||||
|
||||
cwd = await get_cwd(
|
||||
conversation_store,
|
||||
conversation_id,
|
||||
conversation.sid,
|
||||
runtime.config.workspace_mount_path_in_sandbox,
|
||||
)
|
||||
logger.info(f'Getting git changes in {cwd}')
|
||||
@ -275,16 +273,15 @@ async def git_changes(
|
||||
responses={500: {'description': 'Error getting diff', 'model': dict}},
|
||||
)
|
||||
async def git_diff(
|
||||
request: Request,
|
||||
path: str,
|
||||
conversation_id: str,
|
||||
conversation_store: Any = Depends(get_conversation_store),
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
) -> dict[str, Any] | JSONResponse:
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
runtime: Runtime = conversation.runtime
|
||||
|
||||
cwd = await get_cwd(
|
||||
conversation_store,
|
||||
conversation_id,
|
||||
conversation.sid,
|
||||
runtime.config.workspace_mount_path_in_sandbox,
|
||||
)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
@ -7,12 +8,14 @@ from fastapi import (
|
||||
)
|
||||
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
|
||||
|
||||
@app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
|
||||
async def security_api(request: Request) -> Response:
|
||||
async def security_api(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> Response:
|
||||
"""Catch-all route for security analyzer API requests.
|
||||
|
||||
Each request is handled directly to the security analyzer.
|
||||
@ -26,12 +29,12 @@ async def security_api(request: Request) -> Response:
|
||||
Raises:
|
||||
HTTPException: If the security analyzer is not initialized.
|
||||
"""
|
||||
if not request.state.conversation.security_analyzer:
|
||||
if not conversation.security_analyzer:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Security analyzer not initialized',
|
||||
)
|
||||
|
||||
return await request.state.conversation.security_analyzer.handle_api_request(
|
||||
return await conversation.security_analyzer.handle_api_request(
|
||||
request
|
||||
)
|
||||
|
||||
@ -1,16 +1,18 @@
|
||||
from fastapi import APIRouter, Request, status
|
||||
from fastapi import APIRouter, Depends, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
||||
from openhands.events.serialization import event_to_trajectory
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
|
||||
|
||||
@app.get('/trajectory')
|
||||
async def get_trajectory(request: Request) -> JSONResponse:
|
||||
async def get_trajectory(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
"""Get trajectory.
|
||||
|
||||
This function retrieves the current trajectory and returns it.
|
||||
@ -24,7 +26,7 @@ async def get_trajectory(request: Request) -> JSONResponse:
|
||||
"""
|
||||
try:
|
||||
async_store = AsyncEventStoreWrapper(
|
||||
request.state.conversation.event_stream, filter_hidden=True
|
||||
conversation.event_stream, filter_hidden=True
|
||||
)
|
||||
trajectory = []
|
||||
async for event in async_store:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from fastapi import Request
|
||||
from fastapi import Depends, Request
|
||||
|
||||
from openhands.server.shared import ConversationStoreImpl, config
|
||||
from openhands.server.user_auth import get_user_auth
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth import get_user_auth, get_user_id
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
|
||||
|
||||
@ -16,3 +16,16 @@ async def get_conversation_store(request: Request) -> ConversationStore | None:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
request.state.conversation_store = conversation_store
|
||||
return conversation_store
|
||||
|
||||
|
||||
async def get_conversation(
|
||||
conversation_id: str, user_id: str | None = Depends(get_user_id)
|
||||
):
|
||||
"""Grabs conversation id set by middleware. Adds the conversation_id to the openapi schema."""
|
||||
conversation = await conversation_manager.attach_to_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
try:
|
||||
yield conversation
|
||||
finally:
|
||||
await conversation_manager.detach_from_conversation(conversation)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user