diff --git a/openhands/server/app.py b/openhands/server/app.py new file mode 100644 index 0000000000..33f9766fe6 --- /dev/null +++ b/openhands/server/app.py @@ -0,0 +1,56 @@ +import warnings + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + +from fastapi import ( + FastAPI, +) + +import openhands.agenthub # noqa F401 (we import this to get the agents registered) +from openhands.server.middleware import ( + AttachSessionMiddleware, + InMemoryRateLimiter, + LocalhostCORSMiddleware, + NoCacheMiddleware, + RateLimitMiddleware, +) +from openhands.server.routes.auth import app as auth_api_router +from openhands.server.routes.conversation import app as conversation_api_router +from openhands.server.routes.feedback import app as feedback_api_router +from openhands.server.routes.files import app as files_api_router +from openhands.server.routes.public import app as public_api_router +from openhands.server.routes.security import app as security_api_router + +app = FastAPI() +app.add_middleware( + LocalhostCORSMiddleware, + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], +) + +app.add_middleware(NoCacheMiddleware) +app.add_middleware( + RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=10, seconds=1) +) + + +@app.get('/health') +async def health(): + return 'OK' + + +app.include_router(auth_api_router) +app.include_router(public_api_router) +app.include_router(files_api_router) +app.include_router(conversation_api_router) +app.include_router(security_api_router) +app.include_router(feedback_api_router) + +app.middleware('http')(AttachSessionMiddleware(app, target_router=files_api_router)) +app.middleware('http')( + AttachSessionMiddleware(app, target_router=conversation_api_router) +) +app.middleware('http')(AttachSessionMiddleware(app, target_router=security_api_router)) +app.middleware('http')(AttachSessionMiddleware(app, target_router=feedback_api_router)) diff --git a/openhands/server/auth/auth.py b/openhands/server/auth.py similarity index 100% rename from openhands/server/auth/auth.py rename to openhands/server/auth.py diff --git a/openhands/server/auth/__init__.py b/openhands/server/auth/__init__.py deleted file mode 100644 index 0fe3ddd8cc..0000000000 --- a/openhands/server/auth/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from openhands.server.auth.auth import get_sid_from_token, sign_token - -__all__ = ['get_sid_from_token', 'sign_token'] diff --git a/openhands/server/file_config.py b/openhands/server/file_config.py new file mode 100644 index 0000000000..06e8ce20ee --- /dev/null +++ b/openhands/server/file_config.py @@ -0,0 +1,111 @@ +import os +import re + +from openhands.core.config import AppConfig +from openhands.core.logger import openhands_logger as logger +from openhands.server.shared import config as shared_config + +FILES_TO_IGNORE = [ + '.git/', + '.DS_Store', + 'node_modules/', + '__pycache__/', +] + + +def sanitize_filename(filename): + """Sanitize the filename to prevent directory traversal""" + # Remove any directory components + filename = os.path.basename(filename) + # Remove any non-alphanumeric characters except for .-_ + filename = re.sub(r'[^\w\-_\.]', '', filename) + # Limit the filename length + max_length = 255 + if len(filename) > max_length: + name, ext = os.path.splitext(filename) + filename = name[: max_length - len(ext)] + ext + return filename + + +def load_file_upload_config( + config: AppConfig = shared_config, +) -> tuple[int, bool, list[str]]: + """Load file upload configuration from the config object. + + This function retrieves the file upload settings from the global config object. + It handles the following settings: + - Maximum file size for uploads + - Whether to restrict file types + - List of allowed file extensions + + It also performs sanity checks on the values to ensure they are valid and safe. + + Returns: + tuple: A tuple containing: + - max_file_size_mb (int): Maximum file size in MB. 0 means no limit. + - restrict_file_types (bool): Whether file type restrictions are enabled. + - allowed_extensions (set): Set of allowed file extensions. + """ + # Retrieve values from config + max_file_size_mb = config.file_uploads_max_file_size_mb + restrict_file_types = config.file_uploads_restrict_file_types + allowed_extensions = config.file_uploads_allowed_extensions + + # Sanity check for max_file_size_mb + if not isinstance(max_file_size_mb, int) or max_file_size_mb < 0: + logger.warning( + f'Invalid max_file_size_mb: {max_file_size_mb}. Setting to 0 (no limit).' + ) + max_file_size_mb = 0 + + # Sanity check for allowed_extensions + if not isinstance(allowed_extensions, (list, set)) or not allowed_extensions: + logger.warning( + f'Invalid allowed_extensions: {allowed_extensions}. Setting to [".*"].' + ) + allowed_extensions = ['.*'] + else: + # Ensure all extensions start with a dot and are lowercase + allowed_extensions = [ + ext.lower() if ext.startswith('.') else f'.{ext.lower()}' + for ext in allowed_extensions + ] + + # If restrictions are disabled, allow all + if not restrict_file_types: + allowed_extensions = ['.*'] + + logger.debug( + f'File upload config: max_size={max_file_size_mb}MB, ' + f'restrict_types={restrict_file_types}, ' + f'allowed_extensions={allowed_extensions}' + ) + + return max_file_size_mb, restrict_file_types, allowed_extensions + + +# Load configuration +MAX_FILE_SIZE_MB, RESTRICT_FILE_TYPES, ALLOWED_EXTENSIONS = load_file_upload_config() + + +def is_extension_allowed(filename): + """Check if the file extension is allowed based on the current configuration. + + This function supports wildcards and files without extensions. + The check is case-insensitive for extensions. + + Args: + filename (str): The name of the file to check. + + Returns: + bool: True if the file extension is allowed, False otherwise. + """ + if not RESTRICT_FILE_TYPES: + return True + + file_ext = os.path.splitext(filename)[1].lower() # Convert to lowercase + return ( + '.*' in ALLOWED_EXTENSIONS + or file_ext in (ext.lower() for ext in ALLOWED_EXTENSIONS) + or (file_ext == '' and '.' in ALLOWED_EXTENSIONS) + ) diff --git a/openhands/server/listen.py b/openhands/server/listen.py index ce173ba295..3ad6fd1b88 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -1,1031 +1,11 @@ -import os -import re -import tempfile -import time -import warnings -from contextlib import asynccontextmanager - -import jwt -import requests import socketio -from pathspec import PathSpec -from pathspec.patterns import GitWildMatchPattern -from openhands.core.schema.action import ActionType -from openhands.security.options import SecurityAnalyzers -from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback -from openhands.server.github_utils import ( - GITHUB_CLIENT_ID, - GITHUB_CLIENT_SECRET, - UserVerifier, - authenticate_github_user, -) -from openhands.storage import get_file_store -from openhands.utils.async_utils import call_sync_from_async +from openhands.server.app import app as base_app +from openhands.server.socket import sio +from openhands.server.static import SPAStaticFiles -with warnings.catch_warnings(): - warnings.simplefilter('ignore') - import litellm - -from dotenv import load_dotenv -from fastapi import ( - BackgroundTasks, - FastAPI, - HTTPException, - Request, - UploadFile, - status, -) -from fastapi.responses import FileResponse, JSONResponse -from fastapi.security import HTTPBearer -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel - -import openhands.agenthub # noqa F401 (we import this to get the agents registered) -from openhands.controller.agent import Agent -from openhands.core.config import LLMConfig, load_app_config -from openhands.core.logger import openhands_logger as logger -from openhands.events.action import ( - FileReadAction, - FileWriteAction, - NullAction, -) -from openhands.events.observation import ( - ErrorObservation, - FileReadObservation, - FileWriteObservation, - NullObservation, -) -from openhands.events.serialization import event_to_dict -from openhands.events.stream import AsyncEventStreamWrapper -from openhands.llm import bedrock -from openhands.runtime.base import Runtime, RuntimeUnavailableError -from openhands.server.auth.auth import get_sid_from_token, sign_token -from openhands.server.middleware import ( - InMemoryRateLimiter, - LocalhostCORSMiddleware, - NoCacheMiddleware, - RateLimitMiddleware, -) -from openhands.server.session import SessionManager - -load_dotenv() - -config = load_app_config() -file_store = get_file_store(config.file_store, config.file_store_path) -client_manager = None -redis_host = os.environ.get('REDIS_HOST') -if redis_host: - client_manager = socketio.AsyncRedisManager( - f'redis://{redis_host}', - redis_options={'password': os.environ.get('REDIS_PASSWORD')}, - ) -sio = socketio.AsyncServer( - async_mode='asgi', cors_allowed_origins='*', client_manager=client_manager -) -session_manager = SessionManager(sio, config, file_store) - - -@asynccontextmanager -async def _lifespan(app: FastAPI): - async with session_manager: - yield - - -app = FastAPI(lifespan=_lifespan) -app.add_middleware( - LocalhostCORSMiddleware, - allow_credentials=True, - allow_methods=['*'], - allow_headers=['*'], +base_app.mount( + '/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist' ) - -app.add_middleware(NoCacheMiddleware) -app.add_middleware( - RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=10, seconds=1) -) - - -@app.get('/health') -async def health(): - return 'OK' - - -security_scheme = HTTPBearer() - - -def load_file_upload_config() -> tuple[int, bool, list[str]]: - """Load file upload configuration from the config object. - - This function retrieves the file upload settings from the global config object. - It handles the following settings: - - Maximum file size for uploads - - Whether to restrict file types - - List of allowed file extensions - - It also performs sanity checks on the values to ensure they are valid and safe. - - Returns: - tuple: A tuple containing: - - max_file_size_mb (int): Maximum file size in MB. 0 means no limit. - - restrict_file_types (bool): Whether file type restrictions are enabled. - - allowed_extensions (set): Set of allowed file extensions. - """ - # Retrieve values from config - max_file_size_mb = config.file_uploads_max_file_size_mb - restrict_file_types = config.file_uploads_restrict_file_types - allowed_extensions = config.file_uploads_allowed_extensions - - # Sanity check for max_file_size_mb - if not isinstance(max_file_size_mb, int) or max_file_size_mb < 0: - logger.warning( - f'Invalid max_file_size_mb: {max_file_size_mb}. Setting to 0 (no limit).' - ) - max_file_size_mb = 0 - - # Sanity check for allowed_extensions - if not isinstance(allowed_extensions, (list, set)) or not allowed_extensions: - logger.warning( - f'Invalid allowed_extensions: {allowed_extensions}. Setting to [".*"].' - ) - allowed_extensions = ['.*'] - else: - # Ensure all extensions start with a dot and are lowercase - allowed_extensions = [ - ext.lower() if ext.startswith('.') else f'.{ext.lower()}' - for ext in allowed_extensions - ] - - # If restrictions are disabled, allow all - if not restrict_file_types: - allowed_extensions = ['.*'] - - logger.debug( - f'File upload config: max_size={max_file_size_mb}MB, ' - f'restrict_types={restrict_file_types}, ' - f'allowed_extensions={allowed_extensions}' - ) - - return max_file_size_mb, restrict_file_types, allowed_extensions - - -# Load configuration -MAX_FILE_SIZE_MB, RESTRICT_FILE_TYPES, ALLOWED_EXTENSIONS = load_file_upload_config() - - -def is_extension_allowed(filename): - """Check if the file extension is allowed based on the current configuration. - - This function supports wildcards and files without extensions. - The check is case-insensitive for extensions. - - Args: - filename (str): The name of the file to check. - - Returns: - bool: True if the file extension is allowed, False otherwise. - """ - if not RESTRICT_FILE_TYPES: - return True - - file_ext = os.path.splitext(filename)[1].lower() # Convert to lowercase - return ( - '.*' in ALLOWED_EXTENSIONS - or file_ext in (ext.lower() for ext in ALLOWED_EXTENSIONS) - or (file_ext == '' and '.' in ALLOWED_EXTENSIONS) - ) - - -@app.middleware('http') -async def attach_session(request: Request, call_next): - """Middleware to attach session information to the request. - - This middleware checks for the Authorization header, validates the token, - and attaches the corresponding session to the request state. - - Args: - request (Request): The incoming request object. - call_next (Callable): The next middleware or route handler in the chain. - - Returns: - Response: The response from the next middleware or route handler. - """ - non_authed_paths = [ - '/api/options/', - '/api/github/callback', - '/api/authenticate', - ] - if any( - request.url.path.startswith(path) for path in non_authed_paths - ) or not request.url.path.startswith('/api/'): - response = await call_next(request) - return response - - # Bypass authentication for OPTIONS requests (preflight) - if request.method == 'OPTIONS': - response = await call_next(request) - return response - - user_verifier = UserVerifier() - if user_verifier.is_active(): - signed_token = request.cookies.get('github_auth') - if not signed_token: - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={'error': 'Not authenticated'}, - ) - try: - jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256']) - except Exception as e: - logger.warning(f'Invalid token: {e}') - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={'error': 'Invalid token'}, - ) - - if not request.headers.get('Authorization'): - logger.warning('Missing Authorization header') - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={'error': 'Missing Authorization header'}, - ) - - auth_token = request.headers.get('Authorization') - if 'Bearer' in auth_token: - auth_token = auth_token.split('Bearer')[1].strip() - - request.state.sid = get_sid_from_token(auth_token, config.jwt_secret) - if request.state.sid == '': - logger.warning('Invalid token') - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={'error': 'Invalid token'}, - ) - - request.state.conversation = await session_manager.attach_to_conversation( - request.state.sid - ) - if not request.state.conversation: - logger.error(f'Runtime not found for session: {request.state.sid}') - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={'error': 'Session not found'}, - ) - try: - response = await call_next(request) - finally: - await session_manager.detach_from_conversation(request.state.conversation) - return response - - -@app.get('/api/events/search') -async def search_events( - request: Request, - query: str | None = None, - start_id: int = 0, - limit: int = 20, - event_type: str | None = None, - source: str | None = None, - start_date: str | None = None, - end_date: str | None = None, -): - """Search through the event stream with filtering and pagination. - - Args: - request (Request): The incoming request object - query (str, optional): Text to search for in event content - start_id (int): Starting ID in the event stream. Defaults to 0 - limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 20 - event_type (str, optional): Filter by event type (e.g., "FileReadAction") - source (str, optional): Filter by event source - start_date (str, optional): Filter events after this date (ISO format) - end_date (str, optional): Filter events before this date (ISO format) - - Returns: - dict: Dictionary containing: - - events: List of matching events - - has_more: Whether there are more matching events after this batch - - Raises: - HTTPException: If conversation is not found - ValueError: If limit is less than 1 or greater than 100 - """ - if not request.state.conversation: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail='Conversation not found' - ) - - # Get matching events from the stream - event_stream = request.state.conversation.event_stream - matching_events = event_stream.get_matching_events( - query=query, - event_type=event_type, - source=source, - start_date=start_date, - end_date=end_date, - start_id=start_id, - limit=limit + 1, # Get one extra to check if there are more - ) - - # Check if there are more events - has_more = len(matching_events) > limit - if has_more: - matching_events = matching_events[:limit] # Remove the extra event - - return { - 'events': matching_events, - 'has_more': has_more, - } - - -@app.get('/api/options/models') -async def get_litellm_models() -> list[str]: - """ - Get all models supported by LiteLLM. - - This function combines models from litellm and Bedrock, removing any - error-prone Bedrock models. - - To get the models: - ```sh - curl http://localhost:3000/api/litellm-models - ``` - - Returns: - list: A sorted list of unique model names. - """ - litellm_model_list = litellm.model_list + list(litellm.model_cost.keys()) - litellm_model_list_without_bedrock = bedrock.remove_error_modelId( - litellm_model_list - ) - # TODO: for bedrock, this is using the default config - llm_config: LLMConfig = config.get_llm_config() - bedrock_model_list = [] - if ( - llm_config.aws_region_name - and llm_config.aws_access_key_id - and llm_config.aws_secret_access_key - ): - bedrock_model_list = bedrock.list_foundation_models( - llm_config.aws_region_name, - llm_config.aws_access_key_id, - llm_config.aws_secret_access_key, - ) - model_list = litellm_model_list_without_bedrock + bedrock_model_list - for llm_config in config.llms.values(): - ollama_base_url = llm_config.ollama_base_url - if llm_config.model.startswith('ollama'): - if not ollama_base_url: - ollama_base_url = llm_config.base_url - if ollama_base_url: - ollama_url = ollama_base_url.strip('/') + '/api/tags' - try: - ollama_models_list = requests.get(ollama_url, timeout=3).json()[ - 'models' - ] - for model in ollama_models_list: - model_list.append('ollama/' + model['name']) - break - except requests.exceptions.RequestException as e: - logger.error(f'Error getting OLLAMA models: {e}', exc_info=True) - - return list(sorted(set(model_list))) - - -@app.get('/api/options/agents') -async def get_agents(): - """Get all agents supported by LiteLLM. - - To get the agents: - ```sh - curl http://localhost:3000/api/agents - ``` - - Returns: - list: A sorted list of agent names. - """ - agents = sorted(Agent.list_agents()) - return agents - - -@app.get('/api/options/security-analyzers') -async def get_security_analyzers(): - """Get all supported security analyzers. - - To get the security analyzers: - ```sh - curl http://localhost:3000/api/security-analyzers - ``` - - Returns: - list: A sorted list of security analyzer names. - """ - return sorted(SecurityAnalyzers.keys()) - - -FILES_TO_IGNORE = [ - '.git/', - '.DS_Store', - 'node_modules/', - '__pycache__/', -] - - -@app.get('/api/list-files') -async def list_files(request: Request, path: str | None = None): - """List files in the specified path. - - This function retrieves a list of files from the agent's runtime file store, - excluding certain system and hidden files/directories. - - To list files: - ```sh - curl http://localhost:3000/api/list-files - ``` - - Args: - request (Request): The incoming request object. - path (str, optional): The path to list files from. Defaults to None. - - Returns: - list: A list of file names in the specified path. - - Raises: - HTTPException: If there's an error listing the files. - """ - if not request.state.conversation.runtime: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={'error': 'Runtime not yet initialized'}, - ) - - runtime: Runtime = request.state.conversation.runtime - try: - file_list = await call_sync_from_async(runtime.list_files, path) - except RuntimeUnavailableError as e: - logger.error(f'Error listing files: {e}', exc_info=True) - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={'error': f'Error listing files: {e}'}, - ) - if path: - file_list = [os.path.join(path, f) for f in file_list] - - file_list = [f for f in file_list if f not in FILES_TO_IGNORE] - - async def filter_for_gitignore(file_list, base_path): - gitignore_path = os.path.join(base_path, '.gitignore') - try: - read_action = FileReadAction(gitignore_path) - observation = await call_sync_from_async(runtime.run_action, read_action) - spec = PathSpec.from_lines( - GitWildMatchPattern, observation.content.splitlines() - ) - except Exception as e: - logger.warning(e) - return file_list - file_list = [entry for entry in file_list if not spec.match_file(entry)] - return file_list - - try: - file_list = await filter_for_gitignore(file_list, '') - except RuntimeUnavailableError as e: - logger.error(f'Error filtering files: {e}', exc_info=True) - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={'error': f'Error filtering files: {e}'}, - ) - - return file_list - - -@app.get('/api/select-file') -async def select_file(file: str, request: Request): - """Retrieve the content of a specified file. - - To select a file: - ```sh - curl http://localhost:3000/api/select-file?file= - ``` - - Args: - file (str): The path of the file to be retrieved. - Expect path to be absolute inside the runtime. - request (Request): The incoming request object. - - Returns: - dict: A dictionary containing the file content. - - Raises: - HTTPException: If there's an error opening the file. - """ - runtime: Runtime = request.state.conversation.runtime - - file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file) - read_action = FileReadAction(file) - try: - observation = await call_sync_from_async(runtime.run_action, read_action) - except RuntimeUnavailableError as e: - logger.error(f'Error opening file {file}: {e}', exc_info=True) - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={'error': f'Error opening file: {e}'}, - ) - - if isinstance(observation, FileReadObservation): - content = observation.content - return {'code': content} - elif isinstance(observation, ErrorObservation): - logger.error(f'Error opening file {file}: {observation}', exc_info=False) - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={'error': f'Error opening file: {observation}'}, - ) - - -def sanitize_filename(filename): - """Sanitize the filename to prevent directory traversal""" - # Remove any directory components - filename = os.path.basename(filename) - # Remove any non-alphanumeric characters except for .-_ - filename = re.sub(r'[^\w\-_\.]', '', filename) - # Limit the filename length - max_length = 255 - if len(filename) > max_length: - name, ext = os.path.splitext(filename) - filename = name[: max_length - len(ext)] + ext - return filename - - -@app.get('/api/conversation') -async def get_remote_runtime_config(request: Request): - """Retrieve the remote runtime configuration. - - Currently, this is the runtime ID. - """ - runtime = request.state.conversation.runtime - runtime_id = runtime.runtime_id if hasattr(runtime, 'runtime_id') else None - session_id = runtime.sid if hasattr(runtime, 'sid') else None - return JSONResponse( - content={ - 'runtime_id': runtime_id, - 'session_id': session_id, - } - ) - - -@app.post('/api/upload-files') -async def upload_file(request: Request, files: list[UploadFile]): - """Upload a list of files to the workspace. - - To upload a files: - ```sh - curl -X POST -F "file=@" -F "file=@" http://localhost:3000/api/upload-files - ``` - - Args: - request (Request): The incoming request object. - files (list[UploadFile]): A list of files to be uploaded. - - Returns: - dict: A message indicating the success of the upload operation. - - Raises: - HTTPException: If there's an error saving the files. - """ - try: - uploaded_files = [] - skipped_files = [] - for file in files: - safe_filename = sanitize_filename(file.filename) - file_contents = await file.read() - - if ( - MAX_FILE_SIZE_MB > 0 - and len(file_contents) > MAX_FILE_SIZE_MB * 1024 * 1024 - ): - skipped_files.append( - { - 'name': safe_filename, - 'reason': f'Exceeds maximum size limit of {MAX_FILE_SIZE_MB}MB', - } - ) - continue - - if not is_extension_allowed(safe_filename): - skipped_files.append( - {'name': safe_filename, 'reason': 'File type not allowed'} - ) - continue - - # copy the file to the runtime - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_file_path = os.path.join(tmp_dir, safe_filename) - with open(tmp_file_path, 'wb') as tmp_file: - tmp_file.write(file_contents) - tmp_file.flush() - - runtime: Runtime = request.state.conversation.runtime - try: - await call_sync_from_async( - runtime.copy_to, - tmp_file_path, - runtime.config.workspace_mount_path_in_sandbox, - ) - except RuntimeUnavailableError as e: - logger.error( - f'Error saving file {safe_filename}: {e}', exc_info=True - ) - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={'error': f'Error saving file: {e}'}, - ) - uploaded_files.append(safe_filename) - - response_content = { - 'message': 'File upload process completed', - 'uploaded_files': uploaded_files, - 'skipped_files': skipped_files, - } - - if not uploaded_files and skipped_files: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={ - **response_content, - 'error': 'No files were uploaded successfully', - }, - ) - - return JSONResponse(status_code=status.HTTP_200_OK, content=response_content) - - except Exception as e: - logger.error(f'Error during file upload: {e}', exc_info=True) - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={ - 'error': f'Error during file upload: {str(e)}', - 'uploaded_files': [], - 'skipped_files': [], - }, - ) - - -@app.post('/api/submit-feedback') -async def submit_feedback(request: Request): - """Submit user feedback. - - This function stores the provided feedback data. - - To submit feedback: - ```sh - curl -X POST -d '{"email": "test@example.com"}' -H "Authorization:" - ``` - - Args: - request (Request): The incoming request object. - feedback (FeedbackDataModel): The feedback data to be stored. - - Returns: - dict: The stored feedback data. - - Raises: - HTTPException: If there's an error submitting the feedback. - """ - # Assuming the storage service is already configured in the backend - # and there is a function to handle the storage. - body = await request.json() - async_stream = AsyncEventStreamWrapper( - request.state.conversation.event_stream, filter_hidden=True - ) - trajectory = [] - async for event in async_stream: - trajectory.append(event_to_dict(event)) - feedback = FeedbackDataModel( - email=body.get('email', ''), - version=body.get('version', ''), - permissions=body.get('permissions', 'private'), - polarity=body.get('polarity', ''), - feedback=body.get('polarity', ''), - trajectory=trajectory, - ) - try: - feedback_data = await call_sync_from_async(store_feedback, feedback) - return JSONResponse(status_code=200, content=feedback_data) - except Exception as e: - logger.error(f'Error submitting feedback: {e}') - return JSONResponse( - status_code=500, content={'error': 'Failed to submit feedback'} - ) - - -@app.get('/api/defaults') -async def appconfig_defaults(): - """Retrieve the default configuration settings. - - To get the default configurations: - ```sh - curl http://localhost:3000/api/defaults - ``` - - Returns: - dict: The default configuration settings. - """ - return config.defaults_dict - - -@app.post('/api/save-file') -async def save_file(request: Request): - """Save a file to the agent's runtime file store. - - This endpoint allows saving a file when the agent is in a paused, finished, - or awaiting user input state. It checks the agent's state before proceeding - with the file save operation. - - Args: - request (Request): The incoming FastAPI request object. - - Returns: - JSONResponse: A JSON response indicating the success of the operation. - - Raises: - HTTPException: - - 403 error if the agent is not in an allowed state for editing. - - 400 error if the file path or content is missing. - - 500 error if there's an unexpected error during the save operation. - """ - try: - # Extract file path and content from the request - data = await request.json() - file_path = data.get('filePath') - content = data.get('content') - - # Validate the presence of required data - if not file_path or content is None: - raise HTTPException(status_code=400, detail='Missing filePath or content') - - # Save the file to the agent's runtime file store - runtime: Runtime = request.state.conversation.runtime - file_path = os.path.join( - runtime.config.workspace_mount_path_in_sandbox, file_path - ) - write_action = FileWriteAction(file_path, content) - try: - observation = await call_sync_from_async(runtime.run_action, write_action) - except RuntimeUnavailableError as e: - logger.error(f'Error saving file: {e}', exc_info=True) - return JSONResponse( - status_code=500, - content={'error': f'Error saving file: {e}'}, - ) - - if isinstance(observation, FileWriteObservation): - return JSONResponse( - status_code=200, content={'message': 'File saved successfully'} - ) - elif isinstance(observation, ErrorObservation): - return JSONResponse( - status_code=500, - content={'error': f'Failed to save file: {observation}'}, - ) - else: - return JSONResponse( - status_code=500, - content={'error': f'Unexpected observation: {observation}'}, - ) - except Exception as e: - # Log the error and return a 500 response - logger.error(f'Error saving file: {e}', exc_info=True) - raise HTTPException(status_code=500, detail=f'Error saving file: {e}') - - -@app.route('/api/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE']) -async def security_api(request: Request): - """Catch-all route for security analyzer API requests. - - Each request is handled directly to the security analyzer. - - Args: - request (Request): The incoming FastAPI request object. - - Returns: - Any: The response from the security analyzer. - - Raises: - HTTPException: If the security analyzer is not initialized. - """ - if not request.state.conversation.security_analyzer: - raise HTTPException(status_code=404, detail='Security analyzer not initialized') - - return await request.state.conversation.security_analyzer.handle_api_request( - request - ) - - -@app.get('/api/zip-directory') -async def zip_current_workspace(request: Request, background_tasks: BackgroundTasks): - try: - logger.debug('Zipping workspace') - runtime: Runtime = request.state.conversation.runtime - path = runtime.config.workspace_mount_path_in_sandbox - try: - zip_file = await call_sync_from_async(runtime.copy_from, path) - except RuntimeUnavailableError as e: - logger.error(f'Error zipping workspace: {e}', exc_info=True) - return JSONResponse( - status_code=500, - content={'error': f'Error zipping workspace: {e}'}, - ) - response = FileResponse( - path=zip_file, - filename='workspace.zip', - media_type='application/x-zip-compressed', - ) - - # This will execute after the response is sent (So the file is not deleted before being sent) - background_tasks.add_task(zip_file.unlink) - - return response - except Exception as e: - logger.error(f'Error zipping workspace: {e}', exc_info=True) - raise HTTPException( - status_code=500, - detail='Failed to zip workspace', - ) - - -class AuthCode(BaseModel): - code: str - - -@app.post('/api/github/callback') -def github_callback(auth_code: AuthCode): - # Prepare data for the token exchange request - data = { - 'client_id': GITHUB_CLIENT_ID, - 'client_secret': GITHUB_CLIENT_SECRET, - 'code': auth_code.code, - } - - logger.debug('Exchanging code for GitHub token') - - headers = {'Accept': 'application/json'} - response = requests.post( - 'https://github.com/login/oauth/access_token', data=data, headers=headers - ) - - if response.status_code != 200: - logger.error(f'Failed to exchange code for token: {response.text}') - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={'error': 'Failed to exchange code for token'}, - ) - - token_response = response.json() - - if 'access_token' not in token_response: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={'error': 'No access token in response'}, - ) - - return JSONResponse( - status_code=status.HTTP_200_OK, - content={'access_token': token_response['access_token']}, - ) - - -@app.post('/api/authenticate') -async def authenticate(request: Request): - token = request.headers.get('X-GitHub-Token') - if not await authenticate_github_user(token): - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={'error': 'Not authorized via GitHub waitlist'}, - ) - - # Create a signed JWT token with 1-hour expiration - cookie_data = { - 'github_token': token, - 'exp': int(time.time()) + 3600, # 1 hour expiration - } - signed_token = sign_token(cookie_data, config.jwt_secret) - - response = JSONResponse( - status_code=status.HTTP_200_OK, content={'message': 'User authenticated'} - ) - - # Set secure cookie with signed token - response.set_cookie( - key='github_auth', - value=signed_token, - max_age=3600, # 1 hour in seconds - httponly=True, - secure=True, - samesite='strict', - ) - return response - - -@app.get('/api/vscode-url') -async def get_vscode_url(request: Request): - """Get the VSCode URL. - - This endpoint allows getting the VSCode URL. - - Args: - request (Request): The incoming FastAPI request object. - - Returns: - JSONResponse: A JSON response indicating the success of the operation. - """ - try: - runtime: Runtime = request.state.conversation.runtime - logger.debug(f'Runtime type: {type(runtime)}') - logger.debug(f'Runtime VSCode URL: {runtime.vscode_url}') - return JSONResponse(status_code=200, content={'vscode_url': runtime.vscode_url}) - except Exception as e: - logger.error(f'Error getting VSCode URL: {e}', exc_info=True) - return JSONResponse( - status_code=500, - content={ - 'vscode_url': None, - 'error': f'Error getting VSCode URL: {e}', - }, - ) - - -class SPAStaticFiles(StaticFiles): - async def get_response(self, path: str, scope): - try: - return await super().get_response(path, scope) - except Exception: - # FIXME: just making this HTTPException doesn't work for some reason - return await super().get_response('index.html', scope) - - -app.mount('/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist') - -app = socketio.ASGIApp(sio, other_asgi_app=app) - - -@sio.event -async def connect(connection_id: str, environ): - logger.info(f'sio:connect: {connection_id}') - - -@sio.event -async def oh_action(connection_id: str, data: dict): - # If it's an init, we do it here. - action = data.get('action', '') - if action == ActionType.INIT: - await init_connection(connection_id, data) - return - - logger.info(f'sio:oh_action:{connection_id}') - await session_manager.send_to_event_stream(connection_id, data) - - -async def init_connection(connection_id: str, data: dict): - gh_token = data.pop('github_token', None) - if not await authenticate_github_user(gh_token): - raise RuntimeError(status.WS_1008_POLICY_VIOLATION) - - token = data.pop('token', None) - if token: - sid = get_sid_from_token(token, config.jwt_secret) - if sid == '': - await sio.send({'error': 'Invalid token', 'error_code': 401}) - return - logger.info(f'Existing session: {sid}') - else: - sid = connection_id - logger.info(f'New session: {sid}') - - token = sign_token({'sid': sid}, config.jwt_secret) - await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id) - - latest_event_id = int(data.pop('latest_event_id', -1)) - - # The session in question should exist, but may not actually be running locally... - event_stream = await session_manager.init_or_join_session(sid, connection_id, data) - - # Send events - async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1) - async for event in async_stream: - if isinstance( - event, - ( - NullAction, - NullObservation, - ), - ): - continue - await sio.emit('oh_event', event_to_dict(event), to=connection_id) - - -@sio.event -async def disconnect(connection_id: str): - logger.info(f'sio:disconnect:{connection_id}') - await session_manager.disconnect_from_session(connection_id) +app = socketio.ASGIApp(sio, other_asgi_app=base_app) diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 624dac62b4..803887471e 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -1,14 +1,21 @@ import asyncio from collections import defaultdict from datetime import datetime, timedelta +from typing import Callable from urllib.parse import urlparse -from fastapi import Request +import jwt +from fastapi import APIRouter, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp +from openhands.core.logger import openhands_logger as logger +from openhands.server.auth import get_sid_from_token +from openhands.server.github_utils import UserVerifier +from openhands.server.shared import config, session_manager + class LocalhostCORSMiddleware(CORSMiddleware): """ @@ -100,3 +107,71 @@ class RateLimitMiddleware(BaseHTTPMiddleware): headers={'Retry-After': '1'}, ) return await call_next(request) + + +class AttachSessionMiddleware: + def __init__(self, app, target_router: APIRouter): + self.app = app + self.target_router = target_router + self.target_paths = {route.path for route in target_router.routes} + + async def __call__(self, request: Request, call_next: Callable): + do_attach = False + if request.url.path in self.target_paths: + do_attach = True + + if request.method == 'OPTIONS': + do_attach = False + + if not do_attach: + return await call_next(request) + + user_verifier = UserVerifier() + if user_verifier.is_active(): + signed_token = request.cookies.get('github_auth') + if not signed_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={'error': 'Not authenticated'}, + ) + try: + jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256']) + except Exception as e: + logger.warning(f'Invalid token: {e}') + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={'error': 'Invalid token'}, + ) + + if not request.headers.get('Authorization'): + logger.warning('Missing Authorization header') + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={'error': 'Missing Authorization header'}, + ) + + auth_token = request.headers.get('Authorization') + if 'Bearer' in auth_token: + auth_token = auth_token.split('Bearer')[1].strip() + + request.state.sid = get_sid_from_token(auth_token, config.jwt_secret) + if request.state.sid == '': + logger.warning('Invalid token') + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={'error': 'Invalid token'}, + ) + + request.state.conversation = await session_manager.attach_to_conversation( + request.state.sid + ) + if request.state.conversation is None: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': 'Session not found'}, + ) + try: + response = await call_next(request) + finally: + await session_manager.detach_from_conversation(request.state.conversation) + return response diff --git a/openhands/server/routes/auth.py b/openhands/server/routes/auth.py new file mode 100644 index 0000000000..67151f7e96 --- /dev/null +++ b/openhands/server/routes/auth.py @@ -0,0 +1,100 @@ +import time +import warnings + +import requests + +from openhands.server.github_utils import ( + GITHUB_CLIENT_ID, + GITHUB_CLIENT_SECRET, + authenticate_github_user, +) + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + +from fastapi import ( + APIRouter, + Request, + status, +) +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from openhands.core.logger import openhands_logger as logger +from openhands.server.auth import sign_token +from openhands.server.shared import config + +app = APIRouter(prefix='/api') + + +class AuthCode(BaseModel): + code: str + + +@app.post('/github/callback') +def github_callback(auth_code: AuthCode): + # Prepare data for the token exchange request + data = { + 'client_id': GITHUB_CLIENT_ID, + 'client_secret': GITHUB_CLIENT_SECRET, + 'code': auth_code.code, + } + + logger.debug('Exchanging code for GitHub token') + + headers = {'Accept': 'application/json'} + response = requests.post( + 'https://github.com/login/oauth/access_token', data=data, headers=headers + ) + + if response.status_code != 200: + logger.error(f'Failed to exchange code for token: {response.text}') + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={'error': 'Failed to exchange code for token'}, + ) + + token_response = response.json() + + if 'access_token' not in token_response: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={'error': 'No access token in response'}, + ) + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={'access_token': token_response['access_token']}, + ) + + +@app.post('/authenticate') +async def authenticate(request: Request): + token = request.headers.get('X-GitHub-Token') + if not await authenticate_github_user(token): + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={'error': 'Not authorized via GitHub waitlist'}, + ) + + # Create a signed JWT token with 1-hour expiration + cookie_data = { + 'github_token': token, + 'exp': int(time.time()) + 3600, # 1 hour expiration + } + signed_token = sign_token(cookie_data, config.jwt_secret) + + response = JSONResponse( + status_code=status.HTTP_200_OK, content={'message': 'User authenticated'} + ) + + # Set secure cookie with signed token + response.set_cookie( + key='github_auth', + value=signed_token, + max_age=3600, # 1 hour in seconds + httponly=True, + secure=True, + samesite='strict', + ) + return response diff --git a/openhands/server/routes/conversation.py b/openhands/server/routes/conversation.py new file mode 100644 index 0000000000..a47eed68ef --- /dev/null +++ b/openhands/server/routes/conversation.py @@ -0,0 +1,106 @@ +from fastapi import APIRouter, HTTPException, Request, status +from fastapi.responses import JSONResponse + +from openhands.core.logger import openhands_logger as logger +from openhands.runtime.base import Runtime + +app = APIRouter(prefix='/api') + + +@app.get('/conversation') +async def get_remote_runtime_config(request: Request): + """Retrieve the runtime configuration. + + Currently, this is the session ID and runtime ID (if available). + """ + runtime = request.state.conversation.runtime + runtime_id = runtime.runtime_id if hasattr(runtime, 'runtime_id') else None + session_id = runtime.sid if hasattr(runtime, 'sid') else None + return JSONResponse( + content={ + 'runtime_id': runtime_id, + 'session_id': session_id, + } + ) + + +@app.get('/vscode-url') +async def get_vscode_url(request: Request): + """Get the VSCode URL. + + This endpoint allows getting the VSCode URL. + + Args: + request (Request): The incoming FastAPI request object. + + Returns: + JSONResponse: A JSON response indicating the success of the operation. + """ + try: + runtime: Runtime = request.state.conversation.runtime + logger.debug(f'Runtime type: {type(runtime)}') + logger.debug(f'Runtime VSCode URL: {runtime.vscode_url}') + return JSONResponse(status_code=200, content={'vscode_url': runtime.vscode_url}) + except Exception as e: + logger.error(f'Error getting VSCode URL: {e}', exc_info=True) + return JSONResponse( + status_code=500, + content={ + 'vscode_url': None, + 'error': f'Error getting VSCode URL: {e}', + }, + ) + + +@app.get('/events/search') +async def search_events( + request: Request, + query: str | None = None, + start_id: int = 0, + limit: int = 20, + event_type: str | None = None, + source: str | None = None, + start_date: str | None = None, + end_date: str | None = None, +): + """Search through the event stream with filtering and pagination. + Args: + request (Request): The incoming request object + query (str, optional): Text to search for in event content + start_id (int): Starting ID in the event stream. Defaults to 0 + limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 20 + event_type (str, optional): Filter by event type (e.g., "FileReadAction") + source (str, optional): Filter by event source + start_date (str, optional): Filter events after this date (ISO format) + end_date (str, optional): Filter events before this date (ISO format) + Returns: + dict: Dictionary containing: + - events: List of matching events + - has_more: Whether there are more matching events after this batch + Raises: + HTTPException: If conversation is not found + ValueError: If limit is less than 1 or greater than 100 + """ + if not request.state.conversation: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail='Conversation not found' + ) + # Get matching events from the stream + event_stream = request.state.conversation.event_stream + matching_events = event_stream.get_matching_events( + query=query, + event_type=event_type, + source=source, + start_date=start_date, + end_date=end_date, + start_id=start_id, + limit=limit + 1, # Get one extra to check if there are more + ) + # Check if there are more events + has_more = len(matching_events) > limit + if has_more: + matching_events = matching_events[:limit] # Remove the extra event + return { + 'events': matching_events, + 'has_more': has_more, + } diff --git a/openhands/server/routes/feedback.py b/openhands/server/routes/feedback.py new file mode 100644 index 0000000000..8489ec84e6 --- /dev/null +++ b/openhands/server/routes/feedback.py @@ -0,0 +1,74 @@ +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse + +from openhands.core.logger import openhands_logger as logger +from openhands.events.serialization import event_to_dict +from openhands.events.stream import AsyncEventStreamWrapper +from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback +from openhands.server.shared import config +from openhands.utils.async_utils import call_sync_from_async + +app = APIRouter(prefix='/api') + + +@app.post('/submit-feedback') +async def submit_feedback(request: Request): + """Submit user feedback. + + This function stores the provided feedback data. + + To submit feedback: + ```sh + curl -X POST -d '{"email": "test@example.com"}' -H "Authorization:" + ``` + + Args: + request (Request): The incoming request object. + feedback (FeedbackDataModel): The feedback data to be stored. + + Returns: + dict: The stored feedback data. + + Raises: + HTTPException: If there's an error submitting the feedback. + """ + # Assuming the storage service is already configured in the backend + # and there is a function to handle the storage. + body = await request.json() + async_stream = AsyncEventStreamWrapper( + request.state.conversation.event_stream, filter_hidden=True + ) + trajectory = [] + async for event in async_stream: + trajectory.append(event_to_dict(event)) + feedback = FeedbackDataModel( + email=body.get('email', ''), + version=body.get('version', ''), + permissions=body.get('permissions', 'private'), + polarity=body.get('polarity', ''), + feedback=body.get('polarity', ''), + trajectory=trajectory, + ) + try: + feedback_data = await call_sync_from_async(store_feedback, feedback) + return JSONResponse(status_code=200, content=feedback_data) + except Exception as e: + logger.error(f'Error submitting feedback: {e}') + return JSONResponse( + status_code=500, content={'error': 'Failed to submit feedback'} + ) + + +@app.get('/api/defaults') +async def appconfig_defaults(): + """Retrieve the default configuration settings. + + To get the default configurations: + ```sh + curl http://localhost:3000/api/defaults + ``` + + Returns: + dict: The default configuration settings. + """ + return config.defaults_dict diff --git a/openhands/server/routes/files.py b/openhands/server/routes/files.py new file mode 100644 index 0000000000..c2d37350c8 --- /dev/null +++ b/openhands/server/routes/files.py @@ -0,0 +1,341 @@ +import os +import tempfile + +from fastapi import ( + APIRouter, + BackgroundTasks, + HTTPException, + Request, + UploadFile, + status, +) +from fastapi.responses import FileResponse, JSONResponse +from pathspec import PathSpec +from pathspec.patterns import GitWildMatchPattern + +from openhands.core.logger import openhands_logger as logger +from openhands.events.action import ( + FileReadAction, + FileWriteAction, +) +from openhands.events.observation import ( + ErrorObservation, + FileReadObservation, + FileWriteObservation, +) +from openhands.runtime.base import Runtime, RuntimeUnavailableError +from openhands.server.file_config import ( + FILES_TO_IGNORE, + MAX_FILE_SIZE_MB, + is_extension_allowed, + sanitize_filename, +) +from openhands.utils.async_utils import call_sync_from_async + +app = APIRouter(prefix='/api') + + +@app.get('/list-files') +async def list_files(request: Request, path: str | None = None): + """List files in the specified path. + + This function retrieves a list of files from the agent's runtime file store, + excluding certain system and hidden files/directories. + + To list files: + ```sh + curl http://localhost:3000/api/list-files + ``` + + Args: + request (Request): The incoming request object. + path (str, optional): The path to list files from. Defaults to None. + + Returns: + list: A list of file names in the specified path. + + Raises: + HTTPException: If there's an error listing the files. + """ + if not request.state.conversation.runtime: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': 'Runtime not yet initialized'}, + ) + + runtime: Runtime = request.state.conversation.runtime + try: + file_list = await call_sync_from_async(runtime.list_files, path) + except RuntimeUnavailableError as e: + logger.error(f'Error listing files: {e}', exc_info=True) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': f'Error listing files: {e}'}, + ) + if path: + file_list = [os.path.join(path, f) for f in file_list] + + file_list = [f for f in file_list if f not in FILES_TO_IGNORE] + + async def filter_for_gitignore(file_list, base_path): + gitignore_path = os.path.join(base_path, '.gitignore') + try: + read_action = FileReadAction(gitignore_path) + observation = await call_sync_from_async(runtime.run_action, read_action) + spec = PathSpec.from_lines( + GitWildMatchPattern, observation.content.splitlines() + ) + except Exception as e: + logger.warning(e) + return file_list + file_list = [entry for entry in file_list if not spec.match_file(entry)] + return file_list + + try: + file_list = await filter_for_gitignore(file_list, '') + except RuntimeUnavailableError as e: + logger.error(f'Error filtering files: {e}', exc_info=True) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': f'Error filtering files: {e}'}, + ) + + return file_list + + +@app.get('/select-file') +async def select_file(file: str, request: Request): + """Retrieve the content of a specified file. + + To select a file: + ```sh + curl http://localhost:3000/api/select-file?file= + ``` + + Args: + file (str): The path of the file to be retrieved. + Expect path to be absolute inside the runtime. + request (Request): The incoming request object. + + Returns: + dict: A dictionary containing the file content. + + Raises: + HTTPException: If there's an error opening the file. + """ + runtime: Runtime = request.state.conversation.runtime + + file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file) + read_action = FileReadAction(file) + try: + observation = await call_sync_from_async(runtime.run_action, read_action) + except RuntimeUnavailableError as e: + logger.error(f'Error opening file {file}: {e}', exc_info=True) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': f'Error opening file: {e}'}, + ) + + if isinstance(observation, FileReadObservation): + content = observation.content + return {'code': content} + elif isinstance(observation, ErrorObservation): + logger.error(f'Error opening file {file}: {observation}', exc_info=False) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': f'Error opening file: {observation}'}, + ) + + +@app.post('/upload-files') +async def upload_file(request: Request, files: list[UploadFile]): + """Upload a list of files to the workspace. + + To upload a files: + ```sh + curl -X POST -F "file=@" -F "file=@" http://localhost:3000/api/upload-files + ``` + + Args: + request (Request): The incoming request object. + files (list[UploadFile]): A list of files to be uploaded. + + Returns: + dict: A message indicating the success of the upload operation. + + Raises: + HTTPException: If there's an error saving the files. + """ + try: + uploaded_files = [] + skipped_files = [] + for file in files: + safe_filename = sanitize_filename(file.filename) + file_contents = await file.read() + + if ( + MAX_FILE_SIZE_MB > 0 + and len(file_contents) > MAX_FILE_SIZE_MB * 1024 * 1024 + ): + skipped_files.append( + { + 'name': safe_filename, + 'reason': f'Exceeds maximum size limit of {MAX_FILE_SIZE_MB}MB', + } + ) + continue + + if not is_extension_allowed(safe_filename): + skipped_files.append( + {'name': safe_filename, 'reason': 'File type not allowed'} + ) + continue + + # copy the file to the runtime + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file_path = os.path.join(tmp_dir, safe_filename) + with open(tmp_file_path, 'wb') as tmp_file: + tmp_file.write(file_contents) + tmp_file.flush() + + runtime: Runtime = request.state.conversation.runtime + try: + await call_sync_from_async( + runtime.copy_to, + tmp_file_path, + runtime.config.workspace_mount_path_in_sandbox, + ) + except RuntimeUnavailableError as e: + logger.error( + f'Error saving file {safe_filename}: {e}', exc_info=True + ) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': f'Error saving file: {e}'}, + ) + uploaded_files.append(safe_filename) + + response_content = { + 'message': 'File upload process completed', + 'uploaded_files': uploaded_files, + 'skipped_files': skipped_files, + } + + if not uploaded_files and skipped_files: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={ + **response_content, + 'error': 'No files were uploaded successfully', + }, + ) + + return JSONResponse(status_code=status.HTTP_200_OK, content=response_content) + + except Exception as e: + logger.error(f'Error during file upload: {e}', exc_info=True) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + 'error': f'Error during file upload: {str(e)}', + 'uploaded_files': [], + 'skipped_files': [], + }, + ) + + +@app.post('/save-file') +async def save_file(request: Request): + """Save a file to the agent's runtime file store. + + This endpoint allows saving a file when the agent is in a paused, finished, + or awaiting user input state. It checks the agent's state before proceeding + with the file save operation. + + Args: + request (Request): The incoming FastAPI request object. + + Returns: + JSONResponse: A JSON response indicating the success of the operation. + + Raises: + HTTPException: + - 403 error if the agent is not in an allowed state for editing. + - 400 error if the file path or content is missing. + - 500 error if there's an unexpected error during the save operation. + """ + try: + # Extract file path and content from the request + data = await request.json() + file_path = data.get('filePath') + content = data.get('content') + + # Validate the presence of required data + if not file_path or content is None: + raise HTTPException(status_code=400, detail='Missing filePath or content') + + # Save the file to the agent's runtime file store + runtime: Runtime = request.state.conversation.runtime + file_path = os.path.join( + runtime.config.workspace_mount_path_in_sandbox, file_path + ) + write_action = FileWriteAction(file_path, content) + try: + observation = await call_sync_from_async(runtime.run_action, write_action) + except RuntimeUnavailableError as e: + logger.error(f'Error saving file: {e}', exc_info=True) + return JSONResponse( + status_code=500, + content={'error': f'Error saving file: {e}'}, + ) + + if isinstance(observation, FileWriteObservation): + return JSONResponse( + status_code=200, content={'message': 'File saved successfully'} + ) + elif isinstance(observation, ErrorObservation): + return JSONResponse( + status_code=500, + content={'error': f'Failed to save file: {observation}'}, + ) + else: + return JSONResponse( + status_code=500, + content={'error': f'Unexpected observation: {observation}'}, + ) + except Exception as e: + # Log the error and return a 500 response + logger.error(f'Error saving file: {e}', exc_info=True) + raise HTTPException(status_code=500, detail=f'Error saving file: {e}') + + +@app.get('/zip-directory') +async def zip_current_workspace(request: Request, background_tasks: BackgroundTasks): + try: + logger.debug('Zipping workspace') + runtime: Runtime = request.state.conversation.runtime + path = runtime.config.workspace_mount_path_in_sandbox + try: + zip_file = await call_sync_from_async(runtime.copy_from, path) + except RuntimeUnavailableError as e: + logger.error(f'Error zipping workspace: {e}', exc_info=True) + return JSONResponse( + status_code=500, + content={'error': f'Error zipping workspace: {e}'}, + ) + response = FileResponse( + path=zip_file, + filename='workspace.zip', + media_type='application/x-zip-compressed', + ) + + # This will execute after the response is sent (So the file is not deleted before being sent) + background_tasks.add_task(zip_file.unlink) + + return response + except Exception as e: + logger.error(f'Error zipping workspace: {e}', exc_info=True) + raise HTTPException( + status_code=500, + detail='Failed to zip workspace', + ) diff --git a/openhands/server/routes/public.py b/openhands/server/routes/public.py new file mode 100644 index 0000000000..dae4278078 --- /dev/null +++ b/openhands/server/routes/public.py @@ -0,0 +1,106 @@ +import warnings + +import requests + +from openhands.security.options import SecurityAnalyzers + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + import litellm + +from fastapi import ( + APIRouter, +) + +from openhands.controller.agent import Agent +from openhands.core.config import LLMConfig +from openhands.core.logger import openhands_logger as logger +from openhands.llm import bedrock +from openhands.server.shared import config + +app = APIRouter(prefix='/api/options') + + +@app.get('/models') +async def get_litellm_models() -> list[str]: + """ + Get all models supported by LiteLLM. + + This function combines models from litellm and Bedrock, removing any + error-prone Bedrock models. + + To get the models: + ```sh + curl http://localhost:3000/api/litellm-models + ``` + + Returns: + list: A sorted list of unique model names. + """ + litellm_model_list = litellm.model_list + list(litellm.model_cost.keys()) + litellm_model_list_without_bedrock = bedrock.remove_error_modelId( + litellm_model_list + ) + # TODO: for bedrock, this is using the default config + llm_config: LLMConfig = config.get_llm_config() + bedrock_model_list = [] + if ( + llm_config.aws_region_name + and llm_config.aws_access_key_id + and llm_config.aws_secret_access_key + ): + bedrock_model_list = bedrock.list_foundation_models( + llm_config.aws_region_name, + llm_config.aws_access_key_id, + llm_config.aws_secret_access_key, + ) + model_list = litellm_model_list_without_bedrock + bedrock_model_list + for llm_config in config.llms.values(): + ollama_base_url = llm_config.ollama_base_url + if llm_config.model.startswith('ollama'): + if not ollama_base_url: + ollama_base_url = llm_config.base_url + if ollama_base_url: + ollama_url = ollama_base_url.strip('/') + '/api/tags' + try: + ollama_models_list = requests.get(ollama_url, timeout=3).json()[ + 'models' + ] + for model in ollama_models_list: + model_list.append('ollama/' + model['name']) + break + except requests.exceptions.RequestException as e: + logger.error(f'Error getting OLLAMA models: {e}', exc_info=True) + + return list(sorted(set(model_list))) + + +@app.get('/agents') +async def get_agents(): + """Get all agents supported by LiteLLM. + + To get the agents: + ```sh + curl http://localhost:3000/api/agents + ``` + + Returns: + list: A sorted list of agent names. + """ + agents = sorted(Agent.list_agents()) + return agents + + +@app.get('/security-analyzers') +async def get_security_analyzers(): + """Get all supported security analyzers. + + To get the security analyzers: + ```sh + curl http://localhost:3000/api/security-analyzers + ``` + + Returns: + list: A sorted list of security analyzer names. + """ + return sorted(SecurityAnalyzers.keys()) diff --git a/openhands/server/routes/security.py b/openhands/server/routes/security.py new file mode 100644 index 0000000000..f65fcce6aa --- /dev/null +++ b/openhands/server/routes/security.py @@ -0,0 +1,30 @@ +from fastapi import ( + APIRouter, + HTTPException, + Request, +) + +app = APIRouter(prefix='/api') + + +@app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE']) +async def security_api(request: Request): + """Catch-all route for security analyzer API requests. + + Each request is handled directly to the security analyzer. + + Args: + request (Request): The incoming FastAPI request object. + + Returns: + Any: The response from the security analyzer. + + Raises: + HTTPException: If the security analyzer is not initialized. + """ + if not request.state.conversation.security_analyzer: + raise HTTPException(status_code=404, detail='Security analyzer not initialized') + + return await request.state.conversation.security_analyzer.handle_api_request( + request + ) diff --git a/openhands/server/shared.py b/openhands/server/shared.py new file mode 100644 index 0000000000..a7cc5c87c0 --- /dev/null +++ b/openhands/server/shared.py @@ -0,0 +1,28 @@ +import os + +import socketio +from dotenv import load_dotenv + +from openhands.core.config import load_app_config +from openhands.server.session import SessionManager +from openhands.storage import get_file_store + +load_dotenv() + +config = load_app_config() +file_store = get_file_store(config.file_store, config.file_store_path) + +client_manager = None +redis_host = os.environ.get('REDIS_HOST') +if redis_host: + client_manager = socketio.AsyncRedisManager( + f'redis://{redis_host}', + redis_options={'password': os.environ.get('REDIS_PASSWORD')}, + ) + + +sio = socketio.AsyncServer( + async_mode='asgi', cors_allowed_origins='*', client_manager=client_manager +) + +session_manager = SessionManager(sio, config, file_store) diff --git a/openhands/server/socket.py b/openhands/server/socket.py new file mode 100644 index 0000000000..9c66a1d555 --- /dev/null +++ b/openhands/server/socket.py @@ -0,0 +1,76 @@ +from fastapi import status + +from openhands.core.logger import openhands_logger as logger +from openhands.core.schema.action import ActionType +from openhands.events.action import ( + NullAction, +) +from openhands.events.observation import ( + NullObservation, +) +from openhands.events.serialization import event_to_dict +from openhands.events.stream import AsyncEventStreamWrapper +from openhands.server.auth import get_sid_from_token, sign_token +from openhands.server.github_utils import authenticate_github_user +from openhands.server.shared import config, session_manager, sio + + +@sio.event +async def connect(connection_id: str, environ): + logger.info(f'sio:connect: {connection_id}') + + +@sio.event +async def oh_action(connection_id: str, data: dict): + # If it's an init, we do it here. + action = data.get('action', '') + if action == ActionType.INIT: + await init_connection(connection_id, data) + return + + logger.info(f'sio:oh_action:{connection_id}') + await session_manager.send_to_event_stream(connection_id, data) + + +async def init_connection(connection_id: str, data: dict): + gh_token = data.pop('github_token', None) + if not await authenticate_github_user(gh_token): + raise RuntimeError(status.WS_1008_POLICY_VIOLATION) + + token = data.pop('token', None) + if token: + sid = get_sid_from_token(token, config.jwt_secret) + if sid == '': + await sio.send({'error': 'Invalid token', 'error_code': 401}) + return + logger.info(f'Existing session: {sid}') + else: + sid = connection_id + logger.info(f'New session: {sid}') + + token = sign_token({'sid': sid}, config.jwt_secret) + await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id) + + latest_event_id = int(data.pop('latest_event_id', -1)) + + # The session in question should exist, but may not actually be running locally... + event_stream = await session_manager.init_or_join_session(sid, connection_id, data) + + # Send events + async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1) + async for event in async_stream: + if isinstance( + event, + ( + NullAction, + NullObservation, + ), + ): + continue + await sio.emit('oh_event', event_to_dict(event), to=connection_id) + + +@sio.event +async def disconnect(connection_id: str): + logger.info(f'sio:disconnect:{connection_id}') + await session_manager.disconnect_from_session(connection_id) diff --git a/openhands/server/static.py b/openhands/server/static.py new file mode 100644 index 0000000000..ca7eb36c9b --- /dev/null +++ b/openhands/server/static.py @@ -0,0 +1,10 @@ +from fastapi.staticfiles import StaticFiles + + +class SPAStaticFiles(StaticFiles): + async def get_response(self, path: str, scope): + try: + return await super().get_response(path, scope) + except Exception: + # FIXME: just making this HTTPException doesn't work for some reason + return await super().get_response('index.html', scope) diff --git a/tests/unit/test_listen.py b/tests/unit/test_listen.py index 54a1302b26..f19be8aedb 100644 --- a/tests/unit/test_listen.py +++ b/tests/unit/test_listen.py @@ -19,7 +19,10 @@ class MockStaticFiles: with patch('openhands.server.session.SessionManager', MockSessionManager), patch( 'fastapi.staticfiles.StaticFiles', MockStaticFiles ): - from openhands.server.listen import is_extension_allowed, load_file_upload_config + from openhands.server.file_config import ( + is_extension_allowed, + load_file_upload_config, + ) def test_load_file_upload_config(): @@ -28,12 +31,11 @@ def test_load_file_upload_config(): file_uploads_restrict_file_types=True, file_uploads_allowed_extensions=['.txt', '.pdf'], ) - with patch('openhands.server.listen.config', config): - max_size, restrict_types, allowed_extensions = load_file_upload_config() + max_size, restrict_types, allowed_extensions = load_file_upload_config(config) - assert max_size == 10 - assert restrict_types is True - assert set(allowed_extensions) == {'.txt', '.pdf'} + assert max_size == 10 + assert restrict_types is True + assert set(allowed_extensions) == {'.txt', '.pdf'} def test_load_file_upload_config_invalid_max_size(): @@ -42,7 +44,7 @@ def test_load_file_upload_config_invalid_max_size(): file_uploads_restrict_file_types=False, file_uploads_allowed_extensions=[], ) - with patch('openhands.server.listen.config', config): + with patch('openhands.server.shared.config', config): max_size, restrict_types, allowed_extensions = load_file_upload_config() assert max_size == 0 # Should default to 0 when invalid @@ -51,8 +53,8 @@ def test_load_file_upload_config_invalid_max_size(): def test_is_extension_allowed(): - with patch('openhands.server.listen.RESTRICT_FILE_TYPES', True), patch( - 'openhands.server.listen.ALLOWED_EXTENSIONS', ['.txt', '.pdf'] + with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), patch( + 'openhands.server.file_config.ALLOWED_EXTENSIONS', ['.txt', '.pdf'] ): assert is_extension_allowed('file.txt') assert is_extension_allowed('file.pdf') @@ -61,7 +63,7 @@ def test_is_extension_allowed(): def test_is_extension_allowed_no_restrictions(): - with patch('openhands.server.listen.RESTRICT_FILE_TYPES', False): + with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', False): assert is_extension_allowed('file.txt') assert is_extension_allowed('file.pdf') assert is_extension_allowed('file.doc') @@ -69,8 +71,8 @@ def test_is_extension_allowed_no_restrictions(): def test_is_extension_allowed_wildcard(): - with patch('openhands.server.listen.RESTRICT_FILE_TYPES', True), patch( - 'openhands.server.listen.ALLOWED_EXTENSIONS', ['.*'] + with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), patch( + 'openhands.server.file_config.ALLOWED_EXTENSIONS', ['.*'] ): assert is_extension_allowed('file.txt') assert is_extension_allowed('file.pdf')