mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Rename OpenDevin to OpenHands (#3472)
* Replace OpenDevin with OpenHands * Update CONTRIBUTING.md * Update README.md * Update README.md * update poetry lock; move opendevin folder to openhands * fix env var * revert image references in docs * revert permissions * revert permissions --------- Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
This commit is contained in:
102
openhands/server/README.md
Normal file
102
openhands/server/README.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# OpenHands Server
|
||||
|
||||
This is a WebSocket server that executes tasks using an agent.
|
||||
|
||||
## Recommended Prerequisites
|
||||
|
||||
- [Initialize the frontend code](../../frontend/README.md)
|
||||
- Install Python 3.12 (`brew install python` for those using homebrew)
|
||||
- Install pipx: (`brew install pipx` followed by `pipx ensurepath`)
|
||||
- Install poetry: (`pipx install poetry`)
|
||||
|
||||
## Install
|
||||
|
||||
First build a distribution of the frontend code (From the project root directory):
|
||||
```
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
cd ..
|
||||
```
|
||||
Next run `poetry shell` (So you don't have to repeat `poetry run`)
|
||||
|
||||
## Start the Server
|
||||
|
||||
```sh
|
||||
uvicorn openhands.server.listen:app --reload --port 3000
|
||||
```
|
||||
|
||||
## Test the Server
|
||||
|
||||
You can use [`websocat`](https://github.com/vi/websocat) to test the server.
|
||||
|
||||
```sh
|
||||
websocat ws://127.0.0.1:3000/ws
|
||||
{"action": "start", "args": {"task": "write a bash script that prints hello"}}
|
||||
```
|
||||
|
||||
## Supported Environment Variables
|
||||
|
||||
```sh
|
||||
LLM_API_KEY=sk-... # Your OpenAI API Key
|
||||
LLM_MODEL=gpt-4o # Default model for the agent to use
|
||||
WORKSPACE_BASE=/path/to/your/workspace # Default absolute path to workspace
|
||||
```
|
||||
|
||||
## API Schema
|
||||
|
||||
There are two types of messages that can be sent to, or received from, the server:
|
||||
|
||||
* Actions
|
||||
* Observations
|
||||
|
||||
### Actions
|
||||
|
||||
An action has three parts:
|
||||
|
||||
* `action`: The action to be taken
|
||||
* `args`: The arguments for the action
|
||||
* `message`: A friendly message that can be put in the chat log
|
||||
|
||||
There are several kinds of actions. Their arguments are listed below.
|
||||
This list may grow over time.
|
||||
|
||||
* `initialize` - initializes the agent. Only sent by client.
|
||||
* `model` - the name of the model to use
|
||||
* `directory` - the path to the workspace
|
||||
* `agent_cls` - the class of the agent to use
|
||||
* `start` - starts a new development task. Only sent by the client.
|
||||
* `task` - the task to start
|
||||
* `read` - reads the content of a file.
|
||||
* `path` - the path of the file to read
|
||||
* `write` - writes the content to a file.
|
||||
* `path` - the path of the file to write
|
||||
* `content` - the content to write to the file
|
||||
* `run` - runs a command.
|
||||
* `command` - the command to run
|
||||
* `browse` - opens a web page.
|
||||
* `url` - the URL to open
|
||||
* `think` - Allows the agent to make a plan, set a goal, or record thoughts
|
||||
* `thought` - the thought to record
|
||||
* `finish` - agent signals that the task is completed
|
||||
|
||||
### Observations
|
||||
|
||||
An observation has four parts:
|
||||
|
||||
* `observation`: The observation type
|
||||
* `content`: A string representing the observed data
|
||||
* `extras`: additional structured data
|
||||
* `message`: A friendly message that can be put in the chat log
|
||||
|
||||
There are several kinds of observations. Their extras are listed below.
|
||||
This list may grow over time.
|
||||
|
||||
* `read` - the content of a file
|
||||
* `path` - the path of the file read
|
||||
* `browse` - the HTML content of a url
|
||||
* `url` - the URL opened
|
||||
* `run` - the output of a command
|
||||
* `command` - the command run
|
||||
* `exit_code` - the exit code of the command
|
||||
* `chat` - a message from the user
|
||||
0
openhands/server/__init__.py
Normal file
0
openhands/server/__init__.py
Normal file
3
openhands/server/auth/__init__.py
Normal file
3
openhands/server/auth/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .auth import get_sid_from_token, sign_token
|
||||
|
||||
__all__ = ['get_sid_from_token', 'sign_token']
|
||||
39
openhands/server/auth/auth.py
Normal file
39
openhands/server/auth/auth.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import jwt
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def get_sid_from_token(token: str, jwt_secret: str) -> str:
|
||||
"""Retrieves the session id from a JWT token.
|
||||
|
||||
Parameters:
|
||||
token (str): The JWT token from which the session id is to be extracted.
|
||||
|
||||
Returns:
|
||||
str: The session id if found and valid, otherwise an empty string.
|
||||
"""
|
||||
try:
|
||||
# Decode the JWT using the specified secret and algorithm
|
||||
payload = jwt.decode(token, jwt_secret, algorithms=['HS256'])
|
||||
|
||||
# Ensure the payload contains 'sid'
|
||||
if 'sid' in payload:
|
||||
return payload['sid']
|
||||
else:
|
||||
logger.error('SID not found in token')
|
||||
return ''
|
||||
except InvalidTokenError:
|
||||
logger.error('Invalid token')
|
||||
except Exception as e:
|
||||
logger.exception('Unexpected error decoding token: %s', e)
|
||||
return ''
|
||||
|
||||
|
||||
def sign_token(payload: dict[str, object], jwt_secret: str) -> str:
|
||||
"""Signs a JWT token."""
|
||||
# payload = {
|
||||
# "sid": sid,
|
||||
# # "exp": datetime.now(timezone.utc) + timedelta(minutes=15),
|
||||
# }
|
||||
return jwt.encode(payload, jwt_secret, algorithm='HS256')
|
||||
42
openhands/server/data_models/feedback.py
Normal file
42
openhands/server/data_models/feedback.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class FeedbackDataModel(BaseModel):
|
||||
version: str
|
||||
email: str
|
||||
token: str
|
||||
feedback: Literal['positive', 'negative']
|
||||
permissions: Literal['public', 'private']
|
||||
trajectory: list[dict[str, Any]]
|
||||
|
||||
|
||||
FEEDBACK_URL = 'https://share-od-trajectory-3u9bw9tx.uc.gateway.dev/share_od_trajectory'
|
||||
|
||||
|
||||
def store_feedback(feedback: FeedbackDataModel) -> dict[str, str]:
|
||||
# Start logging
|
||||
display_feedback = feedback.model_dump()
|
||||
if 'trajectory' in display_feedback:
|
||||
display_feedback['trajectory'] = (
|
||||
f"elided [length: {len(display_feedback['trajectory'])}"
|
||||
)
|
||||
if 'token' in display_feedback:
|
||||
display_feedback['token'] = 'elided'
|
||||
logger.info(f'Got feedback: {display_feedback}')
|
||||
# Start actual request
|
||||
response = requests.post(
|
||||
FEEDBACK_URL,
|
||||
headers={'Content-Type': 'application/json'},
|
||||
json=feedback.model_dump(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'Failed to store feedback: {response.text}')
|
||||
response_data = json.loads(response.text)
|
||||
logger.info(f'Stored feedback: {response.text}')
|
||||
return response_data
|
||||
736
openhands/server/listen.py
Normal file
736
openhands/server/listen.py
Normal file
@@ -0,0 +1,736 @@
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
import requests
|
||||
|
||||
from openhands.security.options import SecurityAnalyzers
|
||||
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
import litellm
|
||||
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
WebSocket,
|
||||
status,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import HTTPBearer
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
import agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import LLMConfig, load_app_config
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState # Add this import
|
||||
from openhands.events.action import (
|
||||
ChangeAgentStateAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
ErrorObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import bedrock
|
||||
from openhands.runtime.runtime import Runtime
|
||||
from openhands.server.auth import get_sid_from_token, sign_token
|
||||
from openhands.server.session import SessionManager
|
||||
|
||||
config = load_app_config()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
session_manager = SessionManager(config, file_store)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=['http://localhost:3001'],
|
||||
allow_credentials=True,
|
||||
allow_methods=['*'],
|
||||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
security_scheme = HTTPBearer()
|
||||
|
||||
|
||||
def load_file_upload_config() -> tuple[int, bool, list[str]]:
|
||||
"""Load file upload configuration from the config object.
|
||||
|
||||
This function retrieves the file upload settings from the global config object.
|
||||
It handles the following settings:
|
||||
- Maximum file size for uploads
|
||||
- Whether to restrict file types
|
||||
- List of allowed file extensions
|
||||
|
||||
It also performs sanity checks on the values to ensure they are valid and safe.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- max_file_size_mb (int): Maximum file size in MB. 0 means no limit.
|
||||
- restrict_file_types (bool): Whether file type restrictions are enabled.
|
||||
- allowed_extensions (set): Set of allowed file extensions.
|
||||
"""
|
||||
# Retrieve values from config
|
||||
max_file_size_mb = config.file_uploads_max_file_size_mb
|
||||
restrict_file_types = config.file_uploads_restrict_file_types
|
||||
allowed_extensions = config.file_uploads_allowed_extensions
|
||||
|
||||
# Sanity check for max_file_size_mb
|
||||
if not isinstance(max_file_size_mb, int) or max_file_size_mb < 0:
|
||||
logger.warning(
|
||||
f'Invalid max_file_size_mb: {max_file_size_mb}. Setting to 0 (no limit).'
|
||||
)
|
||||
max_file_size_mb = 0
|
||||
|
||||
# Sanity check for allowed_extensions
|
||||
if not isinstance(allowed_extensions, (list, set)) or not allowed_extensions:
|
||||
logger.warning(
|
||||
f'Invalid allowed_extensions: {allowed_extensions}. Setting to [".*"].'
|
||||
)
|
||||
allowed_extensions = ['.*']
|
||||
else:
|
||||
# Ensure all extensions start with a dot and are lowercase
|
||||
allowed_extensions = [
|
||||
ext.lower() if ext.startswith('.') else f'.{ext.lower()}'
|
||||
for ext in allowed_extensions
|
||||
]
|
||||
|
||||
# If restrictions are disabled, allow all
|
||||
if not restrict_file_types:
|
||||
allowed_extensions = ['.*']
|
||||
|
||||
logger.debug(
|
||||
f'File upload config: max_size={max_file_size_mb}MB, '
|
||||
f'restrict_types={restrict_file_types}, '
|
||||
f'allowed_extensions={allowed_extensions}'
|
||||
)
|
||||
|
||||
return max_file_size_mb, restrict_file_types, allowed_extensions
|
||||
|
||||
|
||||
# Load configuration
|
||||
MAX_FILE_SIZE_MB, RESTRICT_FILE_TYPES, ALLOWED_EXTENSIONS = load_file_upload_config()
|
||||
|
||||
|
||||
def is_extension_allowed(filename):
|
||||
"""Check if the file extension is allowed based on the current configuration.
|
||||
|
||||
This function supports wildcards and files without extensions.
|
||||
The check is case-insensitive for extensions.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the file extension is allowed, False otherwise.
|
||||
"""
|
||||
if not RESTRICT_FILE_TYPES:
|
||||
return True
|
||||
|
||||
file_ext = os.path.splitext(filename)[1].lower() # Convert to lowercase
|
||||
return (
|
||||
'.*' in ALLOWED_EXTENSIONS
|
||||
or file_ext in (ext.lower() for ext in ALLOWED_EXTENSIONS)
|
||||
or (file_ext == '' and '.' in ALLOWED_EXTENSIONS)
|
||||
)
|
||||
|
||||
|
||||
@app.middleware('http')
|
||||
async def attach_session(request: Request, call_next):
|
||||
"""Middleware to attach session information to the request.
|
||||
|
||||
This middleware checks for the Authorization header, validates the token,
|
||||
and attaches the corresponding session to the request state.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request object.
|
||||
call_next (Callable): The next middleware or route handler in the chain.
|
||||
|
||||
Returns:
|
||||
Response: The response from the next middleware or route handler.
|
||||
"""
|
||||
if request.url.path.startswith('/api/options/') or not request.url.path.startswith(
|
||||
'/api/'
|
||||
):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
if not request.headers.get('Authorization'):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Missing Authorization header'},
|
||||
)
|
||||
|
||||
auth_token = request.headers.get('Authorization')
|
||||
if 'Bearer' in auth_token:
|
||||
auth_token = auth_token.split('Bearer')[1].strip()
|
||||
|
||||
request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
|
||||
if request.state.sid == '':
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Invalid token'},
|
||||
)
|
||||
|
||||
request.state.session = session_manager.get_session(request.state.sid)
|
||||
if request.state.session is None:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'Session not found'},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@app.websocket('/ws')
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for receiving events from the client (i.e., the browser).
|
||||
Once connected, the client can send various actions:
|
||||
- Initialize the agent:
|
||||
session management, and event streaming.
|
||||
```json
|
||||
{"action": "initialize", "args": {"LLM_MODEL": "ollama/llama3", "AGENT": "CodeActAgent", "LANGUAGE": "en", "LLM_API_KEY": "ollama"}}
|
||||
|
||||
Args:
|
||||
```
|
||||
websocket (WebSocket): The WebSocket connection object.
|
||||
- Start a new development task:
|
||||
```json
|
||||
{"action": "start", "args": {"task": "write a bash script that prints hello"}}
|
||||
```
|
||||
- Send a message:
|
||||
```json
|
||||
{"action": "message", "args": {"content": "Hello, how are you?", "images_urls": ["base64_url1", "base64_url2"]}}
|
||||
```
|
||||
- Write contents to a file:
|
||||
```json
|
||||
{"action": "write", "args": {"path": "./greetings.txt", "content": "Hello, OpenHands?"}}
|
||||
```
|
||||
- Read the contents of a file:
|
||||
```json
|
||||
{"action": "read", "args": {"path": "./greetings.txt"}}
|
||||
```
|
||||
- Run a command:
|
||||
```json
|
||||
{"action": "run", "args": {"command": "ls -l", "thought": "", "is_confirmed": "confirmed"}}
|
||||
```
|
||||
- Run an IPython command:
|
||||
```json
|
||||
{"action": "run_ipython", "args": {"command": "print('Hello, IPython!')"}}
|
||||
```
|
||||
- Open a web page:
|
||||
```json
|
||||
{"action": "browse", "args": {"url": "https://arxiv.org/html/2402.01030v2"}}
|
||||
```
|
||||
- Add a task to the root_task:
|
||||
```json
|
||||
{"action": "add_task", "args": {"task": "Implement feature X"}}
|
||||
```
|
||||
- Update a task in the root_task:
|
||||
```json
|
||||
{"action": "modify_task", "args": {"id": "0", "state": "in_progress", "thought": ""}}
|
||||
```
|
||||
- Change the agent's state:
|
||||
```json
|
||||
{"action": "change_agent_state", "args": {"state": "paused"}}
|
||||
```
|
||||
- Finish the task:
|
||||
```json
|
||||
{"action": "finish", "args": {}}
|
||||
```
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
if websocket.query_params.get('token'):
|
||||
token = websocket.query_params.get('token')
|
||||
sid = get_sid_from_token(token, config.jwt_secret)
|
||||
|
||||
if sid == '':
|
||||
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
|
||||
await websocket.close()
|
||||
return
|
||||
else:
|
||||
sid = str(uuid.uuid4())
|
||||
token = sign_token({'sid': sid}, config.jwt_secret)
|
||||
|
||||
session = session_manager.add_or_restart_session(sid, websocket)
|
||||
await websocket.send_json({'token': token, 'status': 'ok'})
|
||||
|
||||
latest_event_id = -1
|
||||
if websocket.query_params.get('latest_event_id'):
|
||||
latest_event_id = int(websocket.query_params.get('latest_event_id'))
|
||||
for event in session.agent_session.event_stream.get_events(
|
||||
start_id=latest_event_id + 1
|
||||
):
|
||||
if isinstance(
|
||||
event,
|
||||
(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
),
|
||||
):
|
||||
continue
|
||||
await websocket.send_json(event_to_dict(event))
|
||||
|
||||
await session.loop_recv()
|
||||
|
||||
|
||||
@app.get('/api/options/models')
|
||||
async def get_litellm_models() -> list[str]:
|
||||
"""
|
||||
Get all models supported by LiteLLM.
|
||||
|
||||
This function combines models from litellm and Bedrock, removing any
|
||||
error-prone Bedrock models.
|
||||
|
||||
To get the models:
|
||||
```sh
|
||||
curl http://localhost:3000/api/litellm-models
|
||||
```
|
||||
|
||||
Returns:
|
||||
list: A sorted list of unique model names.
|
||||
"""
|
||||
litellm_model_list = litellm.model_list + list(litellm.model_cost.keys())
|
||||
litellm_model_list_without_bedrock = bedrock.remove_error_modelId(
|
||||
litellm_model_list
|
||||
)
|
||||
# TODO: for bedrock, this is using the default config
|
||||
llm_config: LLMConfig = config.get_llm_config()
|
||||
bedrock_model_list = []
|
||||
if (
|
||||
llm_config.aws_region_name
|
||||
and llm_config.aws_access_key_id
|
||||
and llm_config.aws_secret_access_key
|
||||
):
|
||||
bedrock_model_list = bedrock.list_foundation_models(
|
||||
llm_config.aws_region_name,
|
||||
llm_config.aws_access_key_id,
|
||||
llm_config.aws_secret_access_key,
|
||||
)
|
||||
model_list = litellm_model_list_without_bedrock + bedrock_model_list
|
||||
for llm_config in config.llms.values():
|
||||
ollama_base_url = llm_config.ollama_base_url
|
||||
if llm_config.model.startswith('ollama'):
|
||||
if not ollama_base_url:
|
||||
ollama_base_url = llm_config.base_url
|
||||
if ollama_base_url:
|
||||
ollama_url = ollama_base_url.strip('/') + '/api/tags'
|
||||
try:
|
||||
ollama_models_list = requests.get(ollama_url, timeout=3).json()[
|
||||
'models'
|
||||
]
|
||||
for model in ollama_models_list:
|
||||
model_list.append('ollama/' + model['name'])
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f'Error getting OLLAMA models: {e}', exc_info=True)
|
||||
|
||||
return list(sorted(set(model_list)))
|
||||
|
||||
|
||||
@app.get('/api/options/agents')
|
||||
async def get_agents():
|
||||
"""Get all agents supported by LiteLLM.
|
||||
|
||||
To get the agents:
|
||||
```sh
|
||||
curl http://localhost:3000/api/agents
|
||||
```
|
||||
|
||||
Returns:
|
||||
list: A sorted list of agent names.
|
||||
"""
|
||||
agents = sorted(Agent.list_agents())
|
||||
return agents
|
||||
|
||||
|
||||
@app.get('/api/options/security-analyzers')
|
||||
async def get_security_analyzers():
|
||||
"""Get all supported security analyzers.
|
||||
|
||||
To get the security analyzers:
|
||||
```sh
|
||||
curl http://localhost:3000/api/security-analyzers
|
||||
```
|
||||
|
||||
Returns:
|
||||
list: A sorted list of security analyzer names.
|
||||
"""
|
||||
return sorted(SecurityAnalyzers.keys())
|
||||
|
||||
|
||||
@app.get('/api/list-files')
|
||||
async def list_files(request: Request, path: str | None = None):
|
||||
"""List files in the specified path.
|
||||
|
||||
This function retrieves a list of files from the agent's runtime file store,
|
||||
excluding certain system and hidden files/directories.
|
||||
|
||||
To list files:
|
||||
```sh
|
||||
curl http://localhost:3000/api/list-files
|
||||
```
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request object.
|
||||
path (str, optional): The path to list files from. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: A list of file names in the specified path.
|
||||
|
||||
Raises:
|
||||
HTTPException: If there's an error listing the files.
|
||||
"""
|
||||
if not request.state.session.agent_session.runtime:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'Runtime not yet initialized'},
|
||||
)
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
file_list = await runtime.list_files(path)
|
||||
return file_list
|
||||
|
||||
|
||||
@app.get('/api/select-file')
|
||||
async def select_file(file: str, request: Request):
|
||||
"""Retrieve the content of a specified file.
|
||||
|
||||
To select a file:
|
||||
```sh
|
||||
curl http://localhost:3000/api/select-file?file=<file_path>
|
||||
```
|
||||
|
||||
Args:
|
||||
file (str): The path of the file to be retrieved.
|
||||
Expect path to be absolute inside the runtime.
|
||||
request (Request): The incoming request object.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the file content.
|
||||
|
||||
Raises:
|
||||
HTTPException: If there's an error opening the file.
|
||||
"""
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
|
||||
# convert file to an absolute path inside the runtime
|
||||
if not os.path.isabs(file):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={'error': 'File path must be absolute'},
|
||||
)
|
||||
|
||||
read_action = FileReadAction(file)
|
||||
observation = await runtime.run_action(read_action)
|
||||
|
||||
if isinstance(observation, FileReadObservation):
|
||||
content = observation.content
|
||||
return {'code': content}
|
||||
elif isinstance(observation, ErrorObservation):
|
||||
logger.error(f'Error opening file {file}: {observation}', exc_info=False)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'error': f'Error opening file: {observation}'},
|
||||
)
|
||||
|
||||
|
||||
def sanitize_filename(filename):
|
||||
"""Sanitize the filename to prevent directory traversal"""
|
||||
# Remove any directory components
|
||||
filename = os.path.basename(filename)
|
||||
# Remove any non-alphanumeric characters except for .-_
|
||||
filename = re.sub(r'[^\w\-_\.]', '', filename)
|
||||
# Limit the filename length
|
||||
max_length = 255
|
||||
if len(filename) > max_length:
|
||||
name, ext = os.path.splitext(filename)
|
||||
filename = name[: max_length - len(ext)] + ext
|
||||
return filename
|
||||
|
||||
|
||||
@app.post('/api/upload-files')
|
||||
async def upload_file(request: Request, files: list[UploadFile]):
|
||||
"""Upload a list of files to the workspace.
|
||||
|
||||
To upload a files:
|
||||
```sh
|
||||
curl -X POST -F "file=@<file_path1>" -F "file=@<file_path2>" http://localhost:3000/api/upload-files
|
||||
```
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request object.
|
||||
files (list[UploadFile]): A list of files to be uploaded.
|
||||
|
||||
Returns:
|
||||
dict: A message indicating the success of the upload operation.
|
||||
|
||||
Raises:
|
||||
HTTPException: If there's an error saving the files.
|
||||
"""
|
||||
try:
|
||||
uploaded_files = []
|
||||
skipped_files = []
|
||||
for file in files:
|
||||
safe_filename = sanitize_filename(file.filename)
|
||||
file_contents = await file.read()
|
||||
|
||||
if (
|
||||
MAX_FILE_SIZE_MB > 0
|
||||
and len(file_contents) > MAX_FILE_SIZE_MB * 1024 * 1024
|
||||
):
|
||||
skipped_files.append(
|
||||
{
|
||||
'name': safe_filename,
|
||||
'reason': f'Exceeds maximum size limit of {MAX_FILE_SIZE_MB}MB',
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if not is_extension_allowed(safe_filename):
|
||||
skipped_files.append(
|
||||
{'name': safe_filename, 'reason': 'File type not allowed'}
|
||||
)
|
||||
continue
|
||||
|
||||
# copy the file to the runtime
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_file_path = os.path.join(tmp_dir, safe_filename)
|
||||
with open(tmp_file_path, 'wb') as tmp_file:
|
||||
tmp_file.write(file_contents)
|
||||
tmp_file.flush()
|
||||
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
await runtime.copy_to(
|
||||
tmp_file_path, runtime.config.workspace_mount_path_in_sandbox
|
||||
)
|
||||
uploaded_files.append(safe_filename)
|
||||
|
||||
response_content = {
|
||||
'message': 'File upload process completed',
|
||||
'uploaded_files': uploaded_files,
|
||||
'skipped_files': skipped_files,
|
||||
}
|
||||
|
||||
if not uploaded_files and skipped_files:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
**response_content,
|
||||
'error': 'No files were uploaded successfully',
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=response_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Error during file upload: {e}', exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
'error': f'Error during file upload: {str(e)}',
|
||||
'uploaded_files': [],
|
||||
'skipped_files': [],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.post('/api/submit-feedback')
|
||||
async def submit_feedback(request: Request, feedback: FeedbackDataModel):
|
||||
"""Submit user feedback.
|
||||
|
||||
This function stores the provided feedback data.
|
||||
|
||||
To submit feedback:
|
||||
```sh
|
||||
curl -X POST -F "email=test@example.com" -F "token=abc" -F "feedback=positive" -F "permissions=private" -F "trajectory={}" http://localhost:3000/api/submit-feedback
|
||||
```
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request object.
|
||||
feedback (FeedbackDataModel): The feedback data to be stored.
|
||||
|
||||
Returns:
|
||||
dict: The stored feedback data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If there's an error submitting the feedback.
|
||||
"""
|
||||
# Assuming the storage service is already configured in the backend
|
||||
# and there is a function to handle the storage.
|
||||
try:
|
||||
feedback_data = store_feedback(feedback)
|
||||
return JSONResponse(status_code=200, content=feedback_data)
|
||||
except Exception as e:
|
||||
logger.error(f'Error submitting feedback: {e}')
|
||||
return JSONResponse(
|
||||
status_code=500, content={'error': 'Failed to submit feedback'}
|
||||
)
|
||||
|
||||
|
||||
@app.get('/api/root_task')
|
||||
def get_root_task(request: Request):
|
||||
"""Retrieve the root task of the current agent session.
|
||||
|
||||
To get the root_task:
|
||||
```sh
|
||||
curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/root_task
|
||||
```
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request object.
|
||||
|
||||
Returns:
|
||||
dict: The root task data if available.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the root task is not available.
|
||||
"""
|
||||
controller = request.state.session.agent_session.controller
|
||||
if controller is not None:
|
||||
state = controller.get_state()
|
||||
if state:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=state.root_task.to_dict(),
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@app.get('/api/defaults')
|
||||
async def appconfig_defaults():
|
||||
"""Retrieve the default configuration settings.
|
||||
|
||||
To get the default configurations:
|
||||
```sh
|
||||
curl http://localhost:3000/api/defaults
|
||||
```
|
||||
|
||||
Returns:
|
||||
dict: The default configuration settings.
|
||||
"""
|
||||
return config.defaults_dict
|
||||
|
||||
|
||||
@app.post('/api/save-file')
|
||||
async def save_file(request: Request):
|
||||
"""Save a file to the agent's runtime file store.
|
||||
|
||||
This endpoint allows saving a file when the agent is in a paused, finished,
|
||||
or awaiting user input state. It checks the agent's state before proceeding
|
||||
with the file save operation.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming FastAPI request object.
|
||||
|
||||
Returns:
|
||||
JSONResponse: A JSON response indicating the success of the operation.
|
||||
|
||||
Raises:
|
||||
HTTPException:
|
||||
- 403 error if the agent is not in an allowed state for editing.
|
||||
- 400 error if the file path or content is missing.
|
||||
- 500 error if there's an unexpected error during the save operation.
|
||||
"""
|
||||
try:
|
||||
# Get the agent's current state
|
||||
controller = request.state.session.agent_session.controller
|
||||
agent_state = controller.get_agent_state()
|
||||
|
||||
# Check if the agent is in an allowed state for editing
|
||||
if agent_state not in [
|
||||
AgentState.INIT,
|
||||
AgentState.PAUSED,
|
||||
AgentState.FINISHED,
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail='Code editing is only allowed when the agent is paused, finished, or awaiting user input',
|
||||
)
|
||||
|
||||
# Extract file path and content from the request
|
||||
data = await request.json()
|
||||
file_path = data.get('filePath')
|
||||
content = data.get('content')
|
||||
|
||||
# Validate the presence of required data
|
||||
if not file_path or content is None:
|
||||
raise HTTPException(status_code=400, detail='Missing filePath or content')
|
||||
|
||||
# Make sure file_path is abs
|
||||
if not os.path.isabs(file_path):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={'error': 'File path must be absolute'},
|
||||
)
|
||||
|
||||
# Save the file to the agent's runtime file store
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
write_action = FileWriteAction(file_path, content)
|
||||
observation = await runtime.run_action(write_action)
|
||||
|
||||
if isinstance(observation, FileWriteObservation):
|
||||
return JSONResponse(
|
||||
status_code=200, content={'message': 'File saved successfully'}
|
||||
)
|
||||
elif isinstance(observation, ErrorObservation):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={'error': f'Failed to save file: {observation}'},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={'error': f'Unexpected observation: {observation}'},
|
||||
)
|
||||
except Exception as e:
|
||||
# Log the error and return a 500 response
|
||||
logger.error(f'Error saving file: {e}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f'Error saving file: {e}')
|
||||
|
||||
|
||||
@app.route('/api/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
|
||||
async def security_api(request: Request):
|
||||
"""Catch-all route for security analyzer API requests.
|
||||
|
||||
Each request is handled directly to the security analyzer.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming FastAPI request object.
|
||||
|
||||
Returns:
|
||||
Any: The response from the security analyzer.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the security analyzer is not initialized.
|
||||
"""
|
||||
if not request.state.session.agent_session.security_analyzer:
|
||||
raise HTTPException(status_code=404, detail='Security analyzer not initialized')
|
||||
|
||||
return (
|
||||
await request.state.session.agent_session.security_analyzer.handle_api_request(
|
||||
request
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
app.mount('/', StaticFiles(directory='./frontend/dist', html=True), name='dist')
|
||||
10
openhands/server/mock/README.md
Normal file
10
openhands/server/mock/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# OpenHands mock server
|
||||
This is a simple mock server to facilitate development in the frontend.
|
||||
|
||||
## Start the Server
|
||||
Follow the instructions in the README to install dependencies. Then run:
|
||||
```
|
||||
python listen.py
|
||||
```
|
||||
|
||||
Then open the frontend to connect to the mock server. It will simply reply to every received message.
|
||||
60
openhands/server/mock/listen.py
Normal file
60
openhands/server/mock/listen.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, WebSocket
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.websocket('/ws')
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
# send message to mock connection
|
||||
await websocket.send_json(
|
||||
{'action': ActionType.INIT, 'message': 'Control loop started.'}
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# receive message
|
||||
data = await websocket.receive_json()
|
||||
print(f'Received message: {data}')
|
||||
|
||||
# send mock response to client
|
||||
response = {'message': f'receive {data}'}
|
||||
await websocket.send_json(response)
|
||||
print(f'Sent message: {response}')
|
||||
except Exception as e:
|
||||
print(f'WebSocket Error: {e}')
|
||||
|
||||
|
||||
@app.get('/')
|
||||
def read_root():
|
||||
return {'message': 'This is a mock server'}
|
||||
|
||||
|
||||
@app.get('/api/options/models')
|
||||
def read_llm_models():
|
||||
return [
|
||||
'gpt-4',
|
||||
'gpt-4-turbo-preview',
|
||||
'gpt-4-0314',
|
||||
'gpt-4-0613',
|
||||
]
|
||||
|
||||
|
||||
@app.get('/api/options/agents')
|
||||
def read_llm_agents():
|
||||
return [
|
||||
'CodeActAgent',
|
||||
'PlannerAgent',
|
||||
]
|
||||
|
||||
|
||||
@app.get('/api/list-files')
|
||||
def refresh_files():
|
||||
return ['hello_world.py']
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app, host='127.0.0.1', port=3000)
|
||||
4
openhands/server/session/__init__.py
Normal file
4
openhands/server/session/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .manager import SessionManager
|
||||
from .session import Session
|
||||
|
||||
__all__ = ['Session', 'SessionManager']
|
||||
138
openhands/server/session/agent.py
Normal file
138
openhands/server/session/agent.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from openhands.controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig, AppConfig, LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.runtime import Runtime
|
||||
from openhands.security import SecurityAnalyzer, options
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class AgentSession:
|
||||
"""Represents a session with an agent.
|
||||
|
||||
Attributes:
|
||||
controller: The AgentController instance for controlling the agent.
|
||||
"""
|
||||
|
||||
sid: str
|
||||
event_stream: EventStream
|
||||
file_store: FileStore
|
||||
controller: AgentController | None = None
|
||||
runtime: Runtime | None = None
|
||||
security_analyzer: SecurityAnalyzer | None = None
|
||||
_closed: bool = False
|
||||
|
||||
def __init__(self, sid: str, file_store: FileStore):
|
||||
"""Initializes a new instance of the Session class."""
|
||||
self.sid = sid
|
||||
self.event_stream = EventStream(sid, file_store)
|
||||
self.file_store = file_store
|
||||
|
||||
async def start(
|
||||
self,
|
||||
runtime_name: str,
|
||||
config: AppConfig,
|
||||
agent: Agent,
|
||||
max_iterations: int,
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
):
|
||||
"""Starts the agent session.
|
||||
|
||||
Args:
|
||||
start_event: The start event data (optional).
|
||||
"""
|
||||
if self.controller or self.runtime:
|
||||
raise Exception(
|
||||
'Session already started. You need to close this session and start a new one.'
|
||||
)
|
||||
await self._create_security_analyzer(config.security.security_analyzer)
|
||||
await self._create_runtime(runtime_name, config, agent)
|
||||
await self._create_controller(
|
||||
agent,
|
||||
config.security.confirmation_mode,
|
||||
max_iterations,
|
||||
max_budget_per_task=max_budget_per_task,
|
||||
agent_to_llm_config=agent_to_llm_config,
|
||||
agent_configs=agent_configs,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
if self._closed:
|
||||
return
|
||||
if self.controller is not None:
|
||||
end_state = self.controller.get_state()
|
||||
end_state.save_to_session(self.sid, self.file_store)
|
||||
await self.controller.close()
|
||||
if self.runtime is not None:
|
||||
await self.runtime.close()
|
||||
if self.security_analyzer is not None:
|
||||
await self.security_analyzer.close()
|
||||
self._closed = True
|
||||
|
||||
async def _create_security_analyzer(self, security_analyzer: str | None):
|
||||
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions."""
|
||||
logger.info(f'Using security analyzer: {security_analyzer}')
|
||||
if security_analyzer:
|
||||
self.security_analyzer = options.SecurityAnalyzers.get(
|
||||
security_analyzer, SecurityAnalyzer
|
||||
)(self.event_stream)
|
||||
|
||||
async def _create_runtime(self, runtime_name: str, config: AppConfig, agent: Agent):
|
||||
"""Creates a runtime instance."""
|
||||
if self.runtime is not None:
|
||||
raise Exception('Runtime already created')
|
||||
|
||||
logger.info(f'Using runtime: {runtime_name}')
|
||||
runtime_cls = get_runtime_cls(runtime_name)
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
)
|
||||
await self.runtime.ainit()
|
||||
|
||||
async def _create_controller(
|
||||
self,
|
||||
agent: Agent,
|
||||
confirmation_mode: bool,
|
||||
max_iterations: int,
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
):
|
||||
"""Creates an AgentController instance."""
|
||||
if self.controller is not None:
|
||||
raise Exception('Controller already created')
|
||||
if self.runtime is None:
|
||||
raise Exception('Runtime must be initialized before the agent controller')
|
||||
|
||||
logger.info(f'Agents: {agent_configs}')
|
||||
logger.info(f'Creating agent {agent.name} using LLM {agent.llm.config.model}')
|
||||
|
||||
self.controller = AgentController(
|
||||
sid=self.sid,
|
||||
event_stream=self.event_stream,
|
||||
agent=agent,
|
||||
max_iterations=int(max_iterations),
|
||||
max_budget_per_task=max_budget_per_task,
|
||||
agent_to_llm_config=agent_to_llm_config,
|
||||
agent_configs=agent_configs,
|
||||
confirmation_mode=confirmation_mode,
|
||||
# AgentSession is designed to communicate with the frontend, so we don't want to
|
||||
# run the agent in headless mode.
|
||||
headless_mode=False,
|
||||
)
|
||||
try:
|
||||
agent_state = State.restore_from_session(self.sid, self.file_store)
|
||||
self.controller.set_initial_state(
|
||||
agent_state, max_iterations, confirmation_mode
|
||||
)
|
||||
logger.info(f'Restored agent state from session, sid: {self.sid}')
|
||||
except Exception as e:
|
||||
logger.info(f'Error restoring state: {e}')
|
||||
70
openhands/server/session/manager.py
Normal file
70
openhands/server/session/manager.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
from .session import Session
|
||||
|
||||
|
||||
class SessionManager:
|
||||
_sessions: dict[str, Session] = {}
|
||||
cleanup_interval: int = 300
|
||||
session_timeout: int = 600
|
||||
|
||||
def __init__(self, config: AppConfig, file_store: FileStore):
|
||||
asyncio.create_task(self._cleanup_sessions())
|
||||
self.config = config
|
||||
self.file_store = file_store
|
||||
|
||||
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
|
||||
if sid in self._sessions:
|
||||
asyncio.create_task(self._sessions[sid].close())
|
||||
self._sessions[sid] = Session(
|
||||
sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
|
||||
)
|
||||
return self._sessions[sid]
|
||||
|
||||
def get_session(self, sid: str) -> Session | None:
|
||||
if sid not in self._sessions:
|
||||
return None
|
||||
return self._sessions.get(sid)
|
||||
|
||||
async def send(self, sid: str, data: dict[str, object]) -> bool:
|
||||
"""Sends data to the client."""
|
||||
if sid not in self._sessions:
|
||||
return False
|
||||
return await self._sessions[sid].send(data)
|
||||
|
||||
async def send_error(self, sid: str, message: str) -> bool:
|
||||
"""Sends an error message to the client."""
|
||||
return await self.send(sid, {'error': True, 'message': message})
|
||||
|
||||
async def send_message(self, sid: str, message: str) -> bool:
|
||||
"""Sends a message to the client."""
|
||||
return await self.send(sid, {'message': message})
|
||||
|
||||
async def _cleanup_sessions(self):
|
||||
while True:
|
||||
current_time = time.time()
|
||||
session_ids_to_remove = []
|
||||
for sid, session in list(self._sessions.items()):
|
||||
# if session inactive for a long time, remove it
|
||||
if (
|
||||
not session.is_alive
|
||||
and current_time - session.last_active_ts > self.session_timeout
|
||||
):
|
||||
session_ids_to_remove.append(sid)
|
||||
|
||||
for sid in session_ids_to_remove:
|
||||
to_del_session: Session | None = self._sessions.pop(sid, None)
|
||||
if to_del_session is not None:
|
||||
await to_del_session.close()
|
||||
logger.info(
|
||||
f'Session {sid} and related resource have been removed due to inactivity.'
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
193
openhands/server/session/session.py
Normal file
193
openhands/server/session/session.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.const.guide_url import TROUBLESHOOTING_URL
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.core.schema.action import ActionType
|
||||
from openhands.core.schema.config import ConfigType
|
||||
from openhands.events.action import ChangeAgentStateAction, MessageAction, NullAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
CmdOutputObservation,
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
from .agent import AgentSession
|
||||
|
||||
DEL_DELT_SEC = 60 * 60 * 5
|
||||
|
||||
|
||||
class Session:
|
||||
sid: str
|
||||
websocket: WebSocket | None
|
||||
last_active_ts: int = 0
|
||||
is_alive: bool = True
|
||||
agent_session: AgentSession
|
||||
|
||||
def __init__(
|
||||
self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
|
||||
):
|
||||
self.sid = sid
|
||||
self.websocket = ws
|
||||
self.last_active_ts = int(time.time())
|
||||
self.agent_session = AgentSession(sid, file_store)
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER, self.on_event
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def close(self):
|
||||
self.is_alive = False
|
||||
await self.agent_session.close()
|
||||
|
||||
async def loop_recv(self):
|
||||
try:
|
||||
if self.websocket is None:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
data = await self.websocket.receive_json()
|
||||
except ValueError:
|
||||
await self.send_error('Invalid JSON')
|
||||
continue
|
||||
await self.dispatch(data)
|
||||
except WebSocketDisconnect:
|
||||
await self.close()
|
||||
logger.info('WebSocket disconnected, sid: %s', self.sid)
|
||||
except RuntimeError as e:
|
||||
await self.close()
|
||||
logger.exception('Error in loop_recv: %s', e)
|
||||
|
||||
async def _initialize_agent(self, data: dict):
|
||||
self.agent_session.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
|
||||
)
|
||||
self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
|
||||
)
|
||||
# Extract the agent-relevant arguments from the request
|
||||
args = {
|
||||
key: value for key, value in data.get('args', {}).items() if value != ''
|
||||
}
|
||||
agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
|
||||
self.config.security.confirmation_mode = args.get(
|
||||
ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode
|
||||
)
|
||||
self.config.security.security_analyzer = data.get('args', {}).get(
|
||||
ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer
|
||||
)
|
||||
max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
|
||||
# override default LLM config
|
||||
default_llm_config = self.config.get_llm_config()
|
||||
default_llm_config.model = args.get(
|
||||
ConfigType.LLM_MODEL, default_llm_config.model
|
||||
)
|
||||
default_llm_config.api_key = args.get(
|
||||
ConfigType.LLM_API_KEY, default_llm_config.api_key
|
||||
)
|
||||
default_llm_config.base_url = args.get(
|
||||
ConfigType.LLM_BASE_URL, default_llm_config.base_url
|
||||
)
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
|
||||
agent_config = self.config.get_agent_config(agent_cls)
|
||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||
|
||||
# Create the agent session
|
||||
try:
|
||||
await self.agent_session.start(
|
||||
runtime_name=self.config.runtime,
|
||||
config=self.config,
|
||||
agent=agent,
|
||||
max_iterations=max_iterations,
|
||||
max_budget_per_task=self.config.max_budget_per_task,
|
||||
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
||||
agent_configs=self.config.get_agent_configs(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error creating controller: {e}')
|
||||
await self.send_error(
|
||||
f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
|
||||
)
|
||||
return
|
||||
self.agent_session.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
|
||||
)
|
||||
|
||||
async def on_event(self, event: Event):
|
||||
"""Callback function for agent events.
|
||||
|
||||
Args:
|
||||
event: The agent event (Observation or Action).
|
||||
"""
|
||||
if isinstance(event, NullAction):
|
||||
return
|
||||
if isinstance(event, NullObservation):
|
||||
return
|
||||
if event.source == EventSource.AGENT:
|
||||
logger.info('Server event')
|
||||
await self.send(event_to_dict(event))
|
||||
elif event.source == EventSource.USER and isinstance(
|
||||
event, CmdOutputObservation
|
||||
):
|
||||
await self.send(event_to_dict(event))
|
||||
|
||||
async def dispatch(self, data: dict):
|
||||
action = data.get('action', '')
|
||||
if action == ActionType.INIT:
|
||||
await self._initialize_agent(data)
|
||||
return
|
||||
event = event_from_dict(data.copy())
|
||||
# This checks if the model supports images
|
||||
if isinstance(event, MessageAction) and event.images_urls:
|
||||
controller = self.agent_session.controller
|
||||
if controller and not controller.agent.llm.supports_vision():
|
||||
await self.send_error(
|
||||
'Model does not support image upload, change to a different model or try without an image.'
|
||||
)
|
||||
return
|
||||
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
|
||||
async def send(self, data: dict[str, object]) -> bool:
|
||||
try:
|
||||
if self.websocket is None or not self.is_alive:
|
||||
return False
|
||||
await self.websocket.send_json(data)
|
||||
await asyncio.sleep(0.001) # This flushes the data to the client
|
||||
self.last_active_ts = int(time.time())
|
||||
return True
|
||||
except WebSocketDisconnect:
|
||||
self.is_alive = False
|
||||
return False
|
||||
|
||||
async def send_error(self, message: str) -> bool:
|
||||
"""Sends an error message to the client."""
|
||||
return await self.send({'error': True, 'message': message})
|
||||
|
||||
async def send_message(self, message: str) -> bool:
|
||||
"""Sends a message to the client."""
|
||||
return await self.send({'message': message})
|
||||
|
||||
def update_connection(self, ws: WebSocket):
|
||||
self.websocket = ws
|
||||
self.is_alive = True
|
||||
self.last_active_ts = int(time.time())
|
||||
|
||||
def load_from_data(self, data: dict) -> bool:
|
||||
self.last_active_ts = data.get('last_active_ts', 0)
|
||||
if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
|
||||
return False
|
||||
self.is_alive = data.get('is_alive', False)
|
||||
return True
|
||||
Reference in New Issue
Block a user