mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-25 21:36:52 +08:00
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Graham Neubig <neubig@gmail.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
490 lines
18 KiB
Python
490 lines
18 KiB
Python
import os
|
|
import tempfile
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from zipfile import ZipFile
|
|
|
|
import httpcore
|
|
import httpx
|
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
|
|
|
from openhands.core.config import OpenHandsConfig
|
|
from openhands.core.config.mcp_config import (
|
|
MCPConfig,
|
|
MCPSSEServerConfig,
|
|
MCPStdioServerConfig,
|
|
)
|
|
from openhands.core.exceptions import (
|
|
AgentRuntimeTimeoutError,
|
|
)
|
|
from openhands.events import EventStream
|
|
from openhands.events.action import (
|
|
ActionConfirmationStatus,
|
|
AgentThinkAction,
|
|
BrowseInteractiveAction,
|
|
BrowseURLAction,
|
|
CmdRunAction,
|
|
FileEditAction,
|
|
FileReadAction,
|
|
FileWriteAction,
|
|
IPythonRunCellAction,
|
|
)
|
|
from openhands.events.action.action import Action
|
|
from openhands.events.action.files import FileEditSource
|
|
from openhands.events.action.mcp import MCPAction
|
|
from openhands.events.observation import (
|
|
AgentThinkObservation,
|
|
ErrorObservation,
|
|
NullObservation,
|
|
Observation,
|
|
UserRejectObservation,
|
|
)
|
|
from openhands.events.serialization import event_to_dict, observation_from_dict
|
|
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
|
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
|
from openhands.llm.llm_registry import LLMRegistry
|
|
from openhands.runtime.base import Runtime
|
|
from openhands.runtime.plugins import PluginRequirement
|
|
from openhands.runtime.utils.request import send_request
|
|
from openhands.runtime.utils.system_stats import update_last_execution_time
|
|
from openhands.utils.http_session import HttpSession
|
|
from openhands.utils.tenacity_stop import stop_if_should_exit
|
|
|
|
|
|
def _is_retryable_error(exception):
|
|
return isinstance(
|
|
exception, (httpx.RemoteProtocolError, httpcore.RemoteProtocolError)
|
|
)
|
|
|
|
|
|
class ActionExecutionClient(Runtime):
|
|
"""Base class for runtimes that interact with the action execution server.
|
|
|
|
This class contains shared logic between DockerRuntime and RemoteRuntime
|
|
for interacting with the HTTP server defined in action_execution_server.py.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: OpenHandsConfig,
|
|
event_stream: EventStream,
|
|
llm_registry: LLMRegistry,
|
|
sid: str = 'default',
|
|
plugins: list[PluginRequirement] | None = None,
|
|
env_vars: dict[str, str] | None = None,
|
|
status_callback: Any | None = None,
|
|
attach_to_existing: bool = False,
|
|
headless_mode: bool = True,
|
|
user_id: str | None = None,
|
|
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
|
):
|
|
self.session = HttpSession()
|
|
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
|
self._runtime_closed: bool = False
|
|
self._vscode_token: str | None = None # initial dummy value
|
|
self._last_updated_mcp_stdio_servers: list[MCPStdioServerConfig] = []
|
|
super().__init__(
|
|
config,
|
|
event_stream,
|
|
llm_registry,
|
|
sid,
|
|
plugins,
|
|
env_vars,
|
|
status_callback,
|
|
attach_to_existing,
|
|
headless_mode,
|
|
user_id,
|
|
git_provider_tokens,
|
|
)
|
|
|
|
@property
|
|
def action_execution_server_url(self) -> str:
|
|
raise NotImplementedError('Action execution server URL is not implemented')
|
|
|
|
@retry(
|
|
retry=retry_if_exception(_is_retryable_error),
|
|
stop=stop_after_attempt(5) | stop_if_should_exit(),
|
|
wait=wait_exponential(multiplier=1, min=4, max=15),
|
|
)
|
|
def _send_action_server_request(
|
|
self,
|
|
method: str,
|
|
url: str,
|
|
**kwargs,
|
|
) -> httpx.Response:
|
|
"""Send a request to the action execution server.
|
|
|
|
Args:
|
|
method: HTTP method (GET, POST, etc.)
|
|
url: URL to send the request to
|
|
**kwargs: Additional arguments to pass to requests.request()
|
|
|
|
Returns:
|
|
Response from the server
|
|
|
|
Raises:
|
|
AgentRuntimeError: If the request fails
|
|
"""
|
|
return send_request(self.session, method, url, **kwargs)
|
|
|
|
def check_if_alive(self) -> None:
|
|
response = self._send_action_server_request(
|
|
'GET',
|
|
f'{self.action_execution_server_url}/alive',
|
|
timeout=5,
|
|
)
|
|
assert response.is_closed
|
|
|
|
def list_files(self, path: str | None = None) -> list[str]:
|
|
"""List files in the sandbox.
|
|
|
|
If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
|
|
"""
|
|
try:
|
|
data = {}
|
|
if path is not None:
|
|
data['path'] = path
|
|
|
|
response = self._send_action_server_request(
|
|
'POST',
|
|
f'{self.action_execution_server_url}/list_files',
|
|
json=data,
|
|
timeout=10,
|
|
)
|
|
assert response.is_closed
|
|
response_json = response.json()
|
|
assert isinstance(response_json, list)
|
|
return response_json
|
|
except httpx.TimeoutException:
|
|
raise TimeoutError('List files operation timed out')
|
|
|
|
def copy_from(self, path: str) -> Path:
|
|
"""Zip all files in the sandbox and return as a stream of bytes."""
|
|
try:
|
|
params = {'path': path}
|
|
with self.session.stream(
|
|
'GET',
|
|
f'{self.action_execution_server_url}/download_files',
|
|
params=params,
|
|
timeout=30,
|
|
) as response:
|
|
with tempfile.NamedTemporaryFile(
|
|
suffix='.zip', delete=False
|
|
) as temp_file:
|
|
for chunk in response.iter_bytes():
|
|
temp_file.write(chunk)
|
|
temp_file.flush()
|
|
return Path(temp_file.name)
|
|
except httpx.TimeoutException:
|
|
raise TimeoutError('Copy operation timed out')
|
|
|
|
def copy_to(
|
|
self, host_src: str, sandbox_dest: str, recursive: bool = False
|
|
) -> None:
|
|
if not os.path.exists(host_src):
|
|
raise FileNotFoundError(f'Source file {host_src} does not exist')
|
|
|
|
temp_zip_path: str | None = None # Define temp_zip_path outside the try block
|
|
|
|
try:
|
|
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
|
file_to_upload = None
|
|
upload_data = {}
|
|
|
|
if recursive:
|
|
# Create and write the zip file inside the try block
|
|
with tempfile.NamedTemporaryFile(
|
|
suffix='.zip', delete=False
|
|
) as temp_zip:
|
|
temp_zip_path = temp_zip.name
|
|
|
|
try:
|
|
with ZipFile(temp_zip_path, 'w') as zipf:
|
|
for root, _, files in os.walk(host_src):
|
|
for file in files:
|
|
file_path = os.path.join(root, file)
|
|
arcname = os.path.relpath(
|
|
file_path, os.path.dirname(host_src)
|
|
)
|
|
zipf.write(file_path, arcname)
|
|
|
|
self.log(
|
|
'debug',
|
|
f'Opening temporary zip file for upload: {temp_zip_path}',
|
|
)
|
|
file_to_upload = open(temp_zip_path, 'rb')
|
|
upload_data = {'file': file_to_upload}
|
|
except Exception as e:
|
|
# Ensure temp file is cleaned up if zipping fails
|
|
if temp_zip_path and os.path.exists(temp_zip_path):
|
|
os.unlink(temp_zip_path)
|
|
raise e # Re-raise the exception after cleanup attempt
|
|
else:
|
|
file_to_upload = open(host_src, 'rb')
|
|
upload_data = {'file': file_to_upload}
|
|
|
|
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
|
|
|
response = self._send_action_server_request(
|
|
'POST',
|
|
f'{self.action_execution_server_url}/upload_file',
|
|
files=upload_data,
|
|
params=params,
|
|
timeout=300,
|
|
)
|
|
self.log(
|
|
'debug',
|
|
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
|
|
)
|
|
finally:
|
|
if file_to_upload:
|
|
file_to_upload.close()
|
|
|
|
# Cleanup the temporary zip file if it was created
|
|
if temp_zip_path and os.path.exists(temp_zip_path):
|
|
try:
|
|
os.unlink(temp_zip_path)
|
|
except Exception as e:
|
|
self.log(
|
|
'error',
|
|
f'Failed to delete temporary zip file {temp_zip_path}: {e}',
|
|
)
|
|
|
|
def get_vscode_token(self) -> str:
|
|
if self.vscode_enabled and self.runtime_initialized:
|
|
if self._vscode_token is not None: # cached value
|
|
return self._vscode_token
|
|
response = self._send_action_server_request(
|
|
'GET',
|
|
f'{self.action_execution_server_url}/vscode/connection_token',
|
|
timeout=10,
|
|
)
|
|
response_json = response.json()
|
|
assert isinstance(response_json, dict)
|
|
if response_json['token'] is None:
|
|
return ''
|
|
self._vscode_token = response_json['token']
|
|
return response_json['token']
|
|
else:
|
|
return ''
|
|
|
|
def send_action_for_execution(self, action: Action) -> Observation:
|
|
if (
|
|
isinstance(action, FileEditAction)
|
|
and action.impl_source == FileEditSource.LLM_BASED_EDIT
|
|
):
|
|
return self.llm_based_edit(action)
|
|
|
|
# set timeout to default if not set
|
|
if action.timeout is None:
|
|
if isinstance(action, CmdRunAction) and action.blocking:
|
|
raise RuntimeError('Blocking command with no timeout set')
|
|
# We don't block the command if this is a default timeout action
|
|
action.set_hard_timeout(self.config.sandbox.timeout, blocking=False)
|
|
|
|
with self.action_semaphore:
|
|
if not action.runnable:
|
|
if isinstance(action, AgentThinkAction):
|
|
return AgentThinkObservation('Your thought has been logged.')
|
|
return NullObservation('')
|
|
if (
|
|
hasattr(action, 'confirmation_state')
|
|
and action.confirmation_state
|
|
== ActionConfirmationStatus.AWAITING_CONFIRMATION
|
|
):
|
|
return NullObservation('')
|
|
action_type = action.action # type: ignore[attr-defined]
|
|
if action_type not in ACTION_TYPE_TO_CLASS:
|
|
raise ValueError(f'Action {action_type} does not exist.')
|
|
if not hasattr(self, action_type):
|
|
return ErrorObservation(
|
|
f'Action {action_type} is not supported in the current runtime.',
|
|
error_id='AGENT_ERROR$BAD_ACTION',
|
|
)
|
|
if (
|
|
getattr(action, 'confirmation_state', None)
|
|
== ActionConfirmationStatus.REJECTED
|
|
):
|
|
return UserRejectObservation(
|
|
'Action has been rejected by the user! Waiting for further user input.'
|
|
)
|
|
|
|
assert action.timeout is not None
|
|
|
|
try:
|
|
execution_action_body: dict[str, Any] = {
|
|
'action': event_to_dict(action),
|
|
}
|
|
response = self._send_action_server_request(
|
|
'POST',
|
|
f'{self.action_execution_server_url}/execute_action',
|
|
json=execution_action_body,
|
|
# wait a few more seconds to get the timeout error from client side
|
|
timeout=action.timeout + 5,
|
|
)
|
|
assert response.is_closed
|
|
output = response.json()
|
|
if getattr(action, 'hidden', False):
|
|
output.get('extras')['hidden'] = True
|
|
obs = observation_from_dict(output)
|
|
obs._cause = action.id # type: ignore[attr-defined]
|
|
except httpx.TimeoutException:
|
|
raise AgentRuntimeTimeoutError(
|
|
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
|
|
)
|
|
finally:
|
|
update_last_execution_time()
|
|
return obs
|
|
|
|
def run(self, action: CmdRunAction) -> Observation:
|
|
return self.send_action_for_execution(action)
|
|
|
|
def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
|
return self.send_action_for_execution(action)
|
|
|
|
def read(self, action: FileReadAction) -> Observation:
|
|
return self.send_action_for_execution(action)
|
|
|
|
def write(self, action: FileWriteAction) -> Observation:
|
|
return self.send_action_for_execution(action)
|
|
|
|
def edit(self, action: FileEditAction) -> Observation:
|
|
return self.send_action_for_execution(action)
|
|
|
|
def browse(self, action: BrowseURLAction) -> Observation:
|
|
return self.send_action_for_execution(action)
|
|
|
|
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
|
return self.send_action_for_execution(action)
|
|
|
|
def get_mcp_config(
|
|
self, extra_stdio_servers: list[MCPStdioServerConfig] | None = None
|
|
) -> MCPConfig:
|
|
import sys
|
|
|
|
# Check if we're on Windows - MCP is disabled on Windows
|
|
if sys.platform == 'win32':
|
|
# Return empty MCP config on Windows
|
|
self.log('debug', 'MCP is disabled on Windows, returning empty config')
|
|
return MCPConfig(sse_servers=[], stdio_servers=[])
|
|
|
|
# Add the runtime as another MCP server
|
|
updated_mcp_config = self.config.mcp.model_copy()
|
|
|
|
# Get current stdio servers
|
|
current_stdio_servers: list[MCPStdioServerConfig] = list(
|
|
updated_mcp_config.stdio_servers
|
|
)
|
|
if extra_stdio_servers:
|
|
current_stdio_servers.extend(extra_stdio_servers)
|
|
|
|
# Check if there are any new servers using the __eq__ operator
|
|
new_servers = [
|
|
server
|
|
for server in current_stdio_servers
|
|
if server not in self._last_updated_mcp_stdio_servers
|
|
]
|
|
|
|
self.log(
|
|
'debug',
|
|
f'adding {len(new_servers)} new stdio servers to MCP config: {new_servers}',
|
|
)
|
|
|
|
# Only send update request if there are new servers
|
|
if new_servers:
|
|
# Use a union of current servers and last updated servers for the update
|
|
# This ensures we don't lose any servers that might be missing from either list
|
|
combined_servers = current_stdio_servers.copy()
|
|
for server in self._last_updated_mcp_stdio_servers:
|
|
if server not in combined_servers:
|
|
combined_servers.append(server)
|
|
|
|
stdio_tools = [
|
|
server.model_dump(mode='json') for server in combined_servers
|
|
]
|
|
stdio_tools.sort(key=lambda x: x.get('name', '')) # Sort by server name
|
|
|
|
self.log(
|
|
'debug',
|
|
f'Updating MCP server with {len(new_servers)} new stdio servers (total: {len(combined_servers)})',
|
|
)
|
|
response = self._send_action_server_request(
|
|
'POST',
|
|
f'{self.action_execution_server_url}/update_mcp_server',
|
|
json=stdio_tools,
|
|
timeout=60,
|
|
)
|
|
result = response.json()
|
|
if response.status_code != 200:
|
|
self.log('warning', f'Failed to update MCP server: {response.text}')
|
|
else:
|
|
if result.get('router_error_log'):
|
|
self.log(
|
|
'warning',
|
|
f'Some MCP servers failed to be added: {result["router_error_log"]}',
|
|
)
|
|
|
|
# Update our cached list with combined servers after successful update
|
|
self._last_updated_mcp_stdio_servers = combined_servers.copy()
|
|
self.log(
|
|
'debug',
|
|
f'Successfully updated MCP stdio servers, now tracking {len(combined_servers)} servers',
|
|
)
|
|
self.log(
|
|
'info',
|
|
f'Updated MCP config: {updated_mcp_config.sse_servers}',
|
|
)
|
|
else:
|
|
self.log('debug', 'No new stdio servers to update')
|
|
|
|
if len(self._last_updated_mcp_stdio_servers) > 0:
|
|
# We should always include the runtime as an MCP server whenever there's > 0 stdio servers
|
|
updated_mcp_config.sse_servers.append(
|
|
MCPSSEServerConfig(
|
|
url=self.action_execution_server_url.rstrip('/') + '/mcp/sse',
|
|
api_key=self.session_api_key,
|
|
)
|
|
)
|
|
|
|
return updated_mcp_config
|
|
|
|
async def call_tool_mcp(self, action: MCPAction) -> Observation:
|
|
import sys
|
|
|
|
from openhands.events.observation import ErrorObservation
|
|
|
|
# Check if we're on Windows - MCP is disabled on Windows
|
|
if sys.platform == 'win32':
|
|
self.log('info', 'MCP functionality is disabled on Windows')
|
|
return ErrorObservation('MCP functionality is not available on Windows')
|
|
|
|
# Import here to avoid circular imports
|
|
from openhands.mcp.utils import call_tool_mcp as call_tool_mcp_handler
|
|
from openhands.mcp.utils import create_mcp_clients
|
|
|
|
# Get the updated MCP config
|
|
updated_mcp_config = self.get_mcp_config()
|
|
self.log(
|
|
'debug',
|
|
f'Creating MCP clients with servers: {updated_mcp_config.sse_servers}',
|
|
)
|
|
|
|
# Create clients for this specific operation
|
|
mcp_clients = await create_mcp_clients(
|
|
updated_mcp_config.sse_servers, updated_mcp_config.shttp_servers, self.sid
|
|
)
|
|
|
|
# Call the tool and return the result
|
|
# No need for try/finally since disconnect() is now just resetting state
|
|
result = await call_tool_mcp_handler(mcp_clients, action)
|
|
return result
|
|
|
|
def close(self) -> None:
|
|
# Make sure we don't close the session multiple times
|
|
# Can happen in evaluation
|
|
if self._runtime_closed:
|
|
return
|
|
self._runtime_closed = True
|
|
self.session.close()
|