mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add MCP support for CLI (#9519)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
This commit is contained in:
parent
45ac6b839c
commit
fbd9280239
@ -153,6 +153,7 @@ You can use the following commands whenever the prompt (`>`) is displayed:
|
||||
| `/new` | Start a new conversation |
|
||||
| `/settings` | View and modify current LLM/agent settings |
|
||||
| `/resume` | Resume the agent if paused |
|
||||
| `/mcp` | Manage MCP server configuration and view connection errors |
|
||||
|
||||
#### Settings and Configuration
|
||||
|
||||
@ -162,7 +163,7 @@ follow the prompts:
|
||||
- **Basic settings**: Choose a model/provider and enter your API key.
|
||||
- **Advanced settings**: Set custom endpoints, enable or disable confirmation mode, and configure memory condensation.
|
||||
|
||||
Settings can also be managed via the `config.toml` file.
|
||||
Settings can also be managed via the `config.toml` file in the current directory or `~/.openhands/config.toml`.
|
||||
|
||||
#### Repository Initialization
|
||||
|
||||
@ -174,6 +175,41 @@ project details and structure. Use this when onboarding the agent to a new codeb
|
||||
You can pause the agent while it is running by pressing `Ctrl-P`. To continue the conversation after pausing, simply
|
||||
type `/resume` at the prompt.
|
||||
|
||||
#### MCP Server Management
|
||||
|
||||
To configure Model Context Protocol (MCP) servers, you can refer to the documentation on [MCP servers](../mcp) and use the `/mcp` command in the CLI. This command provides an interactive interface for managing Model Context Protocol (MCP) servers:
|
||||
|
||||
- **List configured servers**: View all currently configured MCP servers (SSE, Stdio, and SHTTP)
|
||||
- **Add new server**: Interactively add a new MCP server with guided prompts
|
||||
- **Remove server**: Remove an existing MCP server from your configuration
|
||||
- **View errors**: Display any connection errors that occurred during MCP server startup
|
||||
|
||||
This command modifies your `~/.openhands/config.toml` file and will prompt you to restart OpenHands for changes to take effect.
|
||||
|
||||
To enable the [Tavily MCP server](https://github.com/tavily-ai/tavily-mcp) search engine, you can set the `search_api_key` under the `[core]` section in the `~/.openhands/config.toml` file.
|
||||
|
||||
##### Example of the `config.toml` file with MCP server configuration:
|
||||
|
||||
```toml
|
||||
[core]
|
||||
search_api_key = "tvly-your-api-key-here"
|
||||
|
||||
[mcp]
|
||||
stdio_servers = [
|
||||
{name="fetch", command="uvx", args=["mcp-server-fetch"]},
|
||||
]
|
||||
|
||||
sse_servers = [
|
||||
# Basic SSE server with just a URL
|
||||
"http://example.com:8080/sse",
|
||||
]
|
||||
|
||||
shttp_servers = [
|
||||
# Streamable HTTP server with API key authentication
|
||||
{url="https://secure-example.com/mcp", api_key="your-api-key"}
|
||||
]
|
||||
```
|
||||
|
||||
## Tips and Troubleshooting
|
||||
|
||||
- Use `/help` at any time to see the list of available commands.
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from prompt_toolkit import print_formatted_text
|
||||
import toml
|
||||
from prompt_toolkit import HTML, print_formatted_text
|
||||
from prompt_toolkit.patch_stdout import patch_stdout
|
||||
from prompt_toolkit.shortcuts import clear, print_container
|
||||
from prompt_toolkit.widgets import Frame, TextArea
|
||||
from pydantic import ValidationError
|
||||
|
||||
from openhands.cli.settings import (
|
||||
display_settings,
|
||||
@ -14,9 +20,12 @@ from openhands.cli.tui import (
|
||||
COLOR_GREY,
|
||||
UsageMetrics,
|
||||
cli_confirm,
|
||||
create_prompt_session,
|
||||
display_help,
|
||||
display_mcp_errors,
|
||||
display_shutdown_message,
|
||||
display_status,
|
||||
read_prompt_input,
|
||||
)
|
||||
from openhands.cli.utils import (
|
||||
add_local_config_trusted_dir,
|
||||
@ -27,6 +36,11 @@ from openhands.cli.utils import (
|
||||
from openhands.core.config import (
|
||||
OpenHandsConfig,
|
||||
)
|
||||
from openhands.core.config.mcp_config import (
|
||||
MCPSHTTPServerConfig,
|
||||
MCPSSEServerConfig,
|
||||
MCPStdioServerConfig,
|
||||
)
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.core.schema.exit_reason import ExitReason
|
||||
from openhands.events import EventSource
|
||||
@ -38,6 +52,72 @@ from openhands.events.stream import EventStream
|
||||
from openhands.storage.settings.file_settings_store import FileSettingsStore
|
||||
|
||||
|
||||
async def collect_input(config: OpenHandsConfig, prompt_text: str) -> str | None:
|
||||
"""Collect user input with cancellation support.
|
||||
|
||||
Args:
|
||||
config: OpenHands configuration
|
||||
prompt_text: Text to display to user
|
||||
|
||||
Returns:
|
||||
str | None: User input string, or None if user cancelled
|
||||
"""
|
||||
print_formatted_text(prompt_text, end=' ')
|
||||
user_input = await read_prompt_input(config, '', multiline=False)
|
||||
|
||||
# Check for cancellation
|
||||
if user_input.strip().lower() in ['/exit', '/cancel', 'cancel']:
|
||||
return None
|
||||
|
||||
return user_input.strip()
|
||||
|
||||
|
||||
def restart_cli() -> None:
|
||||
"""Restart the CLI by replacing the current process."""
|
||||
print_formatted_text('🔄 Restarting OpenHands CLI...')
|
||||
|
||||
# Get the current Python executable and script arguments
|
||||
python_executable = sys.executable
|
||||
script_args = sys.argv
|
||||
|
||||
# Use os.execv to replace the current process
|
||||
# This preserves the original command line arguments
|
||||
try:
|
||||
os.execv(python_executable, [python_executable] + script_args)
|
||||
except Exception as e:
|
||||
print_formatted_text(f'❌ Failed to restart CLI: {e}')
|
||||
print_formatted_text(
|
||||
'Please restart OpenHands manually for changes to take effect.'
|
||||
)
|
||||
|
||||
|
||||
async def prompt_for_restart(config: OpenHandsConfig) -> bool:
|
||||
"""Prompt user if they want to restart the CLI and return their choice."""
|
||||
print_formatted_text('📝 MCP server configuration updated successfully!')
|
||||
print_formatted_text('The changes will take effect after restarting OpenHands.')
|
||||
|
||||
prompt_session = create_prompt_session(config)
|
||||
|
||||
while True:
|
||||
try:
|
||||
with patch_stdout():
|
||||
response = await prompt_session.prompt_async(
|
||||
HTML(
|
||||
'<gold>Would you like to restart OpenHands now? (y/n): </gold>'
|
||||
)
|
||||
)
|
||||
response = response.strip().lower() if response else ''
|
||||
|
||||
if response in ['y', 'yes']:
|
||||
return True
|
||||
elif response in ['n', 'no']:
|
||||
return False
|
||||
else:
|
||||
print_formatted_text('Please enter "y" for yes or "n" for no.')
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return False
|
||||
|
||||
|
||||
async def handle_commands(
|
||||
command: str,
|
||||
event_stream: EventStream,
|
||||
@ -79,6 +159,8 @@ async def handle_commands(
|
||||
await handle_settings_command(config, settings_store)
|
||||
elif command == '/resume':
|
||||
close_repl, new_session_requested = await handle_resume_command(event_stream)
|
||||
elif command == '/mcp':
|
||||
await handle_mcp_command(config)
|
||||
else:
|
||||
close_repl = True
|
||||
action = MessageAction(content=command)
|
||||
@ -327,3 +409,432 @@ def check_folder_security_agreement(config: OpenHandsConfig, current_dir: str) -
|
||||
return confirm
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def handle_mcp_command(config: OpenHandsConfig) -> None:
|
||||
"""Handle MCP command with interactive menu."""
|
||||
action = cli_confirm(
|
||||
config,
|
||||
'MCP Server Configuration',
|
||||
[
|
||||
'List configured servers',
|
||||
'Add new server',
|
||||
'Remove server',
|
||||
'View errors',
|
||||
'Go back',
|
||||
],
|
||||
)
|
||||
|
||||
if action == 0: # List
|
||||
display_mcp_servers(config)
|
||||
elif action == 1: # Add
|
||||
await add_mcp_server(config)
|
||||
elif action == 2: # Remove
|
||||
await remove_mcp_server(config)
|
||||
elif action == 3: # View errors
|
||||
handle_mcp_errors_command()
|
||||
# action == 4 is "Go back", do nothing
|
||||
|
||||
|
||||
def display_mcp_servers(config: OpenHandsConfig) -> None:
|
||||
"""Display MCP server configuration information."""
|
||||
mcp_config = config.mcp
|
||||
|
||||
# Count the different types of servers
|
||||
sse_count = len(mcp_config.sse_servers)
|
||||
stdio_count = len(mcp_config.stdio_servers)
|
||||
shttp_count = len(mcp_config.shttp_servers)
|
||||
total_count = sse_count + stdio_count + shttp_count
|
||||
|
||||
if total_count == 0:
|
||||
print_formatted_text(
|
||||
'No custom MCP servers configured. See the documentation to learn more:\n'
|
||||
' https://docs.all-hands.dev/usage/how-to/cli-mode#using-mcp-servers'
|
||||
)
|
||||
else:
|
||||
print_formatted_text(
|
||||
f'Configured MCP servers:\n'
|
||||
f' • SSE servers: {sse_count}\n'
|
||||
f' • Stdio servers: {stdio_count}\n'
|
||||
f' • SHTTP servers: {shttp_count}\n'
|
||||
f' • Total: {total_count}'
|
||||
)
|
||||
|
||||
# Show details for each type if they exist
|
||||
if sse_count > 0:
|
||||
print_formatted_text('SSE Servers:')
|
||||
for idx, sse_server in enumerate(mcp_config.sse_servers, 1):
|
||||
print_formatted_text(f' {idx}. {sse_server.url}')
|
||||
print_formatted_text('')
|
||||
|
||||
if stdio_count > 0:
|
||||
print_formatted_text('Stdio Servers:')
|
||||
for idx, stdio_server in enumerate(mcp_config.stdio_servers, 1):
|
||||
print_formatted_text(
|
||||
f' {idx}. {stdio_server.name} ({stdio_server.command})'
|
||||
)
|
||||
print_formatted_text('')
|
||||
|
||||
if shttp_count > 0:
|
||||
print_formatted_text('SHTTP Servers:')
|
||||
for idx, shttp_server in enumerate(mcp_config.shttp_servers, 1):
|
||||
print_formatted_text(f' {idx}. {shttp_server.url}')
|
||||
print_formatted_text('')
|
||||
|
||||
|
||||
def handle_mcp_errors_command() -> None:
|
||||
"""Display MCP connection errors."""
|
||||
display_mcp_errors()
|
||||
|
||||
|
||||
def get_config_file_path() -> Path:
|
||||
"""Get the path to the config file. By default, we use config.toml in the current working directory. If not found, we use ~/.openhands/config.toml."""
|
||||
# Check if config.toml exists in the current directory
|
||||
current_dir = Path.cwd() / 'config.toml'
|
||||
if current_dir.exists():
|
||||
return current_dir
|
||||
|
||||
# Fallback to the user's home directory
|
||||
return Path.home() / '.openhands' / 'config.toml'
|
||||
|
||||
|
||||
def load_config_file(file_path: Path) -> dict:
|
||||
"""Load the config file, creating it if it doesn't exist."""
|
||||
if file_path.exists():
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
return toml.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return {}
|
||||
|
||||
|
||||
def save_config_file(config_data: dict, file_path: Path) -> None:
|
||||
"""Save the config file."""
|
||||
with open(file_path, 'w') as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
|
||||
def _ensure_mcp_config_structure(config_data: dict) -> None:
|
||||
"""Ensure MCP configuration structure exists in config data."""
|
||||
if 'mcp' not in config_data:
|
||||
config_data['mcp'] = {}
|
||||
|
||||
|
||||
def _add_server_to_config(server_type: str, server_config: dict) -> Path:
|
||||
"""Add a server configuration to the config file."""
|
||||
config_file_path = get_config_file_path()
|
||||
config_data = load_config_file(config_file_path)
|
||||
_ensure_mcp_config_structure(config_data)
|
||||
|
||||
if server_type not in config_data['mcp']:
|
||||
config_data['mcp'][server_type] = []
|
||||
|
||||
config_data['mcp'][server_type].append(server_config)
|
||||
save_config_file(config_data, config_file_path)
|
||||
|
||||
return config_file_path
|
||||
|
||||
|
||||
async def add_mcp_server(config: OpenHandsConfig) -> None:
|
||||
"""Add a new MCP server configuration."""
|
||||
# Choose transport type
|
||||
transport_type = cli_confirm(
|
||||
config,
|
||||
'Select MCP server transport type:',
|
||||
[
|
||||
'SSE (Server-Sent Events)',
|
||||
'Stdio (Standard Input/Output)',
|
||||
'SHTTP (Streamable HTTP)',
|
||||
'Cancel',
|
||||
],
|
||||
)
|
||||
|
||||
if transport_type == 3: # Cancel
|
||||
return
|
||||
|
||||
try:
|
||||
if transport_type == 0: # SSE
|
||||
await add_sse_server(config)
|
||||
elif transport_type == 1: # Stdio
|
||||
await add_stdio_server(config)
|
||||
elif transport_type == 2: # SHTTP
|
||||
await add_shttp_server(config)
|
||||
except Exception as e:
|
||||
print_formatted_text(f'Error adding MCP server: {e}')
|
||||
|
||||
|
||||
async def add_sse_server(config: OpenHandsConfig) -> None:
|
||||
"""Add an SSE MCP server."""
|
||||
print_formatted_text('Adding SSE MCP Server')
|
||||
|
||||
while True: # Retry loop for the entire form
|
||||
# Collect all inputs
|
||||
url = await collect_input(config, '\nEnter server URL:')
|
||||
if url is None:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
api_key = await collect_input(
|
||||
config, '\nEnter API key (optional, press Enter to skip):'
|
||||
)
|
||||
if api_key is None:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
# Convert empty string to None for optional field
|
||||
api_key = api_key if api_key else None
|
||||
|
||||
# Validate all inputs at once
|
||||
try:
|
||||
server = MCPSSEServerConfig(url=url, api_key=api_key)
|
||||
break # Success - exit retry loop
|
||||
|
||||
except ValidationError as e:
|
||||
# Show all errors at once
|
||||
print_formatted_text('❌ Please fix the following errors:')
|
||||
for error in e.errors():
|
||||
field = error['loc'][0] if error['loc'] else 'unknown'
|
||||
print_formatted_text(f' • {field}: {error["msg"]}')
|
||||
|
||||
if cli_confirm(config, '\nTry again?') != 0:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
# Save to config file
|
||||
server_config = {'url': server.url}
|
||||
if server.api_key:
|
||||
server_config['api_key'] = server.api_key
|
||||
|
||||
config_file_path = _add_server_to_config('sse_servers', server_config)
|
||||
print_formatted_text(f'✓ SSE MCP server added to {config_file_path}: {server.url}')
|
||||
|
||||
# Prompt for restart
|
||||
if await prompt_for_restart(config):
|
||||
restart_cli()
|
||||
|
||||
|
||||
async def add_stdio_server(config: OpenHandsConfig) -> None:
|
||||
"""Add a Stdio MCP server."""
|
||||
print_formatted_text('Adding Stdio MCP Server')
|
||||
|
||||
# Get existing server names to check for duplicates
|
||||
existing_names = [server.name for server in config.mcp.stdio_servers]
|
||||
|
||||
while True: # Retry loop for the entire form
|
||||
# Collect all inputs
|
||||
name = await collect_input(config, '\nEnter server name:')
|
||||
if name is None:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
command = await collect_input(config, "\nEnter command (e.g., 'uvx', 'npx'):")
|
||||
if command is None:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
args_input = await collect_input(
|
||||
config,
|
||||
'\nEnter arguments (optional, e.g., "-y server-package arg1"):',
|
||||
)
|
||||
if args_input is None:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
env_input = await collect_input(
|
||||
config,
|
||||
'\nEnter environment variables (KEY=VALUE format, comma-separated, optional):',
|
||||
)
|
||||
if env_input is None:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
# Check for duplicate server names
|
||||
if name in existing_names:
|
||||
print_formatted_text(f"❌ Server name '{name}' already exists.")
|
||||
if cli_confirm(config, '\nTry again?') != 0:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
continue
|
||||
|
||||
# Validate all inputs at once
|
||||
try:
|
||||
server = MCPStdioServerConfig(
|
||||
name=name,
|
||||
command=command,
|
||||
args=args_input, # type: ignore # Will be parsed by Pydantic validator
|
||||
env=env_input, # type: ignore # Will be parsed by Pydantic validator
|
||||
)
|
||||
break # Success - exit retry loop
|
||||
|
||||
except ValidationError as e:
|
||||
# Show all errors at once
|
||||
print_formatted_text('❌ Please fix the following errors:')
|
||||
for error in e.errors():
|
||||
field = error['loc'][0] if error['loc'] else 'unknown'
|
||||
print_formatted_text(f' • {field}: {error["msg"]}')
|
||||
|
||||
if cli_confirm(config, '\nTry again?') != 0:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
# Save to config file
|
||||
server_config: dict[str, Any] = {
|
||||
'name': server.name,
|
||||
'command': server.command,
|
||||
}
|
||||
if server.args:
|
||||
server_config['args'] = server.args
|
||||
if server.env:
|
||||
server_config['env'] = server.env
|
||||
|
||||
config_file_path = _add_server_to_config('stdio_servers', server_config)
|
||||
print_formatted_text(
|
||||
f'✓ Stdio MCP server added to {config_file_path}: {server.name}'
|
||||
)
|
||||
|
||||
# Prompt for restart
|
||||
if await prompt_for_restart(config):
|
||||
restart_cli()
|
||||
|
||||
|
||||
async def add_shttp_server(config: OpenHandsConfig) -> None:
|
||||
"""Add an SHTTP MCP server."""
|
||||
print_formatted_text('Adding SHTTP MCP Server')
|
||||
|
||||
while True: # Retry loop for the entire form
|
||||
# Collect all inputs
|
||||
url = await collect_input(config, '\nEnter server URL:')
|
||||
if url is None:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
api_key = await collect_input(
|
||||
config, '\nEnter API key (optional, press Enter to skip):'
|
||||
)
|
||||
if api_key is None:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
# Convert empty string to None for optional field
|
||||
api_key = api_key if api_key else None
|
||||
|
||||
# Validate all inputs at once
|
||||
try:
|
||||
server = MCPSHTTPServerConfig(url=url, api_key=api_key)
|
||||
break # Success - exit retry loop
|
||||
|
||||
except ValidationError as e:
|
||||
# Show all errors at once
|
||||
print_formatted_text('❌ Please fix the following errors:')
|
||||
for error in e.errors():
|
||||
field = error['loc'][0] if error['loc'] else 'unknown'
|
||||
print_formatted_text(f' • {field}: {error["msg"]}')
|
||||
|
||||
if cli_confirm(config, '\nTry again?') != 0:
|
||||
print_formatted_text('Operation cancelled.')
|
||||
return
|
||||
|
||||
# Save to config file
|
||||
server_config = {'url': server.url}
|
||||
if server.api_key:
|
||||
server_config['api_key'] = server.api_key
|
||||
|
||||
config_file_path = _add_server_to_config('shttp_servers', server_config)
|
||||
print_formatted_text(
|
||||
f'✓ SHTTP MCP server added to {config_file_path}: {server.url}'
|
||||
)
|
||||
|
||||
# Prompt for restart
|
||||
if await prompt_for_restart(config):
|
||||
restart_cli()
|
||||
|
||||
|
||||
async def remove_mcp_server(config: OpenHandsConfig) -> None:
|
||||
"""Remove an MCP server configuration."""
|
||||
mcp_config = config.mcp
|
||||
|
||||
# Collect all servers with their types
|
||||
servers: list[tuple[str, str, object]] = []
|
||||
|
||||
# Add SSE servers
|
||||
for sse_server in mcp_config.sse_servers:
|
||||
servers.append(('SSE', sse_server.url, sse_server))
|
||||
|
||||
# Add Stdio servers
|
||||
for stdio_server in mcp_config.stdio_servers:
|
||||
servers.append(('Stdio', stdio_server.name, stdio_server))
|
||||
|
||||
# Add SHTTP servers
|
||||
for shttp_server in mcp_config.shttp_servers:
|
||||
servers.append(('SHTTP', shttp_server.url, shttp_server))
|
||||
|
||||
if not servers:
|
||||
print_formatted_text('No MCP servers configured to remove.')
|
||||
return
|
||||
|
||||
# Create choices for the user
|
||||
choices = []
|
||||
for server_type, identifier, _ in servers:
|
||||
choices.append(f'{server_type}: {identifier}')
|
||||
choices.append('Cancel')
|
||||
|
||||
# Let user choose which server to remove
|
||||
choice = cli_confirm(config, 'Select MCP server to remove:', choices)
|
||||
|
||||
if choice == len(choices) - 1: # Cancel
|
||||
return
|
||||
|
||||
# Remove the selected server
|
||||
server_type, identifier, _ = servers[choice]
|
||||
|
||||
# Confirm removal
|
||||
confirm = cli_confirm(
|
||||
config,
|
||||
f'Are you sure you want to remove {server_type} server "{identifier}"?',
|
||||
['Yes, remove', 'Cancel'],
|
||||
)
|
||||
|
||||
if confirm == 1: # Cancel
|
||||
return
|
||||
|
||||
# Load config file and remove the server
|
||||
config_file_path = get_config_file_path()
|
||||
config_data = load_config_file(config_file_path)
|
||||
|
||||
_ensure_mcp_config_structure(config_data)
|
||||
|
||||
removed = False
|
||||
|
||||
if server_type == 'SSE' and 'sse_servers' in config_data['mcp']:
|
||||
config_data['mcp']['sse_servers'] = [
|
||||
s for s in config_data['mcp']['sse_servers'] if s.get('url') != identifier
|
||||
]
|
||||
removed = True
|
||||
elif server_type == 'Stdio' and 'stdio_servers' in config_data['mcp']:
|
||||
config_data['mcp']['stdio_servers'] = [
|
||||
s
|
||||
for s in config_data['mcp']['stdio_servers']
|
||||
if s.get('name') != identifier
|
||||
]
|
||||
removed = True
|
||||
elif server_type == 'SHTTP' and 'shttp_servers' in config_data['mcp']:
|
||||
config_data['mcp']['shttp_servers'] = [
|
||||
s for s in config_data['mcp']['shttp_servers'] if s.get('url') != identifier
|
||||
]
|
||||
removed = True
|
||||
|
||||
if removed:
|
||||
save_config_file(config_data, config_file_path)
|
||||
print_formatted_text(
|
||||
f'✓ {server_type} MCP server "{identifier}" removed from {config_file_path}.'
|
||||
)
|
||||
|
||||
# Prompt for restart
|
||||
if await prompt_for_restart(config):
|
||||
restart_cli()
|
||||
else:
|
||||
print_formatted_text(f'Failed to remove {server_type} server "{identifier}".')
|
||||
|
||||
@ -74,6 +74,7 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.io import read_task
|
||||
from openhands.mcp import add_mcp_tools_to_agent
|
||||
from openhands.mcp.error_collector import mcp_error_collector
|
||||
from openhands.memory.condenser.impl.llm_summarizing_condenser import (
|
||||
LLMSummarizingCondenserConfig,
|
||||
)
|
||||
@ -298,6 +299,10 @@ async def run_session(
|
||||
|
||||
# Add MCP tools to the agent
|
||||
if agent.config.enable_mcp:
|
||||
# Clear any previous errors and enable collection
|
||||
mcp_error_collector.clear_errors()
|
||||
mcp_error_collector.enable_collection()
|
||||
|
||||
# Add OpenHands' MCP server by default
|
||||
_, openhands_mcp_stdio_servers = (
|
||||
OpenHandsMCPConfigImpl.create_default_mcp_server_config(
|
||||
@ -309,6 +314,9 @@ async def run_session(
|
||||
|
||||
await add_mcp_tools_to_agent(agent, runtime, memory)
|
||||
|
||||
# Disable collection after startup
|
||||
mcp_error_collector.disable_collection()
|
||||
|
||||
# Clear loading animation
|
||||
is_loaded.set()
|
||||
|
||||
@ -319,7 +327,27 @@ async def run_session(
|
||||
if not skip_banner:
|
||||
display_banner(session_id=sid)
|
||||
|
||||
welcome_message = 'What do you want to build?' # from the application
|
||||
welcome_message = ''
|
||||
|
||||
# Display number of MCP servers configured
|
||||
if agent.config.enable_mcp:
|
||||
total_mcp_servers = (
|
||||
len(runtime.config.mcp.stdio_servers)
|
||||
+ len(runtime.config.mcp.sse_servers)
|
||||
+ len(runtime.config.mcp.shttp_servers)
|
||||
)
|
||||
if total_mcp_servers > 0:
|
||||
mcp_line = f'Using {len(runtime.config.mcp.stdio_servers)} stdio MCP servers, {len(runtime.config.mcp.sse_servers)} SSE MCP servers and {len(runtime.config.mcp.shttp_servers)} SHTTP MCP servers.'
|
||||
|
||||
# Check for MCP errors and add indicator to the same line
|
||||
if agent.config.enable_mcp and mcp_error_collector.has_errors():
|
||||
mcp_line += (
|
||||
' ✗ MCP errors detected (type /mcp → select View errors to view)'
|
||||
)
|
||||
|
||||
welcome_message += mcp_line + '\n\n'
|
||||
|
||||
welcome_message += 'What do you want to build?' # from the application
|
||||
initial_message = '' # from the user
|
||||
|
||||
if task_content:
|
||||
@ -488,6 +516,16 @@ async def main_with_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
if not env_log_level:
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
# If `config.toml` does not exist in current directory, use the file under home directory
|
||||
if not os.path.exists(args.config_file):
|
||||
home_config_file = os.path.join(
|
||||
os.path.expanduser('~'), '.openhands', 'config.toml'
|
||||
)
|
||||
logger.info(
|
||||
f'Config file {args.config_file} does not exist, using default config file in home directory: {home_config_file}.'
|
||||
)
|
||||
args.config_file = home_config_file
|
||||
|
||||
# Load config from toml and override with command line arguments
|
||||
config: OpenHandsConfig = setup_config_from_args(args)
|
||||
|
||||
|
||||
@ -4,6 +4,8 @@
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import datetime
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@ -36,6 +38,7 @@ from openhands.events.action import (
|
||||
ActionConfirmationStatus,
|
||||
ChangeAgentStateAction,
|
||||
CmdRunAction,
|
||||
MCPAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
@ -45,8 +48,10 @@ from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
MCPObservation,
|
||||
)
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.mcp.error_collector import mcp_error_collector
|
||||
|
||||
ENABLE_STREAMING = False # FIXME: this doesn't work
|
||||
|
||||
@ -76,6 +81,7 @@ COMMANDS = {
|
||||
'/new': 'Create a new conversation',
|
||||
'/settings': 'Display and modify current settings',
|
||||
'/resume': 'Resume the agent when paused',
|
||||
'/mcp': 'Manage MCP server configuration and view errors',
|
||||
}
|
||||
|
||||
print_lock = threading.Lock()
|
||||
@ -162,6 +168,7 @@ def display_welcome_message(message: str = '') -> None:
|
||||
print_formatted_text(
|
||||
HTML("<gold>Let's start building!</gold>\n"), style=DEFAULT_STYLE
|
||||
)
|
||||
|
||||
if message:
|
||||
print_formatted_text(
|
||||
HTML(f'{message} <grey>Type /help for help</grey>'),
|
||||
@ -186,6 +193,48 @@ def display_initial_user_prompt(prompt: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def display_mcp_errors() -> None:
|
||||
"""Display collected MCP errors."""
|
||||
errors = mcp_error_collector.get_errors()
|
||||
|
||||
if not errors:
|
||||
print_formatted_text(HTML('<ansigreen>✓ No MCP errors detected</ansigreen>\n'))
|
||||
return
|
||||
|
||||
print_formatted_text(
|
||||
HTML(
|
||||
f'<ansired>✗ {len(errors)} MCP error(s) detected during startup:</ansired>\n'
|
||||
)
|
||||
)
|
||||
|
||||
for i, error in enumerate(errors, 1):
|
||||
# Format timestamp
|
||||
timestamp = datetime.datetime.fromtimestamp(error.timestamp).strftime(
|
||||
'%H:%M:%S'
|
||||
)
|
||||
|
||||
# Create error display text
|
||||
error_text = (
|
||||
f'[{timestamp}] {error.server_type.upper()} Server: {error.server_name}\n'
|
||||
)
|
||||
error_text += f'Error: {error.error_message}\n'
|
||||
if error.exception_details:
|
||||
error_text += f'Details: {error.exception_details}'
|
||||
|
||||
container = Frame(
|
||||
TextArea(
|
||||
text=error_text,
|
||||
read_only=True,
|
||||
style='ansired',
|
||||
wrap_lines=True,
|
||||
),
|
||||
title=f'MCP Error #{i}',
|
||||
style='ansired',
|
||||
)
|
||||
print_container(container)
|
||||
print_formatted_text('') # Add spacing between errors
|
||||
|
||||
|
||||
# Prompt output display functions
|
||||
def display_thought_if_new(thought: str) -> None:
|
||||
"""Display a thought only if it hasn't been displayed recently."""
|
||||
@ -215,6 +264,8 @@ def display_event(event: Event, config: OpenHandsConfig) -> None:
|
||||
|
||||
if event.confirmation_state == ActionConfirmationStatus.CONFIRMED:
|
||||
initialize_streaming_output()
|
||||
elif isinstance(event, MCPAction):
|
||||
display_mcp_action(event)
|
||||
elif isinstance(event, Action):
|
||||
# For other actions, display thoughts normally
|
||||
if hasattr(event, 'thought') and event.thought:
|
||||
@ -232,6 +283,8 @@ def display_event(event: Event, config: OpenHandsConfig) -> None:
|
||||
display_file_edit(event)
|
||||
elif isinstance(event, FileReadObservation):
|
||||
display_file_read(event)
|
||||
elif isinstance(event, MCPObservation):
|
||||
display_mcp_observation(event)
|
||||
elif isinstance(event, AgentStateChangedObservation):
|
||||
display_agent_state_change_message(event.agent_state)
|
||||
elif isinstance(event, ErrorObservation):
|
||||
@ -337,6 +390,66 @@ def display_file_read(event: FileReadObservation) -> None:
|
||||
print_container(container)
|
||||
|
||||
|
||||
def display_mcp_action(event: MCPAction) -> None:
|
||||
"""Display an MCP action in the CLI."""
|
||||
# Format the arguments for display
|
||||
args_text = ''
|
||||
if event.arguments:
|
||||
try:
|
||||
args_text = json.dumps(event.arguments, indent=2)
|
||||
except (TypeError, ValueError):
|
||||
args_text = str(event.arguments)
|
||||
|
||||
# Create the display text
|
||||
display_text = f'Tool: {event.name}'
|
||||
if args_text:
|
||||
display_text += f'\n\nArguments:\n{args_text}'
|
||||
|
||||
container = Frame(
|
||||
TextArea(
|
||||
text=display_text,
|
||||
read_only=True,
|
||||
style='ansiblue',
|
||||
wrap_lines=True,
|
||||
),
|
||||
title='MCP Tool Call',
|
||||
style='ansiblue',
|
||||
)
|
||||
print_formatted_text('')
|
||||
print_container(container)
|
||||
|
||||
|
||||
def display_mcp_observation(event: MCPObservation) -> None:
|
||||
"""Display an MCP observation in the CLI."""
|
||||
# Format the content for display
|
||||
content = event.content.strip() if event.content else 'No output'
|
||||
|
||||
# Add tool name and arguments info if available
|
||||
display_text = content
|
||||
if event.name:
|
||||
header = f'Tool: {event.name}'
|
||||
if event.arguments:
|
||||
try:
|
||||
args_text = json.dumps(event.arguments, indent=2)
|
||||
header += f'\nArguments: {args_text}'
|
||||
except (TypeError, ValueError):
|
||||
header += f'\nArguments: {event.arguments}'
|
||||
display_text = f'{header}\n\nResult:\n{content}'
|
||||
|
||||
container = Frame(
|
||||
TextArea(
|
||||
text=display_text,
|
||||
read_only=True,
|
||||
style=COLOR_GREY,
|
||||
wrap_lines=True,
|
||||
),
|
||||
title='MCP Tool Result',
|
||||
style=f'fg:{COLOR_GREY}',
|
||||
)
|
||||
print_formatted_text('')
|
||||
print_container(container)
|
||||
|
||||
|
||||
def initialize_streaming_output():
|
||||
"""Initialize the streaming output TextArea."""
|
||||
if not ENABLE_STREAMING:
|
||||
|
||||
@ -1,8 +1,17 @@
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@ -11,6 +20,27 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
def _validate_mcp_url(url: str) -> str:
|
||||
"""Shared URL validation logic for MCP servers."""
|
||||
if not url.strip():
|
||||
raise ValueError('URL cannot be empty')
|
||||
|
||||
url = url.strip()
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme:
|
||||
raise ValueError('URL must include a scheme (http:// or https://)')
|
||||
if not parsed.netloc:
|
||||
raise ValueError('URL must include a valid domain/host')
|
||||
if parsed.scheme not in ['http', 'https', 'ws', 'wss']:
|
||||
raise ValueError('URL scheme must be http, https, ws, or wss')
|
||||
return url
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f'Invalid URL format: {str(e)}')
|
||||
|
||||
|
||||
class MCPSSEServerConfig(BaseModel):
|
||||
"""Configuration for a single MCP server.
|
||||
|
||||
@ -22,6 +52,12 @@ class MCPSSEServerConfig(BaseModel):
|
||||
url: str
|
||||
api_key: str | None = None
|
||||
|
||||
@field_validator('url')
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
"""Validate URL format for MCP servers."""
|
||||
return _validate_mcp_url(v)
|
||||
|
||||
|
||||
class MCPStdioServerConfig(BaseModel):
|
||||
"""Configuration for a MCP server that uses stdio.
|
||||
@ -38,6 +74,100 @@ class MCPStdioServerConfig(BaseModel):
|
||||
args: list[str] = Field(default_factory=list)
|
||||
env: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_server_name(cls, v: str) -> str:
|
||||
"""Validate server name for stdio MCP servers."""
|
||||
if not v.strip():
|
||||
raise ValueError('Server name cannot be empty')
|
||||
|
||||
v = v.strip()
|
||||
|
||||
# Check for valid characters (alphanumeric, hyphens, underscores)
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', v):
|
||||
raise ValueError(
|
||||
'Server name can only contain letters, numbers, hyphens, and underscores'
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('command')
|
||||
@classmethod
|
||||
def validate_command(cls, v: str) -> str:
|
||||
"""Validate command for stdio MCP servers."""
|
||||
if not v.strip():
|
||||
raise ValueError('Command cannot be empty')
|
||||
|
||||
v = v.strip()
|
||||
|
||||
# Check that command doesn't contain spaces (should be a single executable)
|
||||
if ' ' in v:
|
||||
raise ValueError(
|
||||
'Command should be a single executable without spaces (use arguments field for parameters)'
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('args', mode='before')
|
||||
@classmethod
|
||||
def parse_args(cls, v) -> list[str]:
|
||||
"""Parse arguments from string or return list as-is.
|
||||
|
||||
Supports shell-like argument parsing using shlex.split().
|
||||
Examples:
|
||||
- "-y mcp-remote https://example.com"
|
||||
- '--config "path with spaces" --debug'
|
||||
- "arg1 arg2 arg3"
|
||||
"""
|
||||
if isinstance(v, str):
|
||||
if not v.strip():
|
||||
return []
|
||||
|
||||
v = v.strip()
|
||||
|
||||
# Use shell-like parsing for natural argument handling
|
||||
try:
|
||||
return shlex.split(v)
|
||||
except ValueError as e:
|
||||
# If shlex parsing fails (e.g., unmatched quotes), provide clear error
|
||||
raise ValueError(
|
||||
f'Invalid argument format: {str(e)}. Use shell-like format, e.g., "arg1 arg2" or \'--config "value with spaces"\''
|
||||
)
|
||||
|
||||
return v or []
|
||||
|
||||
@field_validator('env', mode='before')
|
||||
@classmethod
|
||||
def parse_env(cls, v) -> dict[str, str]:
|
||||
"""Parse environment variables from string or return dict as-is."""
|
||||
if isinstance(v, str):
|
||||
if not v.strip():
|
||||
return {}
|
||||
|
||||
env = {}
|
||||
for pair in v.split(','):
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
|
||||
if '=' not in pair:
|
||||
raise ValueError(
|
||||
f"Environment variable '{pair}' must be in KEY=VALUE format"
|
||||
)
|
||||
|
||||
key, value = pair.split('=', 1)
|
||||
key = key.strip()
|
||||
if not key:
|
||||
raise ValueError('Environment variable key cannot be empty')
|
||||
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', key):
|
||||
raise ValueError(
|
||||
f"Invalid environment variable name '{key}'. Must start with letter or underscore, contain only alphanumeric characters and underscores"
|
||||
)
|
||||
|
||||
env[key] = value
|
||||
return env
|
||||
return v or {}
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Override equality operator to compare server configurations.
|
||||
|
||||
@ -59,6 +189,12 @@ class MCPSHTTPServerConfig(BaseModel):
|
||||
url: str
|
||||
api_key: str | None = None
|
||||
|
||||
@field_validator('url')
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
"""Validate URL format for MCP servers."""
|
||||
return _validate_mcp_url(v)
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Configuration for MCP (Message Control Protocol) settings.
|
||||
|
||||
@ -5,7 +5,7 @@ import platform
|
||||
import sys
|
||||
from ast import literal_eval
|
||||
from types import UnionType
|
||||
from typing import MutableMapping, get_args, get_origin
|
||||
from typing import MutableMapping, get_args, get_origin, get_type_hints
|
||||
from uuid import uuid4
|
||||
|
||||
import toml
|
||||
@ -154,8 +154,22 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None
|
||||
core_config = toml_config['core']
|
||||
|
||||
# Process core section if present
|
||||
cfg_type_hints = get_type_hints(cfg.__class__)
|
||||
for key, value in core_config.items():
|
||||
if hasattr(cfg, key):
|
||||
# Get expected type of the attribute
|
||||
expected_type = cfg_type_hints.get(key, None)
|
||||
|
||||
# Check if expected_type is a Union that includes SecretStr and value is str, e.g. search_api_key
|
||||
if expected_type:
|
||||
origin = get_origin(expected_type)
|
||||
args = get_args(expected_type)
|
||||
|
||||
if origin is UnionType and SecretStr in args and isinstance(value, str):
|
||||
value = SecretStr(value)
|
||||
elif expected_type is SecretStr and isinstance(value, str):
|
||||
value = SecretStr(value)
|
||||
|
||||
setattr(cfg, key, value)
|
||||
else:
|
||||
logger.openhands_logger.warning(
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from openhands.mcp.client import MCPClient
|
||||
from openhands.mcp.error_collector import mcp_error_collector
|
||||
from openhands.mcp.tool import MCPClientTool
|
||||
from openhands.mcp.utils import (
|
||||
add_mcp_tools_to_agent,
|
||||
@ -16,4 +17,5 @@ __all__ = [
|
||||
'fetch_mcp_tools_from_config',
|
||||
'call_tool_mcp',
|
||||
'add_mcp_tools_to_agent',
|
||||
'mcp_error_collector',
|
||||
]
|
||||
|
||||
@ -1,13 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastmcp import Client
|
||||
from fastmcp.client.transports import SSETransport, StreamableHttpTransport
|
||||
from fastmcp.client.transports import (
|
||||
SSETransport,
|
||||
StdioTransport,
|
||||
StreamableHttpTransport,
|
||||
)
|
||||
from mcp import McpError
|
||||
from mcp.types import CallToolResult
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from openhands.core.config.mcp_config import MCPSHTTPServerConfig, MCPSSEServerConfig
|
||||
from openhands.core.config.mcp_config import (
|
||||
MCPSHTTPServerConfig,
|
||||
MCPSSEServerConfig,
|
||||
MCPStdioServerConfig,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.mcp.error_collector import mcp_error_collector
|
||||
from openhands.mcp.tool import MCPClientTool
|
||||
|
||||
|
||||
@ -90,11 +99,51 @@ class MCPClient(BaseModel):
|
||||
|
||||
await self._initialize_and_list_tools()
|
||||
except McpError as e:
|
||||
logger.error(f'McpError connecting to {server_url}: {e}')
|
||||
error_msg = f'McpError connecting to {server_url}: {e}'
|
||||
logger.error(error_msg)
|
||||
mcp_error_collector.add_error(
|
||||
server_name=server_url,
|
||||
server_type='shttp'
|
||||
if isinstance(server, MCPSHTTPServerConfig)
|
||||
else 'sse',
|
||||
error_message=error_msg,
|
||||
exception_details=str(e),
|
||||
)
|
||||
raise # Re-raise the error
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Error connecting to {server_url}: {e}')
|
||||
error_msg = f'Error connecting to {server_url}: {e}'
|
||||
logger.error(error_msg)
|
||||
mcp_error_collector.add_error(
|
||||
server_name=server_url,
|
||||
server_type='shttp'
|
||||
if isinstance(server, MCPSHTTPServerConfig)
|
||||
else 'sse',
|
||||
error_message=error_msg,
|
||||
exception_details=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
async def connect_stdio(self, server: MCPStdioServerConfig, timeout: float = 30.0):
|
||||
"""Connect to MCP server using stdio transport"""
|
||||
try:
|
||||
transport = StdioTransport(
|
||||
command=server.command, args=server.args or [], env=server.env
|
||||
)
|
||||
self.client = Client(transport, timeout=timeout)
|
||||
await self._initialize_and_list_tools()
|
||||
except Exception as e:
|
||||
server_name = getattr(
|
||||
server, 'name', f'{server.command} {" ".join(server.args or [])}'
|
||||
)
|
||||
error_msg = f'Failed to connect to stdio server {server_name}: {e}'
|
||||
logger.error(error_msg)
|
||||
mcp_error_collector.add_error(
|
||||
server_name=server_name,
|
||||
server_type='stdio',
|
||||
error_message=error_msg,
|
||||
exception_details=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
async def call_tool(self, tool_name: str, args: dict) -> CallToolResult:
|
||||
|
||||
78
openhands/mcp/error_collector.py
Normal file
78
openhands/mcp/error_collector.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""MCP Error Collector for capturing and storing MCP-related errors during startup."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPError:
|
||||
"""Represents an MCP-related error."""
|
||||
|
||||
timestamp: float
|
||||
server_name: str
|
||||
server_type: str # 'stdio', 'sse', 'shttp'
|
||||
error_message: str
|
||||
exception_details: str | None = None
|
||||
|
||||
|
||||
class MCPErrorCollector:
|
||||
"""Thread-safe collector for MCP errors during startup."""
|
||||
|
||||
def __init__(self):
|
||||
self._errors: list[MCPError] = []
|
||||
self._lock = threading.Lock()
|
||||
self._collection_enabled = True
|
||||
|
||||
def add_error(
|
||||
self,
|
||||
server_name: str,
|
||||
server_type: str,
|
||||
error_message: str,
|
||||
exception_details: str | None = None,
|
||||
) -> None:
|
||||
"""Add an MCP error to the collection."""
|
||||
if not self._collection_enabled:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
error = MCPError(
|
||||
timestamp=time.time(),
|
||||
server_name=server_name,
|
||||
server_type=server_type,
|
||||
error_message=error_message,
|
||||
exception_details=exception_details,
|
||||
)
|
||||
self._errors.append(error)
|
||||
|
||||
def get_errors(self) -> list[MCPError]:
|
||||
"""Get a copy of all collected errors."""
|
||||
with self._lock:
|
||||
return self._errors.copy()
|
||||
|
||||
def has_errors(self) -> bool:
|
||||
"""Check if there are any collected errors."""
|
||||
with self._lock:
|
||||
return len(self._errors) > 0
|
||||
|
||||
def clear_errors(self) -> None:
|
||||
"""Clear all collected errors."""
|
||||
with self._lock:
|
||||
self._errors.clear()
|
||||
|
||||
def disable_collection(self) -> None:
|
||||
"""Disable error collection (useful after startup)."""
|
||||
self._collection_enabled = False
|
||||
|
||||
def enable_collection(self) -> None:
|
||||
"""Enable error collection."""
|
||||
self._collection_enabled = True
|
||||
|
||||
def get_error_count(self) -> int:
|
||||
"""Get the number of collected errors."""
|
||||
with self._lock:
|
||||
return len(self._errors)
|
||||
|
||||
|
||||
# Global instance for collecting MCP errors
|
||||
mcp_error_collector = MCPErrorCollector()
|
||||
@ -11,14 +11,17 @@ from openhands.core.config.mcp_config import (
|
||||
MCPConfig,
|
||||
MCPSHTTPServerConfig,
|
||||
MCPSSEServerConfig,
|
||||
MCPStdioServerConfig,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.mcp.client import MCPClient
|
||||
from openhands.mcp.error_collector import mcp_error_collector
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||
|
||||
|
||||
def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[dict]:
|
||||
@ -45,7 +48,14 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di
|
||||
mcp_tools = tool.to_param()
|
||||
all_mcp_tools.append(mcp_tools)
|
||||
except Exception as e:
|
||||
logger.error(f'Error in convert_mcp_clients_to_tools: {e}')
|
||||
error_msg = f'Error in convert_mcp_clients_to_tools: {e}'
|
||||
logger.error(error_msg)
|
||||
mcp_error_collector.add_error(
|
||||
server_name='general',
|
||||
server_type='conversion',
|
||||
error_message=error_msg,
|
||||
exception_details=str(e),
|
||||
)
|
||||
return []
|
||||
return all_mcp_tools
|
||||
|
||||
@ -54,6 +64,7 @@ async def create_mcp_clients(
|
||||
sse_servers: list[MCPSSEServerConfig],
|
||||
shttp_servers: list[MCPSHTTPServerConfig],
|
||||
conversation_id: str | None = None,
|
||||
stdio_servers: list[MCPStdioServerConfig] | None = None,
|
||||
) -> list[MCPClient]:
|
||||
import sys
|
||||
|
||||
@ -64,9 +75,13 @@ async def create_mcp_clients(
|
||||
)
|
||||
return []
|
||||
|
||||
servers: list[MCPSSEServerConfig | MCPSHTTPServerConfig] = [
|
||||
if stdio_servers is None:
|
||||
stdio_servers = []
|
||||
|
||||
servers: list[MCPSSEServerConfig | MCPSHTTPServerConfig | MCPStdioServerConfig] = [
|
||||
*sse_servers,
|
||||
*shttp_servers,
|
||||
*stdio_servers,
|
||||
]
|
||||
|
||||
if not servers:
|
||||
@ -75,6 +90,17 @@ async def create_mcp_clients(
|
||||
mcp_clients = []
|
||||
|
||||
for server in servers:
|
||||
if isinstance(server, MCPStdioServerConfig):
|
||||
logger.info(f'Initializing MCP agent for {server} with stdio connection...')
|
||||
client = MCPClient()
|
||||
try:
|
||||
await client.connect_stdio(server)
|
||||
mcp_clients.append(client)
|
||||
except Exception as e:
|
||||
# Error is already logged and collected in client.connect_stdio()
|
||||
logger.error(f'Failed to connect to {server}: {str(e)}', exc_info=True)
|
||||
continue
|
||||
|
||||
is_shttp = isinstance(server, MCPSHTTPServerConfig)
|
||||
connection_type = 'SHTTP' if is_shttp else 'SSE'
|
||||
logger.info(
|
||||
@ -89,13 +115,14 @@ async def create_mcp_clients(
|
||||
mcp_clients.append(client)
|
||||
|
||||
except Exception as e:
|
||||
# Error is already logged and collected in client.connect_http()
|
||||
logger.error(f'Failed to connect to {server}: {str(e)}', exc_info=True)
|
||||
|
||||
return mcp_clients
|
||||
|
||||
|
||||
async def fetch_mcp_tools_from_config(
|
||||
mcp_config: MCPConfig, conversation_id: str | None = None
|
||||
mcp_config: MCPConfig, conversation_id: str | None = None, use_stdio: bool = False
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Retrieves the list of MCP tools from the MCP clients.
|
||||
@ -103,6 +130,7 @@ async def fetch_mcp_tools_from_config(
|
||||
Args:
|
||||
mcp_config: The MCP configuration
|
||||
conversation_id: Optional conversation ID to associate with the MCP clients
|
||||
use_stdio: Whether to use stdio servers for MCP clients, set to True when running from a CLI runtime
|
||||
|
||||
Returns:
|
||||
A list of tool dictionaries. Returns an empty list if no connections could be established.
|
||||
@ -120,7 +148,10 @@ async def fetch_mcp_tools_from_config(
|
||||
logger.debug(f'Creating MCP clients with config: {mcp_config}')
|
||||
# Create clients - this will fetch tools but not maintain active connections
|
||||
mcp_clients = await create_mcp_clients(
|
||||
mcp_config.sse_servers, mcp_config.shttp_servers, conversation_id
|
||||
mcp_config.sse_servers,
|
||||
mcp_config.shttp_servers,
|
||||
conversation_id,
|
||||
mcp_config.stdio_servers if use_stdio else [],
|
||||
)
|
||||
|
||||
if not mcp_clients:
|
||||
@ -131,7 +162,14 @@ async def fetch_mcp_tools_from_config(
|
||||
mcp_tools = convert_mcp_clients_to_tools(mcp_clients)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Error fetching MCP tools: {str(e)}')
|
||||
error_msg = f'Error fetching MCP tools: {str(e)}'
|
||||
logger.error(error_msg)
|
||||
mcp_error_collector.add_error(
|
||||
server_name='general',
|
||||
server_type='fetch',
|
||||
error_message=error_msg,
|
||||
exception_details=str(e),
|
||||
)
|
||||
return []
|
||||
|
||||
logger.debug(f'MCP tools: {mcp_tools}')
|
||||
@ -200,7 +238,9 @@ async def call_tool_mcp(mcp_clients: list[MCPClient], action: MCPAction) -> Obse
|
||||
)
|
||||
|
||||
|
||||
async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memory'):
|
||||
async def add_mcp_tools_to_agent(
|
||||
agent: 'Agent', runtime: Runtime, memory: 'Memory'
|
||||
) -> MCPConfig:
|
||||
"""
|
||||
Add MCP tools to an agent.
|
||||
"""
|
||||
@ -231,13 +271,18 @@ async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memo
|
||||
# Check if this stdio server is already in the config
|
||||
if stdio_server not in extra_stdio_servers:
|
||||
extra_stdio_servers.append(stdio_server)
|
||||
logger.info(f'Added microagent stdio server: {stdio_server.name}')
|
||||
logger.warning(
|
||||
f'Added microagent stdio server: {stdio_server.name}'
|
||||
)
|
||||
|
||||
# Add the runtime as another MCP server
|
||||
updated_mcp_config = runtime.get_mcp_config(extra_stdio_servers)
|
||||
|
||||
# Fetch the MCP tools
|
||||
mcp_tools = await fetch_mcp_tools_from_config(updated_mcp_config)
|
||||
# Only use stdio if run from a CLI runtime
|
||||
mcp_tools = await fetch_mcp_tools_from_config(
|
||||
updated_mcp_config, use_stdio=isinstance(runtime, CLIRuntime)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Loaded {len(mcp_tools)} MCP tools: {[tool["function"]["name"] for tool in mcp_tools]}'
|
||||
@ -245,3 +290,5 @@ async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memo
|
||||
|
||||
# Set the MCP tools on the agent
|
||||
agent.set_mcp_tools(mcp_tools)
|
||||
|
||||
return updated_mcp_config
|
||||
|
||||
@ -689,8 +689,69 @@ class CLIRuntime(Runtime):
|
||||
)
|
||||
|
||||
async def call_tool_mcp(self, action: MCPAction) -> Observation:
|
||||
"""Not implemented for CLI runtime."""
|
||||
return ErrorObservation('MCP functionality is not implemented in CLIRuntime')
|
||||
"""Execute an MCP tool action in CLI runtime.
|
||||
|
||||
Args:
|
||||
action: The MCP action to execute
|
||||
|
||||
Returns:
|
||||
Observation: The result of the MCP tool execution
|
||||
"""
|
||||
# Check if we're on Windows - MCP is disabled on Windows
|
||||
if sys.platform == 'win32':
|
||||
self.log('info', 'MCP functionality is disabled on Windows')
|
||||
return ErrorObservation('MCP functionality is not available on Windows')
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from openhands.mcp.utils import call_tool_mcp as call_tool_mcp_handler
|
||||
from openhands.mcp.utils import create_mcp_clients
|
||||
|
||||
try:
|
||||
# Get the MCP config for this runtime
|
||||
mcp_config = self.get_mcp_config()
|
||||
|
||||
if (
|
||||
not mcp_config.sse_servers
|
||||
and not mcp_config.shttp_servers
|
||||
and not mcp_config.stdio_servers
|
||||
):
|
||||
self.log('warning', 'No MCP servers configured')
|
||||
return ErrorObservation('No MCP servers configured')
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
f'Creating MCP clients for action {action.name} with servers: '
|
||||
f'SSE={len(mcp_config.sse_servers)}, SHTTP={len(mcp_config.shttp_servers)}, '
|
||||
f'stdio={len(mcp_config.stdio_servers)}',
|
||||
)
|
||||
|
||||
# Create clients for this specific operation
|
||||
mcp_clients = await create_mcp_clients(
|
||||
mcp_config.sse_servers,
|
||||
mcp_config.shttp_servers,
|
||||
self.sid,
|
||||
mcp_config.stdio_servers,
|
||||
)
|
||||
|
||||
if not mcp_clients:
|
||||
self.log('warning', 'No MCP clients could be created')
|
||||
return ErrorObservation(
|
||||
'No MCP clients could be created - check server configurations'
|
||||
)
|
||||
|
||||
# Call the tool and return the result
|
||||
self.log(
|
||||
'debug',
|
||||
f'Executing MCP tool: {action.name} with arguments: {action.arguments}',
|
||||
)
|
||||
result = await call_tool_mcp_handler(mcp_clients, action)
|
||||
self.log('debug', f'MCP tool {action.name} executed successfully')
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f'Error executing MCP tool {action.name}: {str(e)}'
|
||||
self.log('error', error_msg)
|
||||
return ErrorObservation(error_msg)
|
||||
|
||||
@property
|
||||
def workspace_root(self) -> Path:
|
||||
@ -869,8 +930,40 @@ class CLIRuntime(Runtime):
|
||||
def get_mcp_config(
|
||||
self, extra_stdio_servers: list[MCPStdioServerConfig] | None = None
|
||||
) -> MCPConfig:
|
||||
# TODO: Load MCP config from a local file
|
||||
return MCPConfig()
|
||||
"""Get MCP configuration for CLI runtime.
|
||||
|
||||
Args:
|
||||
extra_stdio_servers: Additional stdio servers to include in the config
|
||||
|
||||
Returns:
|
||||
MCPConfig: The MCP configuration with stdio servers and any configured SSE/SHTTP servers
|
||||
"""
|
||||
# Check if we're on Windows - MCP is disabled on Windows
|
||||
if sys.platform == 'win32':
|
||||
self.log('debug', 'MCP is disabled on Windows, returning empty config')
|
||||
return MCPConfig(sse_servers=[], stdio_servers=[], shttp_servers=[])
|
||||
|
||||
# Note: we update the self.config.mcp directly for CLI runtime, which is different from other runtimes.
|
||||
mcp_config = self.config.mcp
|
||||
|
||||
# Add any extra stdio servers
|
||||
if extra_stdio_servers:
|
||||
current_stdio_servers = list(mcp_config.stdio_servers)
|
||||
for extra_server in extra_stdio_servers:
|
||||
# Check if this stdio server is already in the config
|
||||
if extra_server not in current_stdio_servers:
|
||||
current_stdio_servers.append(extra_server)
|
||||
self.log('info', f'Added extra stdio server: {extra_server.name}')
|
||||
mcp_config.stdio_servers = current_stdio_servers
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
f'CLI MCP config: {len(mcp_config.sse_servers)} SSE servers, '
|
||||
f'{len(mcp_config.stdio_servers)} stdio servers, '
|
||||
f'{len(mcp_config.shttp_servers)} SHTTP servers',
|
||||
)
|
||||
|
||||
return mcp_config
|
||||
|
||||
def subscribe_to_shell_stream(
|
||||
self, callback: Callable[[str], None] | None = None
|
||||
|
||||
@ -3,10 +3,12 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from openhands.cli.commands import (
|
||||
display_mcp_servers,
|
||||
handle_commands,
|
||||
handle_exit_command,
|
||||
handle_help_command,
|
||||
handle_init_command,
|
||||
handle_mcp_command,
|
||||
handle_new_command,
|
||||
handle_resume_command,
|
||||
handle_settings_command,
|
||||
@ -143,6 +145,18 @@ class TestHandleCommands:
|
||||
assert reload_microagents is False
|
||||
assert new_session is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.cli.commands.handle_mcp_command')
|
||||
async def test_handle_mcp_command(self, mock_handle_mcp, mock_dependencies):
|
||||
close_repl, reload_microagents, new_session, _ = await handle_commands(
|
||||
'/mcp', **mock_dependencies
|
||||
)
|
||||
|
||||
mock_handle_mcp.assert_called_once_with(mock_dependencies['config'])
|
||||
assert close_repl is False
|
||||
assert reload_microagents is False
|
||||
assert new_session is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unknown_command(self, mock_dependencies):
|
||||
user_message = 'Hello, this is not a command'
|
||||
@ -219,6 +233,78 @@ class TestHandleHelpCommand:
|
||||
mock_display_help.assert_called_once()
|
||||
|
||||
|
||||
class TestDisplayMcpServers:
|
||||
@patch('openhands.cli.commands.print_formatted_text')
|
||||
def test_display_mcp_servers_no_servers(self, mock_print):
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
|
||||
config = MagicMock(spec=OpenHandsConfig)
|
||||
config.mcp = MCPConfig() # Empty config with no servers
|
||||
|
||||
display_mcp_servers(config)
|
||||
|
||||
mock_print.assert_called_once()
|
||||
call_args = mock_print.call_args[0][0]
|
||||
assert 'No custom MCP servers configured' in call_args
|
||||
assert (
|
||||
'https://docs.all-hands.dev/usage/how-to/cli-mode#using-mcp-servers'
|
||||
in call_args
|
||||
)
|
||||
|
||||
@patch('openhands.cli.commands.print_formatted_text')
|
||||
def test_display_mcp_servers_with_servers(self, mock_print):
|
||||
from openhands.core.config.mcp_config import (
|
||||
MCPConfig,
|
||||
MCPSHTTPServerConfig,
|
||||
MCPSSEServerConfig,
|
||||
MCPStdioServerConfig,
|
||||
)
|
||||
|
||||
config = MagicMock(spec=OpenHandsConfig)
|
||||
config.mcp = MCPConfig(
|
||||
sse_servers=[MCPSSEServerConfig(url='https://example.com/sse')],
|
||||
stdio_servers=[MCPStdioServerConfig(name='tavily', command='npx')],
|
||||
shttp_servers=[MCPSHTTPServerConfig(url='http://localhost:3000/mcp')],
|
||||
)
|
||||
|
||||
display_mcp_servers(config)
|
||||
|
||||
# Should be called multiple times for different sections
|
||||
assert mock_print.call_count >= 4
|
||||
|
||||
# Check that the summary is printed
|
||||
first_call = mock_print.call_args_list[0][0][0]
|
||||
assert 'Configured MCP servers:' in first_call
|
||||
assert 'SSE servers: 1' in first_call
|
||||
assert 'Stdio servers: 1' in first_call
|
||||
assert 'SHTTP servers: 1' in first_call
|
||||
assert 'Total: 3' in first_call
|
||||
|
||||
|
||||
class TestHandleMcpCommand:
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.cli.commands.cli_confirm')
|
||||
@patch('openhands.cli.commands.display_mcp_servers')
|
||||
async def test_handle_mcp_command_list_action(self, mock_display, mock_cli_confirm):
|
||||
config = MagicMock(spec=OpenHandsConfig)
|
||||
mock_cli_confirm.return_value = 0 # List action
|
||||
|
||||
await handle_mcp_command(config)
|
||||
|
||||
mock_cli_confirm.assert_called_once_with(
|
||||
config,
|
||||
'MCP Server Configuration',
|
||||
[
|
||||
'List configured servers',
|
||||
'Add new server',
|
||||
'Remove server',
|
||||
'View errors',
|
||||
'Go back',
|
||||
],
|
||||
)
|
||||
mock_display.assert_called_once_with(config)
|
||||
|
||||
|
||||
class TestHandleStatusCommand:
|
||||
@patch('openhands.cli.commands.display_status')
|
||||
def test_status_command(self, mock_display_status):
|
||||
@ -496,3 +582,16 @@ class TestHandleResumeCommand:
|
||||
# Check the return values
|
||||
assert close_repl is True
|
||||
assert new_session_requested is False
|
||||
|
||||
|
||||
class TestMCPErrorHandling:
|
||||
"""Test MCP error handling in commands."""
|
||||
|
||||
@patch('openhands.cli.commands.display_mcp_errors')
|
||||
def test_handle_mcp_errors_command(self, mock_display_errors):
|
||||
"""Test handling MCP errors command."""
|
||||
from openhands.cli.commands import handle_mcp_errors_command
|
||||
|
||||
handle_mcp_errors_command()
|
||||
|
||||
mock_display_errors.assert_called_once()
|
||||
|
||||
106
tests/unit/test_cli_config_management.py
Normal file
106
tests/unit/test_cli_config_management.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Tests for CLI server management functionality."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.cli.commands import (
|
||||
display_mcp_servers,
|
||||
remove_mcp_server,
|
||||
)
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.mcp_config import (
|
||||
MCPConfig,
|
||||
MCPSSEServerConfig,
|
||||
MCPStdioServerConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestMCPServerManagement:
|
||||
"""Test MCP server management functions."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.config = MagicMock(spec=OpenHandsConfig)
|
||||
self.config.cli = MagicMock()
|
||||
self.config.cli.vi_mode = False
|
||||
|
||||
@patch('openhands.cli.commands.print_formatted_text')
|
||||
def test_display_mcp_servers_no_servers(self, mock_print):
|
||||
"""Test displaying MCP servers when none are configured."""
|
||||
self.config.mcp = MCPConfig() # Empty config
|
||||
|
||||
display_mcp_servers(self.config)
|
||||
|
||||
mock_print.assert_called_once()
|
||||
call_args = mock_print.call_args[0][0]
|
||||
assert 'No custom MCP servers configured' in call_args
|
||||
|
||||
@patch('openhands.cli.commands.print_formatted_text')
|
||||
def test_display_mcp_servers_with_servers(self, mock_print):
|
||||
"""Test displaying MCP servers when some are configured."""
|
||||
self.config.mcp = MCPConfig(
|
||||
sse_servers=[MCPSSEServerConfig(url='http://test.com')],
|
||||
stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')],
|
||||
)
|
||||
|
||||
display_mcp_servers(self.config)
|
||||
|
||||
# Should be called multiple times for different sections
|
||||
assert mock_print.call_count >= 2
|
||||
|
||||
# Check that the summary is printed
|
||||
first_call = mock_print.call_args_list[0][0][0]
|
||||
assert 'Configured MCP servers:' in first_call
|
||||
assert 'SSE servers: 1' in first_call
|
||||
assert 'Stdio servers: 1' in first_call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.cli.commands.cli_confirm')
|
||||
@patch('openhands.cli.commands.print_formatted_text')
|
||||
async def test_remove_mcp_server_no_servers(self, mock_print, mock_cli_confirm):
|
||||
"""Test removing MCP server when none are configured."""
|
||||
self.config.mcp = MCPConfig() # Empty config
|
||||
|
||||
await remove_mcp_server(self.config)
|
||||
|
||||
mock_print.assert_called_once_with('No MCP servers configured to remove.')
|
||||
mock_cli_confirm.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.cli.commands.cli_confirm')
|
||||
@patch('openhands.cli.commands.load_config_file')
|
||||
@patch('openhands.cli.commands.save_config_file')
|
||||
@patch('openhands.cli.commands.print_formatted_text')
|
||||
async def test_remove_mcp_server_success(
|
||||
self, mock_print, mock_save, mock_load, mock_cli_confirm
|
||||
):
|
||||
"""Test successfully removing an MCP server."""
|
||||
# Set up config with servers
|
||||
self.config.mcp = MCPConfig(
|
||||
sse_servers=[MCPSSEServerConfig(url='http://test.com')],
|
||||
stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')],
|
||||
)
|
||||
|
||||
# Mock user selections
|
||||
mock_cli_confirm.side_effect = [0, 0] # Select first server, confirm removal
|
||||
|
||||
# Mock config file operations
|
||||
mock_load.return_value = {
|
||||
'mcp': {
|
||||
'sse_servers': [{'url': 'http://test.com'}],
|
||||
'stdio_servers': [{'name': 'test-stdio', 'command': 'python'}],
|
||||
}
|
||||
}
|
||||
|
||||
await remove_mcp_server(self.config)
|
||||
|
||||
# Should have been called twice (select server, confirm removal)
|
||||
assert mock_cli_confirm.call_count == 2
|
||||
mock_save.assert_called_once()
|
||||
|
||||
# Check that success message was printed
|
||||
success_calls = [
|
||||
call for call in mock_print.call_args_list if 'removed' in str(call[0][0])
|
||||
]
|
||||
assert len(success_calls) >= 1
|
||||
156
tests/unit/test_cli_runtime_mcp.py
Normal file
156
tests/unit/test_cli_runtime_mcp.py
Normal file
@ -0,0 +1,156 @@
|
||||
"""Tests for CLI Runtime MCP functionality."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.mcp_config import (
|
||||
MCPConfig,
|
||||
MCPSSEServerConfig,
|
||||
MCPStdioServerConfig,
|
||||
)
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation import ErrorObservation
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||
|
||||
|
||||
class TestCLIRuntimeMCP:
|
||||
"""Test MCP functionality in CLI Runtime."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.config = OpenHandsConfig()
|
||||
self.event_stream = MagicMock()
|
||||
self.runtime = CLIRuntime(
|
||||
config=self.config, event_stream=self.event_stream, sid='test-session'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_mcp_no_servers_configured(self):
|
||||
"""Test MCP call with no servers configured."""
|
||||
# Set up empty MCP config
|
||||
self.runtime.config.mcp = MCPConfig()
|
||||
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
|
||||
|
||||
with patch('sys.platform', 'linux'):
|
||||
result = await self.runtime.call_tool_mcp(action)
|
||||
|
||||
assert isinstance(result, ErrorObservation)
|
||||
assert 'No MCP servers configured' in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.mcp.utils.create_mcp_clients')
|
||||
async def test_call_tool_mcp_no_clients_created(self, mock_create_clients):
|
||||
"""Test MCP call when no clients can be created."""
|
||||
# Set up MCP config with servers
|
||||
self.runtime.config.mcp = MCPConfig(
|
||||
sse_servers=[MCPSSEServerConfig(url='http://test.com')]
|
||||
)
|
||||
|
||||
# Mock create_mcp_clients to return empty list
|
||||
mock_create_clients.return_value = []
|
||||
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
|
||||
|
||||
with patch('sys.platform', 'linux'):
|
||||
result = await self.runtime.call_tool_mcp(action)
|
||||
|
||||
assert isinstance(result, ErrorObservation)
|
||||
assert 'No MCP clients could be created' in result.content
|
||||
mock_create_clients.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.mcp.utils.create_mcp_clients')
|
||||
@patch('openhands.mcp.utils.call_tool_mcp')
|
||||
async def test_call_tool_mcp_success(self, mock_call_tool, mock_create_clients):
|
||||
"""Test successful MCP tool call."""
|
||||
# Set up MCP config with servers
|
||||
self.runtime.config.mcp = MCPConfig(
|
||||
sse_servers=[MCPSSEServerConfig(url='http://test.com')],
|
||||
stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')],
|
||||
)
|
||||
|
||||
# Mock successful client creation
|
||||
mock_client = MagicMock()
|
||||
mock_create_clients.return_value = [mock_client]
|
||||
|
||||
# Mock successful tool call
|
||||
expected_observation = MCPObservation(
|
||||
content='{"result": "success"}',
|
||||
name='test_tool',
|
||||
arguments={'arg1': 'value1'},
|
||||
)
|
||||
mock_call_tool.return_value = expected_observation
|
||||
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
|
||||
|
||||
with patch('sys.platform', 'linux'):
|
||||
result = await self.runtime.call_tool_mcp(action)
|
||||
|
||||
assert result == expected_observation
|
||||
mock_create_clients.assert_called_once_with(
|
||||
self.runtime.config.mcp.sse_servers,
|
||||
self.runtime.config.mcp.shttp_servers,
|
||||
self.runtime.sid,
|
||||
self.runtime.config.mcp.stdio_servers,
|
||||
)
|
||||
mock_call_tool.assert_called_once_with([mock_client], action)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.mcp.utils.create_mcp_clients')
|
||||
async def test_call_tool_mcp_exception_handling(self, mock_create_clients):
|
||||
"""Test exception handling in MCP tool call."""
|
||||
# Set up MCP config with servers
|
||||
self.runtime.config.mcp = MCPConfig(
|
||||
sse_servers=[MCPSSEServerConfig(url='http://test.com')]
|
||||
)
|
||||
|
||||
# Mock create_mcp_clients to raise an exception
|
||||
mock_create_clients.side_effect = Exception('Connection error')
|
||||
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
|
||||
|
||||
with patch('sys.platform', 'linux'):
|
||||
result = await self.runtime.call_tool_mcp(action)
|
||||
|
||||
assert isinstance(result, ErrorObservation)
|
||||
assert 'Error executing MCP tool test_tool' in result.content
|
||||
assert 'Connection error' in result.content
|
||||
|
||||
def test_get_mcp_config_basic(self):
|
||||
"""Test basic MCP config retrieval."""
|
||||
# Set up MCP config
|
||||
expected_config = MCPConfig(
|
||||
sse_servers=[MCPSSEServerConfig(url='http://test.com')],
|
||||
stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')],
|
||||
)
|
||||
self.runtime.config.mcp = expected_config
|
||||
|
||||
with patch('sys.platform', 'linux'):
|
||||
result = self.runtime.get_mcp_config()
|
||||
|
||||
assert result == expected_config
|
||||
|
||||
def test_get_mcp_config_with_extra_stdio_servers(self):
|
||||
"""Test MCP config with extra stdio servers."""
|
||||
# Set up initial MCP config
|
||||
initial_stdio_server = MCPStdioServerConfig(name='initial', command='python')
|
||||
self.runtime.config.mcp = MCPConfig(stdio_servers=[initial_stdio_server])
|
||||
|
||||
# Add extra stdio servers
|
||||
extra_servers = [
|
||||
MCPStdioServerConfig(name='extra1', command='node'),
|
||||
MCPStdioServerConfig(name='extra2', command='java'),
|
||||
]
|
||||
|
||||
with patch('sys.platform', 'linux'):
|
||||
result = self.runtime.get_mcp_config(extra_stdio_servers=extra_servers)
|
||||
|
||||
# Should have all three servers
|
||||
assert len(result.stdio_servers) == 3
|
||||
assert initial_stdio_server in result.stdio_servers
|
||||
assert extra_servers[0] in result.stdio_servers
|
||||
assert extra_servers[1] in result.stdio_servers
|
||||
@ -9,6 +9,9 @@ from openhands.cli.tui import (
|
||||
display_banner,
|
||||
display_command,
|
||||
display_event,
|
||||
display_mcp_action,
|
||||
display_mcp_errors,
|
||||
display_mcp_observation,
|
||||
display_message,
|
||||
display_runtime_initialization_message,
|
||||
display_shutdown_message,
|
||||
@ -24,14 +27,17 @@ from openhands.events.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
CmdRunAction,
|
||||
MCPAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
MCPObservation,
|
||||
)
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.mcp.error_collector import MCPError
|
||||
|
||||
|
||||
class TestDisplayFunctions:
|
||||
@ -72,6 +78,35 @@ class TestDisplayFunctions:
|
||||
args, kwargs = mock_print.call_args_list[0]
|
||||
assert "Let's start building" in str(args[0])
|
||||
|
||||
@patch('openhands.cli.tui.print_formatted_text')
|
||||
def test_display_welcome_message_with_message(self, mock_print):
|
||||
message = 'Test message'
|
||||
display_welcome_message(message)
|
||||
assert mock_print.call_count == 2
|
||||
# Check the first call contains the welcome message
|
||||
args, kwargs = mock_print.call_args_list[0]
|
||||
message_text = str(args[0])
|
||||
assert "Let's start building" in message_text
|
||||
# Check the second call contains the custom message
|
||||
args, kwargs = mock_print.call_args_list[1]
|
||||
message_text = str(args[0])
|
||||
assert 'Test message' in message_text
|
||||
assert 'Type /help for help' in message_text
|
||||
|
||||
@patch('openhands.cli.tui.print_formatted_text')
|
||||
def test_display_welcome_message_without_message(self, mock_print):
|
||||
display_welcome_message()
|
||||
assert mock_print.call_count == 2
|
||||
# Check the first call contains the welcome message
|
||||
args, kwargs = mock_print.call_args_list[0]
|
||||
message_text = str(args[0])
|
||||
assert "Let's start building" in message_text
|
||||
# Check the second call contains the default message
|
||||
args, kwargs = mock_print.call_args_list[1]
|
||||
message_text = str(args[0])
|
||||
assert 'What do you want to build?' in message_text
|
||||
assert 'Type /help for help' in message_text
|
||||
|
||||
@patch('openhands.cli.tui.display_message')
|
||||
def test_display_event_message_action(self, mock_display_message):
|
||||
config = MagicMock(spec=OpenHandsConfig)
|
||||
@ -147,6 +182,71 @@ class TestDisplayFunctions:
|
||||
|
||||
mock_display_message.assert_called_once_with('Thinking about this...')
|
||||
|
||||
@patch('openhands.cli.tui.display_mcp_action')
|
||||
def test_display_event_mcp_action(self, mock_display_mcp_action):
|
||||
config = MagicMock(spec=OpenHandsConfig)
|
||||
mcp_action = MCPAction(name='test_tool', arguments={'param': 'value'})
|
||||
|
||||
display_event(mcp_action, config)
|
||||
|
||||
mock_display_mcp_action.assert_called_once_with(mcp_action)
|
||||
|
||||
@patch('openhands.cli.tui.display_mcp_observation')
|
||||
def test_display_event_mcp_observation(self, mock_display_mcp_observation):
|
||||
config = MagicMock(spec=OpenHandsConfig)
|
||||
mcp_observation = MCPObservation(
|
||||
content='Tool result', name='test_tool', arguments={'param': 'value'}
|
||||
)
|
||||
|
||||
display_event(mcp_observation, config)
|
||||
|
||||
mock_display_mcp_observation.assert_called_once_with(mcp_observation)
|
||||
|
||||
@patch('openhands.cli.tui.print_container')
|
||||
def test_display_mcp_action(self, mock_print_container):
|
||||
mcp_action = MCPAction(name='test_tool', arguments={'param': 'value'})
|
||||
|
||||
display_mcp_action(mcp_action)
|
||||
|
||||
mock_print_container.assert_called_once()
|
||||
container = mock_print_container.call_args[0][0]
|
||||
assert 'test_tool' in container.body.text
|
||||
assert 'param' in container.body.text
|
||||
|
||||
@patch('openhands.cli.tui.print_container')
|
||||
def test_display_mcp_action_no_args(self, mock_print_container):
|
||||
mcp_action = MCPAction(name='test_tool')
|
||||
|
||||
display_mcp_action(mcp_action)
|
||||
|
||||
mock_print_container.assert_called_once()
|
||||
container = mock_print_container.call_args[0][0]
|
||||
assert 'test_tool' in container.body.text
|
||||
assert 'Arguments' not in container.body.text
|
||||
|
||||
@patch('openhands.cli.tui.print_container')
|
||||
def test_display_mcp_observation(self, mock_print_container):
|
||||
mcp_observation = MCPObservation(
|
||||
content='Tool result', name='test_tool', arguments={'param': 'value'}
|
||||
)
|
||||
|
||||
display_mcp_observation(mcp_observation)
|
||||
|
||||
mock_print_container.assert_called_once()
|
||||
container = mock_print_container.call_args[0][0]
|
||||
assert 'test_tool' in container.body.text
|
||||
assert 'Tool result' in container.body.text
|
||||
|
||||
@patch('openhands.cli.tui.print_container')
|
||||
def test_display_mcp_observation_no_content(self, mock_print_container):
|
||||
mcp_observation = MCPObservation(content='', name='test_tool')
|
||||
|
||||
display_mcp_observation(mcp_observation)
|
||||
|
||||
mock_print_container.assert_called_once()
|
||||
container = mock_print_container.call_args[0][0]
|
||||
assert 'No output' in container.body.text
|
||||
|
||||
@patch('openhands.cli.tui.print_formatted_text')
|
||||
def test_display_message(self, mock_print):
|
||||
message = 'Test message'
|
||||
@ -307,14 +407,86 @@ class TestReadConfirmationInput:
|
||||
result = await read_confirmation_input(config=cfg)
|
||||
assert result == 'always'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
"""Tests for CLI TUI MCP functionality."""
|
||||
|
||||
|
||||
class TestMCPTUIDisplay:
|
||||
"""Test MCP TUI display functions."""
|
||||
|
||||
@patch('openhands.cli.tui.print_container')
|
||||
def test_display_mcp_action_with_arguments(self, mock_print_container):
|
||||
"""Test displaying MCP action with arguments."""
|
||||
mcp_action = MCPAction(
|
||||
name='test_tool', arguments={'param1': 'value1', 'param2': 42}
|
||||
)
|
||||
|
||||
display_mcp_action(mcp_action)
|
||||
|
||||
mock_print_container.assert_called_once()
|
||||
container = mock_print_container.call_args[0][0]
|
||||
assert 'test_tool' in container.body.text
|
||||
assert 'param1' in container.body.text
|
||||
assert 'value1' in container.body.text
|
||||
|
||||
@patch('openhands.cli.tui.print_container')
|
||||
def test_display_mcp_observation_with_content(self, mock_print_container):
|
||||
"""Test displaying MCP observation with content."""
|
||||
mcp_observation = MCPObservation(
|
||||
content='Tool execution successful',
|
||||
name='test_tool',
|
||||
arguments={'param': 'value'},
|
||||
)
|
||||
|
||||
display_mcp_observation(mcp_observation)
|
||||
|
||||
mock_print_container.assert_called_once()
|
||||
container = mock_print_container.call_args[0][0]
|
||||
assert 'test_tool' in container.body.text
|
||||
assert 'Tool execution successful' in container.body.text
|
||||
|
||||
@patch('openhands.cli.tui.print_formatted_text')
|
||||
@patch('openhands.cli.tui.cli_confirm')
|
||||
async def test_read_confirmation_input_edit(self, mock_confirm, mock_print):
|
||||
mock_confirm.return_value = 3 # user picked third menu item
|
||||
@patch('openhands.cli.tui.mcp_error_collector')
|
||||
def test_display_mcp_errors_no_errors(self, mock_collector, mock_print):
|
||||
"""Test displaying MCP errors when none exist."""
|
||||
mock_collector.get_errors.return_value = []
|
||||
|
||||
cfg = MagicMock() # <- no spec for simplicity
|
||||
cfg.cli = MagicMock(vi_mode=False)
|
||||
display_mcp_errors()
|
||||
|
||||
result = await read_confirmation_input(config=cfg)
|
||||
assert result == 'edit'
|
||||
mock_print.assert_called_once()
|
||||
call_args = mock_print.call_args[0][0]
|
||||
assert 'No MCP errors detected' in str(call_args)
|
||||
|
||||
@patch('openhands.cli.tui.print_container')
|
||||
@patch('openhands.cli.tui.print_formatted_text')
|
||||
@patch('openhands.cli.tui.mcp_error_collector')
|
||||
def test_display_mcp_errors_with_errors(
|
||||
self, mock_collector, mock_print, mock_print_container
|
||||
):
|
||||
"""Test displaying MCP errors when some exist."""
|
||||
# Create mock errors
|
||||
error1 = MCPError(
|
||||
timestamp=1234567890.0,
|
||||
server_name='test-server-1',
|
||||
server_type='stdio',
|
||||
error_message='Connection failed',
|
||||
exception_details='Socket timeout',
|
||||
)
|
||||
error2 = MCPError(
|
||||
timestamp=1234567891.0,
|
||||
server_name='test-server-2',
|
||||
server_type='sse',
|
||||
error_message='Server unreachable',
|
||||
)
|
||||
|
||||
mock_collector.get_errors.return_value = [error1, error2]
|
||||
|
||||
display_mcp_errors()
|
||||
|
||||
# Should print error count header
|
||||
assert mock_print.call_count >= 1
|
||||
header_call = mock_print.call_args_list[0][0][0]
|
||||
assert '2 MCP error(s) detected' in str(header_call)
|
||||
|
||||
# Should print containers for each error
|
||||
assert mock_print_container.call_count == 2
|
||||
|
||||
@ -27,10 +27,9 @@ def test_empty_sse_config():
|
||||
|
||||
def test_invalid_sse_url():
|
||||
"""Test SSE configuration with invalid URL format."""
|
||||
config = MCPConfig(sse_servers=[MCPSSEServerConfig(url='not_a_url')])
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
config.validate_servers()
|
||||
assert 'Invalid URL' in str(exc_info.value)
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MCPSSEServerConfig(url='not_a_url')
|
||||
assert 'URL must include a scheme' in str(exc_info.value)
|
||||
|
||||
|
||||
def test_duplicate_sse_urls():
|
||||
@ -64,7 +63,7 @@ def test_from_toml_section_invalid_sse():
|
||||
}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
MCPConfig.from_toml_section(data)
|
||||
assert 'Invalid URL' in str(exc_info.value)
|
||||
assert 'URL must include a scheme' in str(exc_info.value)
|
||||
|
||||
|
||||
def test_complex_urls():
|
||||
@ -246,3 +245,28 @@ def test_stdio_server_equality_with_different_env_order():
|
||||
|
||||
# Should not be equal
|
||||
assert server1 != server5
|
||||
|
||||
|
||||
def test_mcp_stdio_server_args_parsing_basic():
|
||||
"""Test MCPStdioServerConfig args parsing with basic shell-like format."""
|
||||
# Test basic space-separated parsing
|
||||
config = MCPStdioServerConfig(
|
||||
name='test-server', command='python', args='arg1 arg2 arg3'
|
||||
)
|
||||
assert config.args == ['arg1', 'arg2', 'arg3']
|
||||
|
||||
# Test single argument
|
||||
config = MCPStdioServerConfig(
|
||||
name='test-server', command='python', args='single-arg'
|
||||
)
|
||||
assert config.args == ['single-arg']
|
||||
|
||||
|
||||
def test_mcp_stdio_server_args_parsing_invalid_quotes():
|
||||
"""Test MCPStdioServerConfig args parsing with invalid quotes."""
|
||||
# Test unmatched quotes
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MCPStdioServerConfig(
|
||||
name='test-server', command='python', args='--config "unmatched quote'
|
||||
)
|
||||
assert 'Invalid argument format' in str(exc_info.value)
|
||||
|
||||
159
tests/unit/test_mcp_error_collector.py
Normal file
159
tests/unit/test_mcp_error_collector.py
Normal file
@ -0,0 +1,159 @@
|
||||
"""Tests for MCP Error Collector functionality."""
|
||||
|
||||
import time
|
||||
|
||||
from openhands.mcp.error_collector import (
|
||||
MCPError,
|
||||
MCPErrorCollector,
|
||||
mcp_error_collector,
|
||||
)
|
||||
|
||||
|
||||
class TestMCPError:
|
||||
"""Test MCPError dataclass."""
|
||||
|
||||
def test_mcp_error_creation(self):
|
||||
"""Test creating an MCP error."""
|
||||
timestamp = time.time()
|
||||
error = MCPError(
|
||||
timestamp=timestamp,
|
||||
server_name='test-server',
|
||||
server_type='stdio',
|
||||
error_message='Connection failed',
|
||||
exception_details='Socket timeout',
|
||||
)
|
||||
|
||||
assert error.timestamp == timestamp
|
||||
assert error.server_name == 'test-server'
|
||||
assert error.server_type == 'stdio'
|
||||
assert error.error_message == 'Connection failed'
|
||||
assert error.exception_details == 'Socket timeout'
|
||||
|
||||
def test_mcp_error_creation_without_exception_details(self):
|
||||
"""Test creating an MCP error without exception details."""
|
||||
timestamp = time.time()
|
||||
error = MCPError(
|
||||
timestamp=timestamp,
|
||||
server_name='test-server',
|
||||
server_type='sse',
|
||||
error_message='Server unreachable',
|
||||
)
|
||||
|
||||
assert error.timestamp == timestamp
|
||||
assert error.server_name == 'test-server'
|
||||
assert error.server_type == 'sse'
|
||||
assert error.error_message == 'Server unreachable'
|
||||
assert error.exception_details is None
|
||||
|
||||
|
||||
class TestMCPErrorCollector:
|
||||
"""Test MCPErrorCollector functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.collector = MCPErrorCollector()
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test collector initialization."""
|
||||
assert self.collector._errors == []
|
||||
assert self.collector._collection_enabled is True
|
||||
|
||||
def test_add_error(self):
|
||||
"""Test adding an error to the collector."""
|
||||
self.collector.add_error(
|
||||
server_name='test-server',
|
||||
server_type='stdio',
|
||||
error_message='Connection failed',
|
||||
exception_details='Socket timeout',
|
||||
)
|
||||
|
||||
errors = self.collector.get_errors()
|
||||
assert len(errors) == 1
|
||||
assert errors[0].server_name == 'test-server'
|
||||
assert errors[0].server_type == 'stdio'
|
||||
assert errors[0].error_message == 'Connection failed'
|
||||
assert errors[0].exception_details == 'Socket timeout'
|
||||
assert errors[0].timestamp > 0
|
||||
|
||||
def test_add_multiple_errors(self):
|
||||
"""Test adding multiple errors."""
|
||||
self.collector.add_error('server1', 'stdio', 'Error 1')
|
||||
self.collector.add_error('server2', 'sse', 'Error 2')
|
||||
self.collector.add_error('server3', 'shttp', 'Error 3')
|
||||
|
||||
errors = self.collector.get_errors()
|
||||
assert len(errors) == 3
|
||||
assert errors[0].server_name == 'server1'
|
||||
assert errors[1].server_name == 'server2'
|
||||
assert errors[2].server_name == 'server3'
|
||||
|
||||
def test_has_errors(self):
|
||||
"""Test has_errors method."""
|
||||
assert not self.collector.has_errors()
|
||||
|
||||
self.collector.add_error('server1', 'stdio', 'Error 1')
|
||||
assert self.collector.has_errors()
|
||||
|
||||
self.collector.clear_errors()
|
||||
assert not self.collector.has_errors()
|
||||
|
||||
def test_clear_errors(self):
|
||||
"""Test clearing errors."""
|
||||
self.collector.add_error('server1', 'stdio', 'Error 1')
|
||||
self.collector.add_error('server2', 'sse', 'Error 2')
|
||||
|
||||
assert len(self.collector.get_errors()) == 2
|
||||
|
||||
self.collector.clear_errors()
|
||||
assert len(self.collector.get_errors()) == 0
|
||||
assert not self.collector.has_errors()
|
||||
|
||||
def test_enable_disable_collection(self):
|
||||
"""Test enabling and disabling error collection."""
|
||||
self.collector.add_error('server1', 'stdio', 'Error 1')
|
||||
assert len(self.collector.get_errors()) == 1
|
||||
|
||||
# Disable collection
|
||||
self.collector.disable_collection()
|
||||
|
||||
# Adding error should be ignored
|
||||
self.collector.add_error('server2', 'sse', 'Error 2')
|
||||
assert len(self.collector.get_errors()) == 1 # Still only 1 error
|
||||
|
||||
# Re-enable collection
|
||||
self.collector.enable_collection()
|
||||
|
||||
# Adding error should work again
|
||||
self.collector.add_error('server3', 'shttp', 'Error 3')
|
||||
assert len(self.collector.get_errors()) == 2
|
||||
|
||||
|
||||
class TestGlobalMCPErrorCollector:
|
||||
"""Test the global MCP error collector instance."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear global collector before each test."""
|
||||
mcp_error_collector.clear_errors()
|
||||
mcp_error_collector.enable_collection()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test."""
|
||||
mcp_error_collector.clear_errors()
|
||||
mcp_error_collector.enable_collection()
|
||||
|
||||
def test_global_collector_exists(self):
|
||||
"""Test that global collector instance exists."""
|
||||
assert mcp_error_collector is not None
|
||||
assert isinstance(mcp_error_collector, MCPErrorCollector)
|
||||
|
||||
def test_global_collector_functionality(self):
|
||||
"""Test basic functionality of global collector."""
|
||||
assert not mcp_error_collector.has_errors()
|
||||
|
||||
mcp_error_collector.add_error('global-server', 'stdio', 'Global error')
|
||||
assert mcp_error_collector.has_errors()
|
||||
assert mcp_error_collector.get_error_count() == 1
|
||||
|
||||
errors = mcp_error_collector.get_errors()
|
||||
assert len(errors) == 1
|
||||
assert errors[0].server_name == 'global-server'
|
||||
@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
# Import the module, not the functions directly to avoid circular imports
|
||||
import openhands.mcp.utils
|
||||
from openhands.core.config.mcp_config import MCPSSEServerConfig
|
||||
from openhands.core.config.mcp_config import MCPSSEServerConfig, MCPStdioServerConfig
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
|
||||
@ -161,3 +161,136 @@ async def test_call_tool_mcp_success():
|
||||
assert isinstance(observation, MCPObservation)
|
||||
assert json.loads(observation.content) == {'result': 'success'}
|
||||
mock_client.call_tool.assert_called_once_with('test_tool', {'arg1': 'value1'})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.mcp.utils.MCPClient')
|
||||
async def test_create_mcp_clients_stdio_success(mock_mcp_client):
|
||||
"""Test successful creation of MCP clients with stdio servers."""
|
||||
# Setup mock
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_mcp_client.return_value = mock_client_instance
|
||||
mock_client_instance.connect_stdio = AsyncMock()
|
||||
|
||||
# Test with stdio servers
|
||||
stdio_server_configs = [
|
||||
MCPStdioServerConfig(
|
||||
name='test-server-1',
|
||||
command='python',
|
||||
args=['-m', 'server1'],
|
||||
env={'DEBUG': 'true'},
|
||||
),
|
||||
MCPStdioServerConfig(
|
||||
name='test-server-2',
|
||||
command='/usr/bin/node',
|
||||
args=['server2.js'],
|
||||
env={'NODE_ENV': 'development'},
|
||||
),
|
||||
]
|
||||
|
||||
clients = await openhands.mcp.utils.create_mcp_clients(
|
||||
[], [], stdio_servers=stdio_server_configs
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(clients) == 2
|
||||
assert mock_mcp_client.call_count == 2
|
||||
|
||||
# Check that connect_stdio was called with correct parameters
|
||||
mock_client_instance.connect_stdio.assert_any_call(stdio_server_configs[0])
|
||||
mock_client_instance.connect_stdio.assert_any_call(stdio_server_configs[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.mcp.utils.MCPClient')
|
||||
async def test_create_mcp_clients_stdio_connection_failure(mock_mcp_client):
|
||||
"""Test handling of stdio connection failures when creating MCP clients."""
|
||||
# Setup mock
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_mcp_client.return_value = mock_client_instance
|
||||
|
||||
# First connection succeeds, second fails
|
||||
mock_client_instance.connect_stdio.side_effect = [
|
||||
None, # Success
|
||||
Exception('Stdio connection failed'), # Failure
|
||||
]
|
||||
|
||||
stdio_server_configs = [
|
||||
MCPStdioServerConfig(name='server1', command='python'),
|
||||
MCPStdioServerConfig(name='server2', command='invalid_command'),
|
||||
]
|
||||
|
||||
clients = await openhands.mcp.utils.create_mcp_clients(
|
||||
[], [], stdio_servers=stdio_server_configs
|
||||
)
|
||||
|
||||
# Verify only one client was successfully created
|
||||
assert len(clients) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.mcp.utils.create_mcp_clients')
|
||||
async def test_fetch_mcp_tools_from_config_with_stdio(mock_create_clients):
|
||||
"""Test fetching MCP tools with stdio servers enabled."""
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
|
||||
# Setup mock clients
|
||||
mock_client = MagicMock()
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.to_param.return_value = {'function': {'name': 'stdio_tool'}}
|
||||
mock_client.tools = [mock_tool]
|
||||
mock_create_clients.return_value = [mock_client]
|
||||
|
||||
# Create config with stdio servers
|
||||
mcp_config = MCPConfig(
|
||||
stdio_servers=[MCPStdioServerConfig(name='test-server', command='python')]
|
||||
)
|
||||
|
||||
# Test with use_stdio=True
|
||||
tools = await openhands.mcp.utils.fetch_mcp_tools_from_config(
|
||||
mcp_config, conversation_id='test-conv', use_stdio=True
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(tools) == 1
|
||||
assert tools[0] == {'function': {'name': 'stdio_tool'}}
|
||||
|
||||
# Verify create_mcp_clients was called with stdio servers
|
||||
mock_create_clients.assert_called_once_with(
|
||||
[], [], 'test-conv', mcp_config.stdio_servers
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_mcp_stdio_client():
|
||||
"""Test calling MCP tool on a stdio client."""
|
||||
# Create mock stdio client with the requested tool
|
||||
mock_client = MagicMock()
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = 'stdio_test_tool'
|
||||
mock_client.tools = [mock_tool]
|
||||
|
||||
# Setup response
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {
|
||||
'result': 'stdio_success',
|
||||
'data': 'test_data',
|
||||
}
|
||||
|
||||
# Setup call_tool method
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_response)
|
||||
|
||||
action = MCPAction(name='stdio_test_tool', arguments={'input': 'test_input'})
|
||||
|
||||
# Call the function
|
||||
observation = await openhands.mcp.utils.call_tool_mcp([mock_client], action)
|
||||
|
||||
# Verify
|
||||
assert isinstance(observation, MCPObservation)
|
||||
assert json.loads(observation.content) == {
|
||||
'result': 'stdio_success',
|
||||
'data': 'test_data',
|
||||
}
|
||||
mock_client.call_tool.assert_called_once_with(
|
||||
'stdio_test_tool', {'input': 'test_input'}
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user