Refactor listen.py (#5281)

Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
This commit is contained in:
Robert Brennan
2024-11-26 12:57:24 -05:00
committed by GitHub
parent be6ca4a3ce
commit cac3b6d7f7
16 changed files with 1134 additions and 1042 deletions

56
openhands/server/app.py Normal file
View File

@@ -0,0 +1,56 @@
import warnings
with warnings.catch_warnings():
warnings.simplefilter('ignore')
from fastapi import (
FastAPI,
)
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
from openhands.server.middleware import (
AttachSessionMiddleware,
InMemoryRateLimiter,
LocalhostCORSMiddleware,
NoCacheMiddleware,
RateLimitMiddleware,
)
from openhands.server.routes.auth import app as auth_api_router
from openhands.server.routes.conversation import app as conversation_api_router
from openhands.server.routes.feedback import app as feedback_api_router
from openhands.server.routes.files import app as files_api_router
from openhands.server.routes.public import app as public_api_router
from openhands.server.routes.security import app as security_api_router
app = FastAPI()
app.add_middleware(
LocalhostCORSMiddleware,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
app.add_middleware(NoCacheMiddleware)
app.add_middleware(
RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=10, seconds=1)
)
@app.get('/health')
async def health():
return 'OK'
app.include_router(auth_api_router)
app.include_router(public_api_router)
app.include_router(files_api_router)
app.include_router(conversation_api_router)
app.include_router(security_api_router)
app.include_router(feedback_api_router)
app.middleware('http')(AttachSessionMiddleware(app, target_router=files_api_router))
app.middleware('http')(
AttachSessionMiddleware(app, target_router=conversation_api_router)
)
app.middleware('http')(AttachSessionMiddleware(app, target_router=security_api_router))
app.middleware('http')(AttachSessionMiddleware(app, target_router=feedback_api_router))

View File

@@ -1,3 +0,0 @@
from openhands.server.auth.auth import get_sid_from_token, sign_token
__all__ = ['get_sid_from_token', 'sign_token']

View File

@@ -0,0 +1,111 @@
import os
import re
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.server.shared import config as shared_config
FILES_TO_IGNORE = [
'.git/',
'.DS_Store',
'node_modules/',
'__pycache__/',
]
def sanitize_filename(filename):
"""Sanitize the filename to prevent directory traversal"""
# Remove any directory components
filename = os.path.basename(filename)
# Remove any non-alphanumeric characters except for .-_
filename = re.sub(r'[^\w\-_\.]', '', filename)
# Limit the filename length
max_length = 255
if len(filename) > max_length:
name, ext = os.path.splitext(filename)
filename = name[: max_length - len(ext)] + ext
return filename
def load_file_upload_config(
config: AppConfig = shared_config,
) -> tuple[int, bool, list[str]]:
"""Load file upload configuration from the config object.
This function retrieves the file upload settings from the global config object.
It handles the following settings:
- Maximum file size for uploads
- Whether to restrict file types
- List of allowed file extensions
It also performs sanity checks on the values to ensure they are valid and safe.
Returns:
tuple: A tuple containing:
- max_file_size_mb (int): Maximum file size in MB. 0 means no limit.
- restrict_file_types (bool): Whether file type restrictions are enabled.
- allowed_extensions (set): Set of allowed file extensions.
"""
# Retrieve values from config
max_file_size_mb = config.file_uploads_max_file_size_mb
restrict_file_types = config.file_uploads_restrict_file_types
allowed_extensions = config.file_uploads_allowed_extensions
# Sanity check for max_file_size_mb
if not isinstance(max_file_size_mb, int) or max_file_size_mb < 0:
logger.warning(
f'Invalid max_file_size_mb: {max_file_size_mb}. Setting to 0 (no limit).'
)
max_file_size_mb = 0
# Sanity check for allowed_extensions
if not isinstance(allowed_extensions, (list, set)) or not allowed_extensions:
logger.warning(
f'Invalid allowed_extensions: {allowed_extensions}. Setting to [".*"].'
)
allowed_extensions = ['.*']
else:
# Ensure all extensions start with a dot and are lowercase
allowed_extensions = [
ext.lower() if ext.startswith('.') else f'.{ext.lower()}'
for ext in allowed_extensions
]
# If restrictions are disabled, allow all
if not restrict_file_types:
allowed_extensions = ['.*']
logger.debug(
f'File upload config: max_size={max_file_size_mb}MB, '
f'restrict_types={restrict_file_types}, '
f'allowed_extensions={allowed_extensions}'
)
return max_file_size_mb, restrict_file_types, allowed_extensions
# Load configuration
MAX_FILE_SIZE_MB, RESTRICT_FILE_TYPES, ALLOWED_EXTENSIONS = load_file_upload_config()
def is_extension_allowed(filename):
"""Check if the file extension is allowed based on the current configuration.
This function supports wildcards and files without extensions.
The check is case-insensitive for extensions.
Args:
filename (str): The name of the file to check.
Returns:
bool: True if the file extension is allowed, False otherwise.
"""
if not RESTRICT_FILE_TYPES:
return True
file_ext = os.path.splitext(filename)[1].lower() # Convert to lowercase
return (
'.*' in ALLOWED_EXTENSIONS
or file_ext in (ext.lower() for ext in ALLOWED_EXTENSIONS)
or (file_ext == '' and '.' in ALLOWED_EXTENSIONS)
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,21 @@
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Callable
from urllib.parse import urlparse
from fastapi import Request
import jwt
from fastapi import APIRouter, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from openhands.core.logger import openhands_logger as logger
from openhands.server.auth import get_sid_from_token
from openhands.server.github_utils import UserVerifier
from openhands.server.shared import config, session_manager
class LocalhostCORSMiddleware(CORSMiddleware):
"""
@@ -100,3 +107,71 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
headers={'Retry-After': '1'},
)
return await call_next(request)
class AttachSessionMiddleware:
def __init__(self, app, target_router: APIRouter):
self.app = app
self.target_router = target_router
self.target_paths = {route.path for route in target_router.routes}
async def __call__(self, request: Request, call_next: Callable):
do_attach = False
if request.url.path in self.target_paths:
do_attach = True
if request.method == 'OPTIONS':
do_attach = False
if not do_attach:
return await call_next(request)
user_verifier = UserVerifier()
if user_verifier.is_active():
signed_token = request.cookies.get('github_auth')
if not signed_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Not authenticated'},
)
try:
jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256'])
except Exception as e:
logger.warning(f'Invalid token: {e}')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Invalid token'},
)
if not request.headers.get('Authorization'):
logger.warning('Missing Authorization header')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Missing Authorization header'},
)
auth_token = request.headers.get('Authorization')
if 'Bearer' in auth_token:
auth_token = auth_token.split('Bearer')[1].strip()
request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
if request.state.sid == '':
logger.warning('Invalid token')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Invalid token'},
)
request.state.conversation = await session_manager.attach_to_conversation(
request.state.sid
)
if request.state.conversation is None:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Session not found'},
)
try:
response = await call_next(request)
finally:
await session_manager.detach_from_conversation(request.state.conversation)
return response

View File

@@ -0,0 +1,100 @@
import time
import warnings
import requests
from openhands.server.github_utils import (
GITHUB_CLIENT_ID,
GITHUB_CLIENT_SECRET,
authenticate_github_user,
)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
from fastapi import (
APIRouter,
Request,
status,
)
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.server.auth import sign_token
from openhands.server.shared import config
app = APIRouter(prefix='/api')
class AuthCode(BaseModel):
code: str
@app.post('/github/callback')
def github_callback(auth_code: AuthCode):
# Prepare data for the token exchange request
data = {
'client_id': GITHUB_CLIENT_ID,
'client_secret': GITHUB_CLIENT_SECRET,
'code': auth_code.code,
}
logger.debug('Exchanging code for GitHub token')
headers = {'Accept': 'application/json'}
response = requests.post(
'https://github.com/login/oauth/access_token', data=data, headers=headers
)
if response.status_code != 200:
logger.error(f'Failed to exchange code for token: {response.text}')
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={'error': 'Failed to exchange code for token'},
)
token_response = response.json()
if 'access_token' not in token_response:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={'error': 'No access token in response'},
)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={'access_token': token_response['access_token']},
)
@app.post('/authenticate')
async def authenticate(request: Request):
token = request.headers.get('X-GitHub-Token')
if not await authenticate_github_user(token):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Not authorized via GitHub waitlist'},
)
# Create a signed JWT token with 1-hour expiration
cookie_data = {
'github_token': token,
'exp': int(time.time()) + 3600, # 1 hour expiration
}
signed_token = sign_token(cookie_data, config.jwt_secret)
response = JSONResponse(
status_code=status.HTTP_200_OK, content={'message': 'User authenticated'}
)
# Set secure cookie with signed token
response.set_cookie(
key='github_auth',
value=signed_token,
max_age=3600, # 1 hour in seconds
httponly=True,
secure=True,
samesite='strict',
)
return response

View File

@@ -0,0 +1,106 @@
from fastapi import APIRouter, HTTPException, Request, status
from fastapi.responses import JSONResponse
from openhands.core.logger import openhands_logger as logger
from openhands.runtime.base import Runtime
app = APIRouter(prefix='/api')
@app.get('/conversation')
async def get_remote_runtime_config(request: Request):
"""Retrieve the runtime configuration.
Currently, this is the session ID and runtime ID (if available).
"""
runtime = request.state.conversation.runtime
runtime_id = runtime.runtime_id if hasattr(runtime, 'runtime_id') else None
session_id = runtime.sid if hasattr(runtime, 'sid') else None
return JSONResponse(
content={
'runtime_id': runtime_id,
'session_id': session_id,
}
)
@app.get('/vscode-url')
async def get_vscode_url(request: Request):
"""Get the VSCode URL.
This endpoint allows getting the VSCode URL.
Args:
request (Request): The incoming FastAPI request object.
Returns:
JSONResponse: A JSON response indicating the success of the operation.
"""
try:
runtime: Runtime = request.state.conversation.runtime
logger.debug(f'Runtime type: {type(runtime)}')
logger.debug(f'Runtime VSCode URL: {runtime.vscode_url}')
return JSONResponse(status_code=200, content={'vscode_url': runtime.vscode_url})
except Exception as e:
logger.error(f'Error getting VSCode URL: {e}', exc_info=True)
return JSONResponse(
status_code=500,
content={
'vscode_url': None,
'error': f'Error getting VSCode URL: {e}',
},
)
@app.get('/events/search')
async def search_events(
request: Request,
query: str | None = None,
start_id: int = 0,
limit: int = 20,
event_type: str | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
):
"""Search through the event stream with filtering and pagination.
Args:
request (Request): The incoming request object
query (str, optional): Text to search for in event content
start_id (int): Starting ID in the event stream. Defaults to 0
limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 20
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)
Returns:
dict: Dictionary containing:
- events: List of matching events
- has_more: Whether there are more matching events after this batch
Raises:
HTTPException: If conversation is not found
ValueError: If limit is less than 1 or greater than 100
"""
if not request.state.conversation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail='Conversation not found'
)
# Get matching events from the stream
event_stream = request.state.conversation.event_stream
matching_events = event_stream.get_matching_events(
query=query,
event_type=event_type,
source=source,
start_date=start_date,
end_date=end_date,
start_id=start_id,
limit=limit + 1, # Get one extra to check if there are more
)
# Check if there are more events
has_more = len(matching_events) > limit
if has_more:
matching_events = matching_events[:limit] # Remove the extra event
return {
'events': matching_events,
'has_more': has_more,
}

View File

@@ -0,0 +1,74 @@
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from openhands.core.logger import openhands_logger as logger
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
from openhands.server.shared import config
from openhands.utils.async_utils import call_sync_from_async
app = APIRouter(prefix='/api')
@app.post('/submit-feedback')
async def submit_feedback(request: Request):
"""Submit user feedback.
This function stores the provided feedback data.
To submit feedback:
```sh
curl -X POST -d '{"email": "test@example.com"}' -H "Authorization:"
```
Args:
request (Request): The incoming request object.
feedback (FeedbackDataModel): The feedback data to be stored.
Returns:
dict: The stored feedback data.
Raises:
HTTPException: If there's an error submitting the feedback.
"""
# Assuming the storage service is already configured in the backend
# and there is a function to handle the storage.
body = await request.json()
async_stream = AsyncEventStreamWrapper(
request.state.conversation.event_stream, filter_hidden=True
)
trajectory = []
async for event in async_stream:
trajectory.append(event_to_dict(event))
feedback = FeedbackDataModel(
email=body.get('email', ''),
version=body.get('version', ''),
permissions=body.get('permissions', 'private'),
polarity=body.get('polarity', ''),
feedback=body.get('polarity', ''),
trajectory=trajectory,
)
try:
feedback_data = await call_sync_from_async(store_feedback, feedback)
return JSONResponse(status_code=200, content=feedback_data)
except Exception as e:
logger.error(f'Error submitting feedback: {e}')
return JSONResponse(
status_code=500, content={'error': 'Failed to submit feedback'}
)
@app.get('/api/defaults')
async def appconfig_defaults():
"""Retrieve the default configuration settings.
To get the default configurations:
```sh
curl http://localhost:3000/api/defaults
```
Returns:
dict: The default configuration settings.
"""
return config.defaults_dict

View File

@@ -0,0 +1,341 @@
import os
import tempfile
from fastapi import (
APIRouter,
BackgroundTasks,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi.responses import FileResponse, JSONResponse
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import (
FileReadAction,
FileWriteAction,
)
from openhands.events.observation import (
ErrorObservation,
FileReadObservation,
FileWriteObservation,
)
from openhands.runtime.base import Runtime, RuntimeUnavailableError
from openhands.server.file_config import (
FILES_TO_IGNORE,
MAX_FILE_SIZE_MB,
is_extension_allowed,
sanitize_filename,
)
from openhands.utils.async_utils import call_sync_from_async
app = APIRouter(prefix='/api')
@app.get('/list-files')
async def list_files(request: Request, path: str | None = None):
"""List files in the specified path.
This function retrieves a list of files from the agent's runtime file store,
excluding certain system and hidden files/directories.
To list files:
```sh
curl http://localhost:3000/api/list-files
```
Args:
request (Request): The incoming request object.
path (str, optional): The path to list files from. Defaults to None.
Returns:
list: A list of file names in the specified path.
Raises:
HTTPException: If there's an error listing the files.
"""
if not request.state.conversation.runtime:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Runtime not yet initialized'},
)
runtime: Runtime = request.state.conversation.runtime
try:
file_list = await call_sync_from_async(runtime.list_files, path)
except RuntimeUnavailableError as e:
logger.error(f'Error listing files: {e}', exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': f'Error listing files: {e}'},
)
if path:
file_list = [os.path.join(path, f) for f in file_list]
file_list = [f for f in file_list if f not in FILES_TO_IGNORE]
async def filter_for_gitignore(file_list, base_path):
gitignore_path = os.path.join(base_path, '.gitignore')
try:
read_action = FileReadAction(gitignore_path)
observation = await call_sync_from_async(runtime.run_action, read_action)
spec = PathSpec.from_lines(
GitWildMatchPattern, observation.content.splitlines()
)
except Exception as e:
logger.warning(e)
return file_list
file_list = [entry for entry in file_list if not spec.match_file(entry)]
return file_list
try:
file_list = await filter_for_gitignore(file_list, '')
except RuntimeUnavailableError as e:
logger.error(f'Error filtering files: {e}', exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': f'Error filtering files: {e}'},
)
return file_list
@app.get('/select-file')
async def select_file(file: str, request: Request):
"""Retrieve the content of a specified file.
To select a file:
```sh
curl http://localhost:3000/api/select-file?file=<file_path>
```
Args:
file (str): The path of the file to be retrieved.
Expect path to be absolute inside the runtime.
request (Request): The incoming request object.
Returns:
dict: A dictionary containing the file content.
Raises:
HTTPException: If there's an error opening the file.
"""
runtime: Runtime = request.state.conversation.runtime
file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
read_action = FileReadAction(file)
try:
observation = await call_sync_from_async(runtime.run_action, read_action)
except RuntimeUnavailableError as e:
logger.error(f'Error opening file {file}: {e}', exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': f'Error opening file: {e}'},
)
if isinstance(observation, FileReadObservation):
content = observation.content
return {'code': content}
elif isinstance(observation, ErrorObservation):
logger.error(f'Error opening file {file}: {observation}', exc_info=False)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': f'Error opening file: {observation}'},
)
@app.post('/upload-files')
async def upload_file(request: Request, files: list[UploadFile]):
"""Upload a list of files to the workspace.
To upload a files:
```sh
curl -X POST -F "file=@<file_path1>" -F "file=@<file_path2>" http://localhost:3000/api/upload-files
```
Args:
request (Request): The incoming request object.
files (list[UploadFile]): A list of files to be uploaded.
Returns:
dict: A message indicating the success of the upload operation.
Raises:
HTTPException: If there's an error saving the files.
"""
try:
uploaded_files = []
skipped_files = []
for file in files:
safe_filename = sanitize_filename(file.filename)
file_contents = await file.read()
if (
MAX_FILE_SIZE_MB > 0
and len(file_contents) > MAX_FILE_SIZE_MB * 1024 * 1024
):
skipped_files.append(
{
'name': safe_filename,
'reason': f'Exceeds maximum size limit of {MAX_FILE_SIZE_MB}MB',
}
)
continue
if not is_extension_allowed(safe_filename):
skipped_files.append(
{'name': safe_filename, 'reason': 'File type not allowed'}
)
continue
# copy the file to the runtime
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_file_path = os.path.join(tmp_dir, safe_filename)
with open(tmp_file_path, 'wb') as tmp_file:
tmp_file.write(file_contents)
tmp_file.flush()
runtime: Runtime = request.state.conversation.runtime
try:
await call_sync_from_async(
runtime.copy_to,
tmp_file_path,
runtime.config.workspace_mount_path_in_sandbox,
)
except RuntimeUnavailableError as e:
logger.error(
f'Error saving file {safe_filename}: {e}', exc_info=True
)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': f'Error saving file: {e}'},
)
uploaded_files.append(safe_filename)
response_content = {
'message': 'File upload process completed',
'uploaded_files': uploaded_files,
'skipped_files': skipped_files,
}
if not uploaded_files and skipped_files:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
**response_content,
'error': 'No files were uploaded successfully',
},
)
return JSONResponse(status_code=status.HTTP_200_OK, content=response_content)
except Exception as e:
logger.error(f'Error during file upload: {e}', exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
'error': f'Error during file upload: {str(e)}',
'uploaded_files': [],
'skipped_files': [],
},
)
@app.post('/save-file')
async def save_file(request: Request):
"""Save a file to the agent's runtime file store.
This endpoint allows saving a file when the agent is in a paused, finished,
or awaiting user input state. It checks the agent's state before proceeding
with the file save operation.
Args:
request (Request): The incoming FastAPI request object.
Returns:
JSONResponse: A JSON response indicating the success of the operation.
Raises:
HTTPException:
- 403 error if the agent is not in an allowed state for editing.
- 400 error if the file path or content is missing.
- 500 error if there's an unexpected error during the save operation.
"""
try:
# Extract file path and content from the request
data = await request.json()
file_path = data.get('filePath')
content = data.get('content')
# Validate the presence of required data
if not file_path or content is None:
raise HTTPException(status_code=400, detail='Missing filePath or content')
# Save the file to the agent's runtime file store
runtime: Runtime = request.state.conversation.runtime
file_path = os.path.join(
runtime.config.workspace_mount_path_in_sandbox, file_path
)
write_action = FileWriteAction(file_path, content)
try:
observation = await call_sync_from_async(runtime.run_action, write_action)
except RuntimeUnavailableError as e:
logger.error(f'Error saving file: {e}', exc_info=True)
return JSONResponse(
status_code=500,
content={'error': f'Error saving file: {e}'},
)
if isinstance(observation, FileWriteObservation):
return JSONResponse(
status_code=200, content={'message': 'File saved successfully'}
)
elif isinstance(observation, ErrorObservation):
return JSONResponse(
status_code=500,
content={'error': f'Failed to save file: {observation}'},
)
else:
return JSONResponse(
status_code=500,
content={'error': f'Unexpected observation: {observation}'},
)
except Exception as e:
# Log the error and return a 500 response
logger.error(f'Error saving file: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=f'Error saving file: {e}')
@app.get('/zip-directory')
async def zip_current_workspace(request: Request, background_tasks: BackgroundTasks):
try:
logger.debug('Zipping workspace')
runtime: Runtime = request.state.conversation.runtime
path = runtime.config.workspace_mount_path_in_sandbox
try:
zip_file = await call_sync_from_async(runtime.copy_from, path)
except RuntimeUnavailableError as e:
logger.error(f'Error zipping workspace: {e}', exc_info=True)
return JSONResponse(
status_code=500,
content={'error': f'Error zipping workspace: {e}'},
)
response = FileResponse(
path=zip_file,
filename='workspace.zip',
media_type='application/x-zip-compressed',
)
# This will execute after the response is sent (So the file is not deleted before being sent)
background_tasks.add_task(zip_file.unlink)
return response
except Exception as e:
logger.error(f'Error zipping workspace: {e}', exc_info=True)
raise HTTPException(
status_code=500,
detail='Failed to zip workspace',
)

View File

@@ -0,0 +1,106 @@
import warnings
import requests
from openhands.security.options import SecurityAnalyzers
with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
from fastapi import (
APIRouter,
)
from openhands.controller.agent import Agent
from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.llm import bedrock
from openhands.server.shared import config
app = APIRouter(prefix='/api/options')
@app.get('/models')
async def get_litellm_models() -> list[str]:
"""
Get all models supported by LiteLLM.
This function combines models from litellm and Bedrock, removing any
error-prone Bedrock models.
To get the models:
```sh
curl http://localhost:3000/api/litellm-models
```
Returns:
list: A sorted list of unique model names.
"""
litellm_model_list = litellm.model_list + list(litellm.model_cost.keys())
litellm_model_list_without_bedrock = bedrock.remove_error_modelId(
litellm_model_list
)
# TODO: for bedrock, this is using the default config
llm_config: LLMConfig = config.get_llm_config()
bedrock_model_list = []
if (
llm_config.aws_region_name
and llm_config.aws_access_key_id
and llm_config.aws_secret_access_key
):
bedrock_model_list = bedrock.list_foundation_models(
llm_config.aws_region_name,
llm_config.aws_access_key_id,
llm_config.aws_secret_access_key,
)
model_list = litellm_model_list_without_bedrock + bedrock_model_list
for llm_config in config.llms.values():
ollama_base_url = llm_config.ollama_base_url
if llm_config.model.startswith('ollama'):
if not ollama_base_url:
ollama_base_url = llm_config.base_url
if ollama_base_url:
ollama_url = ollama_base_url.strip('/') + '/api/tags'
try:
ollama_models_list = requests.get(ollama_url, timeout=3).json()[
'models'
]
for model in ollama_models_list:
model_list.append('ollama/' + model['name'])
break
except requests.exceptions.RequestException as e:
logger.error(f'Error getting OLLAMA models: {e}', exc_info=True)
return list(sorted(set(model_list)))
@app.get('/agents')
async def get_agents():
"""Get all agents supported by LiteLLM.
To get the agents:
```sh
curl http://localhost:3000/api/agents
```
Returns:
list: A sorted list of agent names.
"""
agents = sorted(Agent.list_agents())
return agents
@app.get('/security-analyzers')
async def get_security_analyzers():
"""Get all supported security analyzers.
To get the security analyzers:
```sh
curl http://localhost:3000/api/security-analyzers
```
Returns:
list: A sorted list of security analyzer names.
"""
return sorted(SecurityAnalyzers.keys())

View File

@@ -0,0 +1,30 @@
from fastapi import (
APIRouter,
HTTPException,
Request,
)
app = APIRouter(prefix='/api')
@app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
async def security_api(request: Request):
"""Catch-all route for security analyzer API requests.
Each request is handled directly to the security analyzer.
Args:
request (Request): The incoming FastAPI request object.
Returns:
Any: The response from the security analyzer.
Raises:
HTTPException: If the security analyzer is not initialized.
"""
if not request.state.conversation.security_analyzer:
raise HTTPException(status_code=404, detail='Security analyzer not initialized')
return await request.state.conversation.security_analyzer.handle_api_request(
request
)

View File

@@ -0,0 +1,28 @@
import os
import socketio
from dotenv import load_dotenv
from openhands.core.config import load_app_config
from openhands.server.session import SessionManager
from openhands.storage import get_file_store
load_dotenv()
config = load_app_config()
file_store = get_file_store(config.file_store, config.file_store_path)
client_manager = None
redis_host = os.environ.get('REDIS_HOST')
if redis_host:
client_manager = socketio.AsyncRedisManager(
f'redis://{redis_host}',
redis_options={'password': os.environ.get('REDIS_PASSWORD')},
)
sio = socketio.AsyncServer(
async_mode='asgi', cors_allowed_origins='*', client_manager=client_manager
)
session_manager = SessionManager(sio, config, file_store)

View File

@@ -0,0 +1,76 @@
from fastapi import status
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.action import ActionType
from openhands.events.action import (
NullAction,
)
from openhands.events.observation import (
NullObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.auth import get_sid_from_token, sign_token
from openhands.server.github_utils import authenticate_github_user
from openhands.server.shared import config, session_manager, sio
@sio.event
async def connect(connection_id: str, environ):
logger.info(f'sio:connect: {connection_id}')
@sio.event
async def oh_action(connection_id: str, data: dict):
# If it's an init, we do it here.
action = data.get('action', '')
if action == ActionType.INIT:
await init_connection(connection_id, data)
return
logger.info(f'sio:oh_action:{connection_id}')
await session_manager.send_to_event_stream(connection_id, data)
async def init_connection(connection_id: str, data: dict):
gh_token = data.pop('github_token', None)
if not await authenticate_github_user(gh_token):
raise RuntimeError(status.WS_1008_POLICY_VIOLATION)
token = data.pop('token', None)
if token:
sid = get_sid_from_token(token, config.jwt_secret)
if sid == '':
await sio.send({'error': 'Invalid token', 'error_code': 401})
return
logger.info(f'Existing session: {sid}')
else:
sid = connection_id
logger.info(f'New session: {sid}')
token = sign_token({'sid': sid}, config.jwt_secret)
await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id)
latest_event_id = int(data.pop('latest_event_id', -1))
# The session in question should exist, but may not actually be running locally...
event_stream = await session_manager.init_or_join_session(sid, connection_id, data)
# Send events
async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1)
async for event in async_stream:
if isinstance(
event,
(
NullAction,
NullObservation,
),
):
continue
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
@sio.event
async def disconnect(connection_id: str):
logger.info(f'sio:disconnect:{connection_id}')
await session_manager.disconnect_from_session(connection_id)

View File

@@ -0,0 +1,10 @@
from fastapi.staticfiles import StaticFiles
class SPAStaticFiles(StaticFiles):
async def get_response(self, path: str, scope):
try:
return await super().get_response(path, scope)
except Exception:
# FIXME: just making this HTTPException doesn't work for some reason
return await super().get_response('index.html', scope)

View File

@@ -19,7 +19,10 @@ class MockStaticFiles:
with patch('openhands.server.session.SessionManager', MockSessionManager), patch(
'fastapi.staticfiles.StaticFiles', MockStaticFiles
):
from openhands.server.listen import is_extension_allowed, load_file_upload_config
from openhands.server.file_config import (
is_extension_allowed,
load_file_upload_config,
)
def test_load_file_upload_config():
@@ -28,12 +31,11 @@ def test_load_file_upload_config():
file_uploads_restrict_file_types=True,
file_uploads_allowed_extensions=['.txt', '.pdf'],
)
with patch('openhands.server.listen.config', config):
max_size, restrict_types, allowed_extensions = load_file_upload_config()
max_size, restrict_types, allowed_extensions = load_file_upload_config(config)
assert max_size == 10
assert restrict_types is True
assert set(allowed_extensions) == {'.txt', '.pdf'}
assert max_size == 10
assert restrict_types is True
assert set(allowed_extensions) == {'.txt', '.pdf'}
def test_load_file_upload_config_invalid_max_size():
@@ -42,7 +44,7 @@ def test_load_file_upload_config_invalid_max_size():
file_uploads_restrict_file_types=False,
file_uploads_allowed_extensions=[],
)
with patch('openhands.server.listen.config', config):
with patch('openhands.server.shared.config', config):
max_size, restrict_types, allowed_extensions = load_file_upload_config()
assert max_size == 0 # Should default to 0 when invalid
@@ -51,8 +53,8 @@ def test_load_file_upload_config_invalid_max_size():
def test_is_extension_allowed():
with patch('openhands.server.listen.RESTRICT_FILE_TYPES', True), patch(
'openhands.server.listen.ALLOWED_EXTENSIONS', ['.txt', '.pdf']
with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), patch(
'openhands.server.file_config.ALLOWED_EXTENSIONS', ['.txt', '.pdf']
):
assert is_extension_allowed('file.txt')
assert is_extension_allowed('file.pdf')
@@ -61,7 +63,7 @@ def test_is_extension_allowed():
def test_is_extension_allowed_no_restrictions():
with patch('openhands.server.listen.RESTRICT_FILE_TYPES', False):
with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', False):
assert is_extension_allowed('file.txt')
assert is_extension_allowed('file.pdf')
assert is_extension_allowed('file.doc')
@@ -69,8 +71,8 @@ def test_is_extension_allowed_no_restrictions():
def test_is_extension_allowed_wildcard():
with patch('openhands.server.listen.RESTRICT_FILE_TYPES', True), patch(
'openhands.server.listen.ALLOWED_EXTENSIONS', ['.*']
with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), patch(
'openhands.server.file_config.ALLOWED_EXTENSIONS', ['.*']
):
assert is_extension_allowed('file.txt')
assert is_extension_allowed('file.pdf')