Add MCP support for CLI (#9519)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
This commit is contained in:
Ryan H. Tran 2025-07-24 00:06:01 +07:00 committed by GitHub
parent 45ac6b839c
commit fbd9280239
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2001 additions and 35 deletions

View File

@ -153,6 +153,7 @@ You can use the following commands whenever the prompt (`>`) is displayed:
| `/new` | Start a new conversation |
| `/settings` | View and modify current LLM/agent settings |
| `/resume` | Resume the agent if paused |
| `/mcp` | Manage MCP server configuration and view connection errors |
#### Settings and Configuration
@ -162,7 +163,7 @@ follow the prompts:
- **Basic settings**: Choose a model/provider and enter your API key.
- **Advanced settings**: Set custom endpoints, enable or disable confirmation mode, and configure memory condensation.
Settings can also be managed via the `config.toml` file.
Settings can also be managed via the `config.toml` file in the current directory or `~/.openhands/config.toml`.
#### Repository Initialization
@ -174,6 +175,41 @@ project details and structure. Use this when onboarding the agent to a new codeb
You can pause the agent while it is running by pressing `Ctrl-P`. To continue the conversation after pausing, simply
type `/resume` at the prompt.
#### MCP Server Management
To configure Model Context Protocol (MCP) servers, you can refer to the documentation on [MCP servers](../mcp) and use the `/mcp` command in the CLI. This command provides an interactive interface for managing Model Context Protocol (MCP) servers:
- **List configured servers**: View all currently configured MCP servers (SSE, Stdio, and SHTTP)
- **Add new server**: Interactively add a new MCP server with guided prompts
- **Remove server**: Remove an existing MCP server from your configuration
- **View errors**: Display any connection errors that occurred during MCP server startup
This command modifies your `~/.openhands/config.toml` file and will prompt you to restart OpenHands for changes to take effect.
To enable the [Tavily MCP server](https://github.com/tavily-ai/tavily-mcp) search engine, you can set the `search_api_key` under the `[core]` section in the `~/.openhands/config.toml` file.
##### Example of the `config.toml` file with MCP server configuration:
```toml
[core]
search_api_key = "tvly-your-api-key-here"
[mcp]
stdio_servers = [
{name="fetch", command="uvx", args=["mcp-server-fetch"]},
]
sse_servers = [
# Basic SSE server with just a URL
"http://example.com:8080/sse",
]
shttp_servers = [
# Streamable HTTP server with API key authentication
{url="https://secure-example.com/mcp", api_key="your-api-key"}
]
```
## Tips and Troubleshooting
- Use `/help` at any time to see the list of available commands.

View File

@ -1,9 +1,15 @@
import asyncio
import os
import sys
from pathlib import Path
from typing import Any
from prompt_toolkit import print_formatted_text
import toml
from prompt_toolkit import HTML, print_formatted_text
from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.shortcuts import clear, print_container
from prompt_toolkit.widgets import Frame, TextArea
from pydantic import ValidationError
from openhands.cli.settings import (
display_settings,
@ -14,9 +20,12 @@ from openhands.cli.tui import (
COLOR_GREY,
UsageMetrics,
cli_confirm,
create_prompt_session,
display_help,
display_mcp_errors,
display_shutdown_message,
display_status,
read_prompt_input,
)
from openhands.cli.utils import (
add_local_config_trusted_dir,
@ -27,6 +36,11 @@ from openhands.cli.utils import (
from openhands.core.config import (
OpenHandsConfig,
)
from openhands.core.config.mcp_config import (
MCPSHTTPServerConfig,
MCPSSEServerConfig,
MCPStdioServerConfig,
)
from openhands.core.schema import AgentState
from openhands.core.schema.exit_reason import ExitReason
from openhands.events import EventSource
@ -38,6 +52,72 @@ from openhands.events.stream import EventStream
from openhands.storage.settings.file_settings_store import FileSettingsStore
async def collect_input(config: OpenHandsConfig, prompt_text: str) -> str | None:
"""Collect user input with cancellation support.
Args:
config: OpenHands configuration
prompt_text: Text to display to user
Returns:
str | None: User input string, or None if user cancelled
"""
print_formatted_text(prompt_text, end=' ')
user_input = await read_prompt_input(config, '', multiline=False)
# Check for cancellation
if user_input.strip().lower() in ['/exit', '/cancel', 'cancel']:
return None
return user_input.strip()
def restart_cli() -> None:
"""Restart the CLI by replacing the current process."""
print_formatted_text('🔄 Restarting OpenHands CLI...')
# Get the current Python executable and script arguments
python_executable = sys.executable
script_args = sys.argv
# Use os.execv to replace the current process
# This preserves the original command line arguments
try:
os.execv(python_executable, [python_executable] + script_args)
except Exception as e:
print_formatted_text(f'❌ Failed to restart CLI: {e}')
print_formatted_text(
'Please restart OpenHands manually for changes to take effect.'
)
async def prompt_for_restart(config: OpenHandsConfig) -> bool:
"""Prompt user if they want to restart the CLI and return their choice."""
print_formatted_text('📝 MCP server configuration updated successfully!')
print_formatted_text('The changes will take effect after restarting OpenHands.')
prompt_session = create_prompt_session(config)
while True:
try:
with patch_stdout():
response = await prompt_session.prompt_async(
HTML(
'<gold>Would you like to restart OpenHands now? (y/n): </gold>'
)
)
response = response.strip().lower() if response else ''
if response in ['y', 'yes']:
return True
elif response in ['n', 'no']:
return False
else:
print_formatted_text('Please enter "y" for yes or "n" for no.')
except (KeyboardInterrupt, EOFError):
return False
async def handle_commands(
command: str,
event_stream: EventStream,
@ -79,6 +159,8 @@ async def handle_commands(
await handle_settings_command(config, settings_store)
elif command == '/resume':
close_repl, new_session_requested = await handle_resume_command(event_stream)
elif command == '/mcp':
await handle_mcp_command(config)
else:
close_repl = True
action = MessageAction(content=command)
@ -327,3 +409,432 @@ def check_folder_security_agreement(config: OpenHandsConfig, current_dir: str) -
return confirm
return True
async def handle_mcp_command(config: OpenHandsConfig) -> None:
"""Handle MCP command with interactive menu."""
action = cli_confirm(
config,
'MCP Server Configuration',
[
'List configured servers',
'Add new server',
'Remove server',
'View errors',
'Go back',
],
)
if action == 0: # List
display_mcp_servers(config)
elif action == 1: # Add
await add_mcp_server(config)
elif action == 2: # Remove
await remove_mcp_server(config)
elif action == 3: # View errors
handle_mcp_errors_command()
# action == 4 is "Go back", do nothing
def display_mcp_servers(config: OpenHandsConfig) -> None:
"""Display MCP server configuration information."""
mcp_config = config.mcp
# Count the different types of servers
sse_count = len(mcp_config.sse_servers)
stdio_count = len(mcp_config.stdio_servers)
shttp_count = len(mcp_config.shttp_servers)
total_count = sse_count + stdio_count + shttp_count
if total_count == 0:
print_formatted_text(
'No custom MCP servers configured. See the documentation to learn more:\n'
' https://docs.all-hands.dev/usage/how-to/cli-mode#using-mcp-servers'
)
else:
print_formatted_text(
f'Configured MCP servers:\n'
f' • SSE servers: {sse_count}\n'
f' • Stdio servers: {stdio_count}\n'
f' • SHTTP servers: {shttp_count}\n'
f' • Total: {total_count}'
)
# Show details for each type if they exist
if sse_count > 0:
print_formatted_text('SSE Servers:')
for idx, sse_server in enumerate(mcp_config.sse_servers, 1):
print_formatted_text(f' {idx}. {sse_server.url}')
print_formatted_text('')
if stdio_count > 0:
print_formatted_text('Stdio Servers:')
for idx, stdio_server in enumerate(mcp_config.stdio_servers, 1):
print_formatted_text(
f' {idx}. {stdio_server.name} ({stdio_server.command})'
)
print_formatted_text('')
if shttp_count > 0:
print_formatted_text('SHTTP Servers:')
for idx, shttp_server in enumerate(mcp_config.shttp_servers, 1):
print_formatted_text(f' {idx}. {shttp_server.url}')
print_formatted_text('')
def handle_mcp_errors_command() -> None:
"""Display MCP connection errors."""
display_mcp_errors()
def get_config_file_path() -> Path:
"""Get the path to the config file. By default, we use config.toml in the current working directory. If not found, we use ~/.openhands/config.toml."""
# Check if config.toml exists in the current directory
current_dir = Path.cwd() / 'config.toml'
if current_dir.exists():
return current_dir
# Fallback to the user's home directory
return Path.home() / '.openhands' / 'config.toml'
def load_config_file(file_path: Path) -> dict:
"""Load the config file, creating it if it doesn't exist."""
if file_path.exists():
try:
with open(file_path, 'r') as f:
return toml.load(f)
except Exception:
pass
# Create directory if it doesn't exist
file_path.parent.mkdir(parents=True, exist_ok=True)
return {}
def save_config_file(config_data: dict, file_path: Path) -> None:
"""Save the config file."""
with open(file_path, 'w') as f:
toml.dump(config_data, f)
def _ensure_mcp_config_structure(config_data: dict) -> None:
"""Ensure MCP configuration structure exists in config data."""
if 'mcp' not in config_data:
config_data['mcp'] = {}
def _add_server_to_config(server_type: str, server_config: dict) -> Path:
"""Add a server configuration to the config file."""
config_file_path = get_config_file_path()
config_data = load_config_file(config_file_path)
_ensure_mcp_config_structure(config_data)
if server_type not in config_data['mcp']:
config_data['mcp'][server_type] = []
config_data['mcp'][server_type].append(server_config)
save_config_file(config_data, config_file_path)
return config_file_path
async def add_mcp_server(config: OpenHandsConfig) -> None:
"""Add a new MCP server configuration."""
# Choose transport type
transport_type = cli_confirm(
config,
'Select MCP server transport type:',
[
'SSE (Server-Sent Events)',
'Stdio (Standard Input/Output)',
'SHTTP (Streamable HTTP)',
'Cancel',
],
)
if transport_type == 3: # Cancel
return
try:
if transport_type == 0: # SSE
await add_sse_server(config)
elif transport_type == 1: # Stdio
await add_stdio_server(config)
elif transport_type == 2: # SHTTP
await add_shttp_server(config)
except Exception as e:
print_formatted_text(f'Error adding MCP server: {e}')
async def add_sse_server(config: OpenHandsConfig) -> None:
"""Add an SSE MCP server."""
print_formatted_text('Adding SSE MCP Server')
while True: # Retry loop for the entire form
# Collect all inputs
url = await collect_input(config, '\nEnter server URL:')
if url is None:
print_formatted_text('Operation cancelled.')
return
api_key = await collect_input(
config, '\nEnter API key (optional, press Enter to skip):'
)
if api_key is None:
print_formatted_text('Operation cancelled.')
return
# Convert empty string to None for optional field
api_key = api_key if api_key else None
# Validate all inputs at once
try:
server = MCPSSEServerConfig(url=url, api_key=api_key)
break # Success - exit retry loop
except ValidationError as e:
# Show all errors at once
print_formatted_text('❌ Please fix the following errors:')
for error in e.errors():
field = error['loc'][0] if error['loc'] else 'unknown'
print_formatted_text(f'{field}: {error["msg"]}')
if cli_confirm(config, '\nTry again?') != 0:
print_formatted_text('Operation cancelled.')
return
# Save to config file
server_config = {'url': server.url}
if server.api_key:
server_config['api_key'] = server.api_key
config_file_path = _add_server_to_config('sse_servers', server_config)
print_formatted_text(f'✓ SSE MCP server added to {config_file_path}: {server.url}')
# Prompt for restart
if await prompt_for_restart(config):
restart_cli()
async def add_stdio_server(config: OpenHandsConfig) -> None:
"""Add a Stdio MCP server."""
print_formatted_text('Adding Stdio MCP Server')
# Get existing server names to check for duplicates
existing_names = [server.name for server in config.mcp.stdio_servers]
while True: # Retry loop for the entire form
# Collect all inputs
name = await collect_input(config, '\nEnter server name:')
if name is None:
print_formatted_text('Operation cancelled.')
return
command = await collect_input(config, "\nEnter command (e.g., 'uvx', 'npx'):")
if command is None:
print_formatted_text('Operation cancelled.')
return
args_input = await collect_input(
config,
'\nEnter arguments (optional, e.g., "-y server-package arg1"):',
)
if args_input is None:
print_formatted_text('Operation cancelled.')
return
env_input = await collect_input(
config,
'\nEnter environment variables (KEY=VALUE format, comma-separated, optional):',
)
if env_input is None:
print_formatted_text('Operation cancelled.')
return
# Check for duplicate server names
if name in existing_names:
print_formatted_text(f"❌ Server name '{name}' already exists.")
if cli_confirm(config, '\nTry again?') != 0:
print_formatted_text('Operation cancelled.')
return
continue
# Validate all inputs at once
try:
server = MCPStdioServerConfig(
name=name,
command=command,
args=args_input, # type: ignore # Will be parsed by Pydantic validator
env=env_input, # type: ignore # Will be parsed by Pydantic validator
)
break # Success - exit retry loop
except ValidationError as e:
# Show all errors at once
print_formatted_text('❌ Please fix the following errors:')
for error in e.errors():
field = error['loc'][0] if error['loc'] else 'unknown'
print_formatted_text(f'{field}: {error["msg"]}')
if cli_confirm(config, '\nTry again?') != 0:
print_formatted_text('Operation cancelled.')
return
# Save to config file
server_config: dict[str, Any] = {
'name': server.name,
'command': server.command,
}
if server.args:
server_config['args'] = server.args
if server.env:
server_config['env'] = server.env
config_file_path = _add_server_to_config('stdio_servers', server_config)
print_formatted_text(
f'✓ Stdio MCP server added to {config_file_path}: {server.name}'
)
# Prompt for restart
if await prompt_for_restart(config):
restart_cli()
async def add_shttp_server(config: OpenHandsConfig) -> None:
"""Add an SHTTP MCP server."""
print_formatted_text('Adding SHTTP MCP Server')
while True: # Retry loop for the entire form
# Collect all inputs
url = await collect_input(config, '\nEnter server URL:')
if url is None:
print_formatted_text('Operation cancelled.')
return
api_key = await collect_input(
config, '\nEnter API key (optional, press Enter to skip):'
)
if api_key is None:
print_formatted_text('Operation cancelled.')
return
# Convert empty string to None for optional field
api_key = api_key if api_key else None
# Validate all inputs at once
try:
server = MCPSHTTPServerConfig(url=url, api_key=api_key)
break # Success - exit retry loop
except ValidationError as e:
# Show all errors at once
print_formatted_text('❌ Please fix the following errors:')
for error in e.errors():
field = error['loc'][0] if error['loc'] else 'unknown'
print_formatted_text(f'{field}: {error["msg"]}')
if cli_confirm(config, '\nTry again?') != 0:
print_formatted_text('Operation cancelled.')
return
# Save to config file
server_config = {'url': server.url}
if server.api_key:
server_config['api_key'] = server.api_key
config_file_path = _add_server_to_config('shttp_servers', server_config)
print_formatted_text(
f'✓ SHTTP MCP server added to {config_file_path}: {server.url}'
)
# Prompt for restart
if await prompt_for_restart(config):
restart_cli()
async def remove_mcp_server(config: OpenHandsConfig) -> None:
"""Remove an MCP server configuration."""
mcp_config = config.mcp
# Collect all servers with their types
servers: list[tuple[str, str, object]] = []
# Add SSE servers
for sse_server in mcp_config.sse_servers:
servers.append(('SSE', sse_server.url, sse_server))
# Add Stdio servers
for stdio_server in mcp_config.stdio_servers:
servers.append(('Stdio', stdio_server.name, stdio_server))
# Add SHTTP servers
for shttp_server in mcp_config.shttp_servers:
servers.append(('SHTTP', shttp_server.url, shttp_server))
if not servers:
print_formatted_text('No MCP servers configured to remove.')
return
# Create choices for the user
choices = []
for server_type, identifier, _ in servers:
choices.append(f'{server_type}: {identifier}')
choices.append('Cancel')
# Let user choose which server to remove
choice = cli_confirm(config, 'Select MCP server to remove:', choices)
if choice == len(choices) - 1: # Cancel
return
# Remove the selected server
server_type, identifier, _ = servers[choice]
# Confirm removal
confirm = cli_confirm(
config,
f'Are you sure you want to remove {server_type} server "{identifier}"?',
['Yes, remove', 'Cancel'],
)
if confirm == 1: # Cancel
return
# Load config file and remove the server
config_file_path = get_config_file_path()
config_data = load_config_file(config_file_path)
_ensure_mcp_config_structure(config_data)
removed = False
if server_type == 'SSE' and 'sse_servers' in config_data['mcp']:
config_data['mcp']['sse_servers'] = [
s for s in config_data['mcp']['sse_servers'] if s.get('url') != identifier
]
removed = True
elif server_type == 'Stdio' and 'stdio_servers' in config_data['mcp']:
config_data['mcp']['stdio_servers'] = [
s
for s in config_data['mcp']['stdio_servers']
if s.get('name') != identifier
]
removed = True
elif server_type == 'SHTTP' and 'shttp_servers' in config_data['mcp']:
config_data['mcp']['shttp_servers'] = [
s for s in config_data['mcp']['shttp_servers'] if s.get('url') != identifier
]
removed = True
if removed:
save_config_file(config_data, config_file_path)
print_formatted_text(
f'{server_type} MCP server "{identifier}" removed from {config_file_path}.'
)
# Prompt for restart
if await prompt_for_restart(config):
restart_cli()
else:
print_formatted_text(f'Failed to remove {server_type} server "{identifier}".')

View File

@ -74,6 +74,7 @@ from openhands.events.observation import (
)
from openhands.io import read_task
from openhands.mcp import add_mcp_tools_to_agent
from openhands.mcp.error_collector import mcp_error_collector
from openhands.memory.condenser.impl.llm_summarizing_condenser import (
LLMSummarizingCondenserConfig,
)
@ -298,6 +299,10 @@ async def run_session(
# Add MCP tools to the agent
if agent.config.enable_mcp:
# Clear any previous errors and enable collection
mcp_error_collector.clear_errors()
mcp_error_collector.enable_collection()
# Add OpenHands' MCP server by default
_, openhands_mcp_stdio_servers = (
OpenHandsMCPConfigImpl.create_default_mcp_server_config(
@ -309,6 +314,9 @@ async def run_session(
await add_mcp_tools_to_agent(agent, runtime, memory)
# Disable collection after startup
mcp_error_collector.disable_collection()
# Clear loading animation
is_loaded.set()
@ -319,7 +327,27 @@ async def run_session(
if not skip_banner:
display_banner(session_id=sid)
welcome_message = 'What do you want to build?' # from the application
welcome_message = ''
# Display number of MCP servers configured
if agent.config.enable_mcp:
total_mcp_servers = (
len(runtime.config.mcp.stdio_servers)
+ len(runtime.config.mcp.sse_servers)
+ len(runtime.config.mcp.shttp_servers)
)
if total_mcp_servers > 0:
mcp_line = f'Using {len(runtime.config.mcp.stdio_servers)} stdio MCP servers, {len(runtime.config.mcp.sse_servers)} SSE MCP servers and {len(runtime.config.mcp.shttp_servers)} SHTTP MCP servers.'
# Check for MCP errors and add indicator to the same line
if agent.config.enable_mcp and mcp_error_collector.has_errors():
mcp_line += (
' ✗ MCP errors detected (type /mcp → select View errors to view)'
)
welcome_message += mcp_line + '\n\n'
welcome_message += 'What do you want to build?' # from the application
initial_message = '' # from the user
if task_content:
@ -488,6 +516,16 @@ async def main_with_loop(loop: asyncio.AbstractEventLoop) -> None:
if not env_log_level:
logger.setLevel(logging.WARNING)
# If `config.toml` does not exist in current directory, use the file under home directory
if not os.path.exists(args.config_file):
home_config_file = os.path.join(
os.path.expanduser('~'), '.openhands', 'config.toml'
)
logger.info(
f'Config file {args.config_file} does not exist, using default config file in home directory: {home_config_file}.'
)
args.config_file = home_config_file
# Load config from toml and override with command line arguments
config: OpenHandsConfig = setup_config_from_args(args)

View File

@ -4,6 +4,8 @@
import asyncio
import contextlib
import datetime
import json
import sys
import threading
import time
@ -36,6 +38,7 @@ from openhands.events.action import (
ActionConfirmationStatus,
ChangeAgentStateAction,
CmdRunAction,
MCPAction,
MessageAction,
)
from openhands.events.event import Event
@ -45,8 +48,10 @@ from openhands.events.observation import (
ErrorObservation,
FileEditObservation,
FileReadObservation,
MCPObservation,
)
from openhands.llm.metrics import Metrics
from openhands.mcp.error_collector import mcp_error_collector
ENABLE_STREAMING = False # FIXME: this doesn't work
@ -76,6 +81,7 @@ COMMANDS = {
'/new': 'Create a new conversation',
'/settings': 'Display and modify current settings',
'/resume': 'Resume the agent when paused',
'/mcp': 'Manage MCP server configuration and view errors',
}
print_lock = threading.Lock()
@ -162,6 +168,7 @@ def display_welcome_message(message: str = '') -> None:
print_formatted_text(
HTML("<gold>Let's start building!</gold>\n"), style=DEFAULT_STYLE
)
if message:
print_formatted_text(
HTML(f'{message} <grey>Type /help for help</grey>'),
@ -186,6 +193,48 @@ def display_initial_user_prompt(prompt: str) -> None:
)
def display_mcp_errors() -> None:
"""Display collected MCP errors."""
errors = mcp_error_collector.get_errors()
if not errors:
print_formatted_text(HTML('<ansigreen>✓ No MCP errors detected</ansigreen>\n'))
return
print_formatted_text(
HTML(
f'<ansired>✗ {len(errors)} MCP error(s) detected during startup:</ansired>\n'
)
)
for i, error in enumerate(errors, 1):
# Format timestamp
timestamp = datetime.datetime.fromtimestamp(error.timestamp).strftime(
'%H:%M:%S'
)
# Create error display text
error_text = (
f'[{timestamp}] {error.server_type.upper()} Server: {error.server_name}\n'
)
error_text += f'Error: {error.error_message}\n'
if error.exception_details:
error_text += f'Details: {error.exception_details}'
container = Frame(
TextArea(
text=error_text,
read_only=True,
style='ansired',
wrap_lines=True,
),
title=f'MCP Error #{i}',
style='ansired',
)
print_container(container)
print_formatted_text('') # Add spacing between errors
# Prompt output display functions
def display_thought_if_new(thought: str) -> None:
"""Display a thought only if it hasn't been displayed recently."""
@ -215,6 +264,8 @@ def display_event(event: Event, config: OpenHandsConfig) -> None:
if event.confirmation_state == ActionConfirmationStatus.CONFIRMED:
initialize_streaming_output()
elif isinstance(event, MCPAction):
display_mcp_action(event)
elif isinstance(event, Action):
# For other actions, display thoughts normally
if hasattr(event, 'thought') and event.thought:
@ -232,6 +283,8 @@ def display_event(event: Event, config: OpenHandsConfig) -> None:
display_file_edit(event)
elif isinstance(event, FileReadObservation):
display_file_read(event)
elif isinstance(event, MCPObservation):
display_mcp_observation(event)
elif isinstance(event, AgentStateChangedObservation):
display_agent_state_change_message(event.agent_state)
elif isinstance(event, ErrorObservation):
@ -337,6 +390,66 @@ def display_file_read(event: FileReadObservation) -> None:
print_container(container)
def display_mcp_action(event: MCPAction) -> None:
"""Display an MCP action in the CLI."""
# Format the arguments for display
args_text = ''
if event.arguments:
try:
args_text = json.dumps(event.arguments, indent=2)
except (TypeError, ValueError):
args_text = str(event.arguments)
# Create the display text
display_text = f'Tool: {event.name}'
if args_text:
display_text += f'\n\nArguments:\n{args_text}'
container = Frame(
TextArea(
text=display_text,
read_only=True,
style='ansiblue',
wrap_lines=True,
),
title='MCP Tool Call',
style='ansiblue',
)
print_formatted_text('')
print_container(container)
def display_mcp_observation(event: MCPObservation) -> None:
"""Display an MCP observation in the CLI."""
# Format the content for display
content = event.content.strip() if event.content else 'No output'
# Add tool name and arguments info if available
display_text = content
if event.name:
header = f'Tool: {event.name}'
if event.arguments:
try:
args_text = json.dumps(event.arguments, indent=2)
header += f'\nArguments: {args_text}'
except (TypeError, ValueError):
header += f'\nArguments: {event.arguments}'
display_text = f'{header}\n\nResult:\n{content}'
container = Frame(
TextArea(
text=display_text,
read_only=True,
style=COLOR_GREY,
wrap_lines=True,
),
title='MCP Tool Result',
style=f'fg:{COLOR_GREY}',
)
print_formatted_text('')
print_container(container)
def initialize_streaming_output():
"""Initialize the streaming output TextArea."""
if not ENABLE_STREAMING:

View File

@ -1,8 +1,17 @@
import os
import re
import shlex
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
ValidationError,
field_validator,
model_validator,
)
if TYPE_CHECKING:
from openhands.core.config.openhands_config import OpenHandsConfig
@ -11,6 +20,27 @@ from openhands.core.logger import openhands_logger as logger
from openhands.utils.import_utils import get_impl
def _validate_mcp_url(url: str) -> str:
"""Shared URL validation logic for MCP servers."""
if not url.strip():
raise ValueError('URL cannot be empty')
url = url.strip()
try:
parsed = urlparse(url)
if not parsed.scheme:
raise ValueError('URL must include a scheme (http:// or https://)')
if not parsed.netloc:
raise ValueError('URL must include a valid domain/host')
if parsed.scheme not in ['http', 'https', 'ws', 'wss']:
raise ValueError('URL scheme must be http, https, ws, or wss')
return url
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f'Invalid URL format: {str(e)}')
class MCPSSEServerConfig(BaseModel):
"""Configuration for a single MCP server.
@ -22,6 +52,12 @@ class MCPSSEServerConfig(BaseModel):
url: str
api_key: str | None = None
@field_validator('url')
@classmethod
def validate_url(cls, v: str) -> str:
"""Validate URL format for MCP servers."""
return _validate_mcp_url(v)
class MCPStdioServerConfig(BaseModel):
"""Configuration for a MCP server that uses stdio.
@ -38,6 +74,100 @@ class MCPStdioServerConfig(BaseModel):
args: list[str] = Field(default_factory=list)
env: dict[str, str] = Field(default_factory=dict)
@field_validator('name')
@classmethod
def validate_server_name(cls, v: str) -> str:
"""Validate server name for stdio MCP servers."""
if not v.strip():
raise ValueError('Server name cannot be empty')
v = v.strip()
# Check for valid characters (alphanumeric, hyphens, underscores)
if not re.match(r'^[a-zA-Z0-9_-]+$', v):
raise ValueError(
'Server name can only contain letters, numbers, hyphens, and underscores'
)
return v
@field_validator('command')
@classmethod
def validate_command(cls, v: str) -> str:
"""Validate command for stdio MCP servers."""
if not v.strip():
raise ValueError('Command cannot be empty')
v = v.strip()
# Check that command doesn't contain spaces (should be a single executable)
if ' ' in v:
raise ValueError(
'Command should be a single executable without spaces (use arguments field for parameters)'
)
return v
@field_validator('args', mode='before')
@classmethod
def parse_args(cls, v) -> list[str]:
"""Parse arguments from string or return list as-is.
Supports shell-like argument parsing using shlex.split().
Examples:
- "-y mcp-remote https://example.com"
- '--config "path with spaces" --debug'
- "arg1 arg2 arg3"
"""
if isinstance(v, str):
if not v.strip():
return []
v = v.strip()
# Use shell-like parsing for natural argument handling
try:
return shlex.split(v)
except ValueError as e:
# If shlex parsing fails (e.g., unmatched quotes), provide clear error
raise ValueError(
f'Invalid argument format: {str(e)}. Use shell-like format, e.g., "arg1 arg2" or \'--config "value with spaces"\''
)
return v or []
@field_validator('env', mode='before')
@classmethod
def parse_env(cls, v) -> dict[str, str]:
"""Parse environment variables from string or return dict as-is."""
if isinstance(v, str):
if not v.strip():
return {}
env = {}
for pair in v.split(','):
pair = pair.strip()
if not pair:
continue
if '=' not in pair:
raise ValueError(
f"Environment variable '{pair}' must be in KEY=VALUE format"
)
key, value = pair.split('=', 1)
key = key.strip()
if not key:
raise ValueError('Environment variable key cannot be empty')
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', key):
raise ValueError(
f"Invalid environment variable name '{key}'. Must start with letter or underscore, contain only alphanumeric characters and underscores"
)
env[key] = value
return env
return v or {}
def __eq__(self, other):
"""Override equality operator to compare server configurations.
@ -59,6 +189,12 @@ class MCPSHTTPServerConfig(BaseModel):
url: str
api_key: str | None = None
@field_validator('url')
@classmethod
def validate_url(cls, v: str) -> str:
"""Validate URL format for MCP servers."""
return _validate_mcp_url(v)
class MCPConfig(BaseModel):
"""Configuration for MCP (Message Control Protocol) settings.

View File

@ -5,7 +5,7 @@ import platform
import sys
from ast import literal_eval
from types import UnionType
from typing import MutableMapping, get_args, get_origin
from typing import MutableMapping, get_args, get_origin, get_type_hints
from uuid import uuid4
import toml
@ -154,8 +154,22 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None
core_config = toml_config['core']
# Process core section if present
cfg_type_hints = get_type_hints(cfg.__class__)
for key, value in core_config.items():
if hasattr(cfg, key):
# Get expected type of the attribute
expected_type = cfg_type_hints.get(key, None)
# Check if expected_type is a Union that includes SecretStr and value is str, e.g. search_api_key
if expected_type:
origin = get_origin(expected_type)
args = get_args(expected_type)
if origin is UnionType and SecretStr in args and isinstance(value, str):
value = SecretStr(value)
elif expected_type is SecretStr and isinstance(value, str):
value = SecretStr(value)
setattr(cfg, key, value)
else:
logger.openhands_logger.warning(

View File

@ -1,4 +1,5 @@
from openhands.mcp.client import MCPClient
from openhands.mcp.error_collector import mcp_error_collector
from openhands.mcp.tool import MCPClientTool
from openhands.mcp.utils import (
add_mcp_tools_to_agent,
@ -16,4 +17,5 @@ __all__ = [
'fetch_mcp_tools_from_config',
'call_tool_mcp',
'add_mcp_tools_to_agent',
'mcp_error_collector',
]

View File

@ -1,13 +1,22 @@
from typing import Optional
from fastmcp import Client
from fastmcp.client.transports import SSETransport, StreamableHttpTransport
from fastmcp.client.transports import (
SSETransport,
StdioTransport,
StreamableHttpTransport,
)
from mcp import McpError
from mcp.types import CallToolResult
from pydantic import BaseModel, ConfigDict, Field
from openhands.core.config.mcp_config import MCPSHTTPServerConfig, MCPSSEServerConfig
from openhands.core.config.mcp_config import (
MCPSHTTPServerConfig,
MCPSSEServerConfig,
MCPStdioServerConfig,
)
from openhands.core.logger import openhands_logger as logger
from openhands.mcp.error_collector import mcp_error_collector
from openhands.mcp.tool import MCPClientTool
@ -90,11 +99,51 @@ class MCPClient(BaseModel):
await self._initialize_and_list_tools()
except McpError as e:
logger.error(f'McpError connecting to {server_url}: {e}')
error_msg = f'McpError connecting to {server_url}: {e}'
logger.error(error_msg)
mcp_error_collector.add_error(
server_name=server_url,
server_type='shttp'
if isinstance(server, MCPSHTTPServerConfig)
else 'sse',
error_message=error_msg,
exception_details=str(e),
)
raise # Re-raise the error
except Exception as e:
logger.error(f'Error connecting to {server_url}: {e}')
error_msg = f'Error connecting to {server_url}: {e}'
logger.error(error_msg)
mcp_error_collector.add_error(
server_name=server_url,
server_type='shttp'
if isinstance(server, MCPSHTTPServerConfig)
else 'sse',
error_message=error_msg,
exception_details=str(e),
)
raise
async def connect_stdio(self, server: MCPStdioServerConfig, timeout: float = 30.0):
"""Connect to MCP server using stdio transport"""
try:
transport = StdioTransport(
command=server.command, args=server.args or [], env=server.env
)
self.client = Client(transport, timeout=timeout)
await self._initialize_and_list_tools()
except Exception as e:
server_name = getattr(
server, 'name', f'{server.command} {" ".join(server.args or [])}'
)
error_msg = f'Failed to connect to stdio server {server_name}: {e}'
logger.error(error_msg)
mcp_error_collector.add_error(
server_name=server_name,
server_type='stdio',
error_message=error_msg,
exception_details=str(e),
)
raise
async def call_tool(self, tool_name: str, args: dict) -> CallToolResult:

View File

@ -0,0 +1,78 @@
"""MCP Error Collector for capturing and storing MCP-related errors during startup."""
import threading
import time
from dataclasses import dataclass
@dataclass
class MCPError:
"""Represents an MCP-related error."""
timestamp: float
server_name: str
server_type: str # 'stdio', 'sse', 'shttp'
error_message: str
exception_details: str | None = None
class MCPErrorCollector:
"""Thread-safe collector for MCP errors during startup."""
def __init__(self):
self._errors: list[MCPError] = []
self._lock = threading.Lock()
self._collection_enabled = True
def add_error(
self,
server_name: str,
server_type: str,
error_message: str,
exception_details: str | None = None,
) -> None:
"""Add an MCP error to the collection."""
if not self._collection_enabled:
return
with self._lock:
error = MCPError(
timestamp=time.time(),
server_name=server_name,
server_type=server_type,
error_message=error_message,
exception_details=exception_details,
)
self._errors.append(error)
def get_errors(self) -> list[MCPError]:
"""Get a copy of all collected errors."""
with self._lock:
return self._errors.copy()
def has_errors(self) -> bool:
"""Check if there are any collected errors."""
with self._lock:
return len(self._errors) > 0
def clear_errors(self) -> None:
"""Clear all collected errors."""
with self._lock:
self._errors.clear()
def disable_collection(self) -> None:
"""Disable error collection (useful after startup)."""
self._collection_enabled = False
def enable_collection(self) -> None:
"""Enable error collection."""
self._collection_enabled = True
def get_error_count(self) -> int:
"""Get the number of collected errors."""
with self._lock:
return len(self._errors)
# Global instance for collecting MCP errors
mcp_error_collector = MCPErrorCollector()

View File

@ -11,14 +11,17 @@ from openhands.core.config.mcp_config import (
MCPConfig,
MCPSHTTPServerConfig,
MCPSSEServerConfig,
MCPStdioServerConfig,
)
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.mcp import MCPAction
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.mcp.client import MCPClient
from openhands.mcp.error_collector import mcp_error_collector
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[dict]:
@ -45,7 +48,14 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di
mcp_tools = tool.to_param()
all_mcp_tools.append(mcp_tools)
except Exception as e:
logger.error(f'Error in convert_mcp_clients_to_tools: {e}')
error_msg = f'Error in convert_mcp_clients_to_tools: {e}'
logger.error(error_msg)
mcp_error_collector.add_error(
server_name='general',
server_type='conversion',
error_message=error_msg,
exception_details=str(e),
)
return []
return all_mcp_tools
@ -54,6 +64,7 @@ async def create_mcp_clients(
sse_servers: list[MCPSSEServerConfig],
shttp_servers: list[MCPSHTTPServerConfig],
conversation_id: str | None = None,
stdio_servers: list[MCPStdioServerConfig] | None = None,
) -> list[MCPClient]:
import sys
@ -64,9 +75,13 @@ async def create_mcp_clients(
)
return []
servers: list[MCPSSEServerConfig | MCPSHTTPServerConfig] = [
if stdio_servers is None:
stdio_servers = []
servers: list[MCPSSEServerConfig | MCPSHTTPServerConfig | MCPStdioServerConfig] = [
*sse_servers,
*shttp_servers,
*stdio_servers,
]
if not servers:
@ -75,6 +90,17 @@ async def create_mcp_clients(
mcp_clients = []
for server in servers:
if isinstance(server, MCPStdioServerConfig):
logger.info(f'Initializing MCP agent for {server} with stdio connection...')
client = MCPClient()
try:
await client.connect_stdio(server)
mcp_clients.append(client)
except Exception as e:
# Error is already logged and collected in client.connect_stdio()
logger.error(f'Failed to connect to {server}: {str(e)}', exc_info=True)
continue
is_shttp = isinstance(server, MCPSHTTPServerConfig)
connection_type = 'SHTTP' if is_shttp else 'SSE'
logger.info(
@ -89,13 +115,14 @@ async def create_mcp_clients(
mcp_clients.append(client)
except Exception as e:
# Error is already logged and collected in client.connect_http()
logger.error(f'Failed to connect to {server}: {str(e)}', exc_info=True)
return mcp_clients
async def fetch_mcp_tools_from_config(
mcp_config: MCPConfig, conversation_id: str | None = None
mcp_config: MCPConfig, conversation_id: str | None = None, use_stdio: bool = False
) -> list[dict]:
"""
Retrieves the list of MCP tools from the MCP clients.
@ -103,6 +130,7 @@ async def fetch_mcp_tools_from_config(
Args:
mcp_config: The MCP configuration
conversation_id: Optional conversation ID to associate with the MCP clients
use_stdio: Whether to use stdio servers for MCP clients, set to True when running from a CLI runtime
Returns:
A list of tool dictionaries. Returns an empty list if no connections could be established.
@ -120,7 +148,10 @@ async def fetch_mcp_tools_from_config(
logger.debug(f'Creating MCP clients with config: {mcp_config}')
# Create clients - this will fetch tools but not maintain active connections
mcp_clients = await create_mcp_clients(
mcp_config.sse_servers, mcp_config.shttp_servers, conversation_id
mcp_config.sse_servers,
mcp_config.shttp_servers,
conversation_id,
mcp_config.stdio_servers if use_stdio else [],
)
if not mcp_clients:
@ -131,7 +162,14 @@ async def fetch_mcp_tools_from_config(
mcp_tools = convert_mcp_clients_to_tools(mcp_clients)
except Exception as e:
logger.error(f'Error fetching MCP tools: {str(e)}')
error_msg = f'Error fetching MCP tools: {str(e)}'
logger.error(error_msg)
mcp_error_collector.add_error(
server_name='general',
server_type='fetch',
error_message=error_msg,
exception_details=str(e),
)
return []
logger.debug(f'MCP tools: {mcp_tools}')
@ -200,7 +238,9 @@ async def call_tool_mcp(mcp_clients: list[MCPClient], action: MCPAction) -> Obse
)
async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memory'):
async def add_mcp_tools_to_agent(
agent: 'Agent', runtime: Runtime, memory: 'Memory'
) -> MCPConfig:
"""
Add MCP tools to an agent.
"""
@ -231,13 +271,18 @@ async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memo
# Check if this stdio server is already in the config
if stdio_server not in extra_stdio_servers:
extra_stdio_servers.append(stdio_server)
logger.info(f'Added microagent stdio server: {stdio_server.name}')
logger.warning(
f'Added microagent stdio server: {stdio_server.name}'
)
# Add the runtime as another MCP server
updated_mcp_config = runtime.get_mcp_config(extra_stdio_servers)
# Fetch the MCP tools
mcp_tools = await fetch_mcp_tools_from_config(updated_mcp_config)
# Only use stdio if run from a CLI runtime
mcp_tools = await fetch_mcp_tools_from_config(
updated_mcp_config, use_stdio=isinstance(runtime, CLIRuntime)
)
logger.info(
f'Loaded {len(mcp_tools)} MCP tools: {[tool["function"]["name"] for tool in mcp_tools]}'
@ -245,3 +290,5 @@ async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memo
# Set the MCP tools on the agent
agent.set_mcp_tools(mcp_tools)
return updated_mcp_config

View File

@ -689,8 +689,69 @@ class CLIRuntime(Runtime):
)
async def call_tool_mcp(self, action: MCPAction) -> Observation:
"""Not implemented for CLI runtime."""
return ErrorObservation('MCP functionality is not implemented in CLIRuntime')
"""Execute an MCP tool action in CLI runtime.
Args:
action: The MCP action to execute
Returns:
Observation: The result of the MCP tool execution
"""
# Check if we're on Windows - MCP is disabled on Windows
if sys.platform == 'win32':
self.log('info', 'MCP functionality is disabled on Windows')
return ErrorObservation('MCP functionality is not available on Windows')
# Import here to avoid circular imports
from openhands.mcp.utils import call_tool_mcp as call_tool_mcp_handler
from openhands.mcp.utils import create_mcp_clients
try:
# Get the MCP config for this runtime
mcp_config = self.get_mcp_config()
if (
not mcp_config.sse_servers
and not mcp_config.shttp_servers
and not mcp_config.stdio_servers
):
self.log('warning', 'No MCP servers configured')
return ErrorObservation('No MCP servers configured')
self.log(
'debug',
f'Creating MCP clients for action {action.name} with servers: '
f'SSE={len(mcp_config.sse_servers)}, SHTTP={len(mcp_config.shttp_servers)}, '
f'stdio={len(mcp_config.stdio_servers)}',
)
# Create clients for this specific operation
mcp_clients = await create_mcp_clients(
mcp_config.sse_servers,
mcp_config.shttp_servers,
self.sid,
mcp_config.stdio_servers,
)
if not mcp_clients:
self.log('warning', 'No MCP clients could be created')
return ErrorObservation(
'No MCP clients could be created - check server configurations'
)
# Call the tool and return the result
self.log(
'debug',
f'Executing MCP tool: {action.name} with arguments: {action.arguments}',
)
result = await call_tool_mcp_handler(mcp_clients, action)
self.log('debug', f'MCP tool {action.name} executed successfully')
return result
except Exception as e:
error_msg = f'Error executing MCP tool {action.name}: {str(e)}'
self.log('error', error_msg)
return ErrorObservation(error_msg)
@property
def workspace_root(self) -> Path:
@ -869,8 +930,40 @@ class CLIRuntime(Runtime):
def get_mcp_config(
self, extra_stdio_servers: list[MCPStdioServerConfig] | None = None
) -> MCPConfig:
# TODO: Load MCP config from a local file
return MCPConfig()
"""Get MCP configuration for CLI runtime.
Args:
extra_stdio_servers: Additional stdio servers to include in the config
Returns:
MCPConfig: The MCP configuration with stdio servers and any configured SSE/SHTTP servers
"""
# Check if we're on Windows - MCP is disabled on Windows
if sys.platform == 'win32':
self.log('debug', 'MCP is disabled on Windows, returning empty config')
return MCPConfig(sse_servers=[], stdio_servers=[], shttp_servers=[])
# Note: we update the self.config.mcp directly for CLI runtime, which is different from other runtimes.
mcp_config = self.config.mcp
# Add any extra stdio servers
if extra_stdio_servers:
current_stdio_servers = list(mcp_config.stdio_servers)
for extra_server in extra_stdio_servers:
# Check if this stdio server is already in the config
if extra_server not in current_stdio_servers:
current_stdio_servers.append(extra_server)
self.log('info', f'Added extra stdio server: {extra_server.name}')
mcp_config.stdio_servers = current_stdio_servers
self.log(
'debug',
f'CLI MCP config: {len(mcp_config.sse_servers)} SSE servers, '
f'{len(mcp_config.stdio_servers)} stdio servers, '
f'{len(mcp_config.shttp_servers)} SHTTP servers',
)
return mcp_config
def subscribe_to_shell_stream(
self, callback: Callable[[str], None] | None = None

View File

@ -3,10 +3,12 @@ from unittest.mock import MagicMock, patch
import pytest
from openhands.cli.commands import (
display_mcp_servers,
handle_commands,
handle_exit_command,
handle_help_command,
handle_init_command,
handle_mcp_command,
handle_new_command,
handle_resume_command,
handle_settings_command,
@ -143,6 +145,18 @@ class TestHandleCommands:
assert reload_microagents is False
assert new_session is False
@pytest.mark.asyncio
@patch('openhands.cli.commands.handle_mcp_command')
async def test_handle_mcp_command(self, mock_handle_mcp, mock_dependencies):
close_repl, reload_microagents, new_session, _ = await handle_commands(
'/mcp', **mock_dependencies
)
mock_handle_mcp.assert_called_once_with(mock_dependencies['config'])
assert close_repl is False
assert reload_microagents is False
assert new_session is False
@pytest.mark.asyncio
async def test_handle_unknown_command(self, mock_dependencies):
user_message = 'Hello, this is not a command'
@ -219,6 +233,78 @@ class TestHandleHelpCommand:
mock_display_help.assert_called_once()
class TestDisplayMcpServers:
@patch('openhands.cli.commands.print_formatted_text')
def test_display_mcp_servers_no_servers(self, mock_print):
from openhands.core.config.mcp_config import MCPConfig
config = MagicMock(spec=OpenHandsConfig)
config.mcp = MCPConfig() # Empty config with no servers
display_mcp_servers(config)
mock_print.assert_called_once()
call_args = mock_print.call_args[0][0]
assert 'No custom MCP servers configured' in call_args
assert (
'https://docs.all-hands.dev/usage/how-to/cli-mode#using-mcp-servers'
in call_args
)
@patch('openhands.cli.commands.print_formatted_text')
def test_display_mcp_servers_with_servers(self, mock_print):
from openhands.core.config.mcp_config import (
MCPConfig,
MCPSHTTPServerConfig,
MCPSSEServerConfig,
MCPStdioServerConfig,
)
config = MagicMock(spec=OpenHandsConfig)
config.mcp = MCPConfig(
sse_servers=[MCPSSEServerConfig(url='https://example.com/sse')],
stdio_servers=[MCPStdioServerConfig(name='tavily', command='npx')],
shttp_servers=[MCPSHTTPServerConfig(url='http://localhost:3000/mcp')],
)
display_mcp_servers(config)
# Should be called multiple times for different sections
assert mock_print.call_count >= 4
# Check that the summary is printed
first_call = mock_print.call_args_list[0][0][0]
assert 'Configured MCP servers:' in first_call
assert 'SSE servers: 1' in first_call
assert 'Stdio servers: 1' in first_call
assert 'SHTTP servers: 1' in first_call
assert 'Total: 3' in first_call
class TestHandleMcpCommand:
@pytest.mark.asyncio
@patch('openhands.cli.commands.cli_confirm')
@patch('openhands.cli.commands.display_mcp_servers')
async def test_handle_mcp_command_list_action(self, mock_display, mock_cli_confirm):
config = MagicMock(spec=OpenHandsConfig)
mock_cli_confirm.return_value = 0 # List action
await handle_mcp_command(config)
mock_cli_confirm.assert_called_once_with(
config,
'MCP Server Configuration',
[
'List configured servers',
'Add new server',
'Remove server',
'View errors',
'Go back',
],
)
mock_display.assert_called_once_with(config)
class TestHandleStatusCommand:
@patch('openhands.cli.commands.display_status')
def test_status_command(self, mock_display_status):
@ -496,3 +582,16 @@ class TestHandleResumeCommand:
# Check the return values
assert close_repl is True
assert new_session_requested is False
class TestMCPErrorHandling:
"""Test MCP error handling in commands."""
@patch('openhands.cli.commands.display_mcp_errors')
def test_handle_mcp_errors_command(self, mock_display_errors):
"""Test handling MCP errors command."""
from openhands.cli.commands import handle_mcp_errors_command
handle_mcp_errors_command()
mock_display_errors.assert_called_once()

View File

@ -0,0 +1,106 @@
"""Tests for CLI server management functionality."""
from unittest.mock import MagicMock, patch
import pytest
from openhands.cli.commands import (
display_mcp_servers,
remove_mcp_server,
)
from openhands.core.config import OpenHandsConfig
from openhands.core.config.mcp_config import (
MCPConfig,
MCPSSEServerConfig,
MCPStdioServerConfig,
)
class TestMCPServerManagement:
"""Test MCP server management functions."""
def setup_method(self):
"""Set up test fixtures."""
self.config = MagicMock(spec=OpenHandsConfig)
self.config.cli = MagicMock()
self.config.cli.vi_mode = False
@patch('openhands.cli.commands.print_formatted_text')
def test_display_mcp_servers_no_servers(self, mock_print):
"""Test displaying MCP servers when none are configured."""
self.config.mcp = MCPConfig() # Empty config
display_mcp_servers(self.config)
mock_print.assert_called_once()
call_args = mock_print.call_args[0][0]
assert 'No custom MCP servers configured' in call_args
@patch('openhands.cli.commands.print_formatted_text')
def test_display_mcp_servers_with_servers(self, mock_print):
"""Test displaying MCP servers when some are configured."""
self.config.mcp = MCPConfig(
sse_servers=[MCPSSEServerConfig(url='http://test.com')],
stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')],
)
display_mcp_servers(self.config)
# Should be called multiple times for different sections
assert mock_print.call_count >= 2
# Check that the summary is printed
first_call = mock_print.call_args_list[0][0][0]
assert 'Configured MCP servers:' in first_call
assert 'SSE servers: 1' in first_call
assert 'Stdio servers: 1' in first_call
@pytest.mark.asyncio
@patch('openhands.cli.commands.cli_confirm')
@patch('openhands.cli.commands.print_formatted_text')
async def test_remove_mcp_server_no_servers(self, mock_print, mock_cli_confirm):
"""Test removing MCP server when none are configured."""
self.config.mcp = MCPConfig() # Empty config
await remove_mcp_server(self.config)
mock_print.assert_called_once_with('No MCP servers configured to remove.')
mock_cli_confirm.assert_not_called()
@pytest.mark.asyncio
@patch('openhands.cli.commands.cli_confirm')
@patch('openhands.cli.commands.load_config_file')
@patch('openhands.cli.commands.save_config_file')
@patch('openhands.cli.commands.print_formatted_text')
async def test_remove_mcp_server_success(
self, mock_print, mock_save, mock_load, mock_cli_confirm
):
"""Test successfully removing an MCP server."""
# Set up config with servers
self.config.mcp = MCPConfig(
sse_servers=[MCPSSEServerConfig(url='http://test.com')],
stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')],
)
# Mock user selections
mock_cli_confirm.side_effect = [0, 0] # Select first server, confirm removal
# Mock config file operations
mock_load.return_value = {
'mcp': {
'sse_servers': [{'url': 'http://test.com'}],
'stdio_servers': [{'name': 'test-stdio', 'command': 'python'}],
}
}
await remove_mcp_server(self.config)
# Should have been called twice (select server, confirm removal)
assert mock_cli_confirm.call_count == 2
mock_save.assert_called_once()
# Check that success message was printed
success_calls = [
call for call in mock_print.call_args_list if 'removed' in str(call[0][0])
]
assert len(success_calls) >= 1

View File

@ -0,0 +1,156 @@
"""Tests for CLI Runtime MCP functionality."""
from unittest.mock import MagicMock, patch
import pytest
from openhands.core.config import OpenHandsConfig
from openhands.core.config.mcp_config import (
MCPConfig,
MCPSSEServerConfig,
MCPStdioServerConfig,
)
from openhands.events.action.mcp import MCPAction
from openhands.events.observation import ErrorObservation
from openhands.events.observation.mcp import MCPObservation
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
class TestCLIRuntimeMCP:
"""Test MCP functionality in CLI Runtime."""
def setup_method(self):
"""Set up test fixtures."""
self.config = OpenHandsConfig()
self.event_stream = MagicMock()
self.runtime = CLIRuntime(
config=self.config, event_stream=self.event_stream, sid='test-session'
)
@pytest.mark.asyncio
async def test_call_tool_mcp_no_servers_configured(self):
"""Test MCP call with no servers configured."""
# Set up empty MCP config
self.runtime.config.mcp = MCPConfig()
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
with patch('sys.platform', 'linux'):
result = await self.runtime.call_tool_mcp(action)
assert isinstance(result, ErrorObservation)
assert 'No MCP servers configured' in result.content
@pytest.mark.asyncio
@patch('openhands.mcp.utils.create_mcp_clients')
async def test_call_tool_mcp_no_clients_created(self, mock_create_clients):
"""Test MCP call when no clients can be created."""
# Set up MCP config with servers
self.runtime.config.mcp = MCPConfig(
sse_servers=[MCPSSEServerConfig(url='http://test.com')]
)
# Mock create_mcp_clients to return empty list
mock_create_clients.return_value = []
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
with patch('sys.platform', 'linux'):
result = await self.runtime.call_tool_mcp(action)
assert isinstance(result, ErrorObservation)
assert 'No MCP clients could be created' in result.content
mock_create_clients.assert_called_once()
@pytest.mark.asyncio
@patch('openhands.mcp.utils.create_mcp_clients')
@patch('openhands.mcp.utils.call_tool_mcp')
async def test_call_tool_mcp_success(self, mock_call_tool, mock_create_clients):
"""Test successful MCP tool call."""
# Set up MCP config with servers
self.runtime.config.mcp = MCPConfig(
sse_servers=[MCPSSEServerConfig(url='http://test.com')],
stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')],
)
# Mock successful client creation
mock_client = MagicMock()
mock_create_clients.return_value = [mock_client]
# Mock successful tool call
expected_observation = MCPObservation(
content='{"result": "success"}',
name='test_tool',
arguments={'arg1': 'value1'},
)
mock_call_tool.return_value = expected_observation
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
with patch('sys.platform', 'linux'):
result = await self.runtime.call_tool_mcp(action)
assert result == expected_observation
mock_create_clients.assert_called_once_with(
self.runtime.config.mcp.sse_servers,
self.runtime.config.mcp.shttp_servers,
self.runtime.sid,
self.runtime.config.mcp.stdio_servers,
)
mock_call_tool.assert_called_once_with([mock_client], action)
@pytest.mark.asyncio
@patch('openhands.mcp.utils.create_mcp_clients')
async def test_call_tool_mcp_exception_handling(self, mock_create_clients):
"""Test exception handling in MCP tool call."""
# Set up MCP config with servers
self.runtime.config.mcp = MCPConfig(
sse_servers=[MCPSSEServerConfig(url='http://test.com')]
)
# Mock create_mcp_clients to raise an exception
mock_create_clients.side_effect = Exception('Connection error')
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
with patch('sys.platform', 'linux'):
result = await self.runtime.call_tool_mcp(action)
assert isinstance(result, ErrorObservation)
assert 'Error executing MCP tool test_tool' in result.content
assert 'Connection error' in result.content
def test_get_mcp_config_basic(self):
"""Test basic MCP config retrieval."""
# Set up MCP config
expected_config = MCPConfig(
sse_servers=[MCPSSEServerConfig(url='http://test.com')],
stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')],
)
self.runtime.config.mcp = expected_config
with patch('sys.platform', 'linux'):
result = self.runtime.get_mcp_config()
assert result == expected_config
def test_get_mcp_config_with_extra_stdio_servers(self):
"""Test MCP config with extra stdio servers."""
# Set up initial MCP config
initial_stdio_server = MCPStdioServerConfig(name='initial', command='python')
self.runtime.config.mcp = MCPConfig(stdio_servers=[initial_stdio_server])
# Add extra stdio servers
extra_servers = [
MCPStdioServerConfig(name='extra1', command='node'),
MCPStdioServerConfig(name='extra2', command='java'),
]
with patch('sys.platform', 'linux'):
result = self.runtime.get_mcp_config(extra_stdio_servers=extra_servers)
# Should have all three servers
assert len(result.stdio_servers) == 3
assert initial_stdio_server in result.stdio_servers
assert extra_servers[0] in result.stdio_servers
assert extra_servers[1] in result.stdio_servers

View File

@ -9,6 +9,9 @@ from openhands.cli.tui import (
display_banner,
display_command,
display_event,
display_mcp_action,
display_mcp_errors,
display_mcp_observation,
display_message,
display_runtime_initialization_message,
display_shutdown_message,
@ -24,14 +27,17 @@ from openhands.events.action import (
Action,
ActionConfirmationStatus,
CmdRunAction,
MCPAction,
MessageAction,
)
from openhands.events.observation import (
CmdOutputObservation,
FileEditObservation,
FileReadObservation,
MCPObservation,
)
from openhands.llm.metrics import Metrics
from openhands.mcp.error_collector import MCPError
class TestDisplayFunctions:
@ -72,6 +78,35 @@ class TestDisplayFunctions:
args, kwargs = mock_print.call_args_list[0]
assert "Let's start building" in str(args[0])
@patch('openhands.cli.tui.print_formatted_text')
def test_display_welcome_message_with_message(self, mock_print):
message = 'Test message'
display_welcome_message(message)
assert mock_print.call_count == 2
# Check the first call contains the welcome message
args, kwargs = mock_print.call_args_list[0]
message_text = str(args[0])
assert "Let's start building" in message_text
# Check the second call contains the custom message
args, kwargs = mock_print.call_args_list[1]
message_text = str(args[0])
assert 'Test message' in message_text
assert 'Type /help for help' in message_text
@patch('openhands.cli.tui.print_formatted_text')
def test_display_welcome_message_without_message(self, mock_print):
display_welcome_message()
assert mock_print.call_count == 2
# Check the first call contains the welcome message
args, kwargs = mock_print.call_args_list[0]
message_text = str(args[0])
assert "Let's start building" in message_text
# Check the second call contains the default message
args, kwargs = mock_print.call_args_list[1]
message_text = str(args[0])
assert 'What do you want to build?' in message_text
assert 'Type /help for help' in message_text
@patch('openhands.cli.tui.display_message')
def test_display_event_message_action(self, mock_display_message):
config = MagicMock(spec=OpenHandsConfig)
@ -147,6 +182,71 @@ class TestDisplayFunctions:
mock_display_message.assert_called_once_with('Thinking about this...')
@patch('openhands.cli.tui.display_mcp_action')
def test_display_event_mcp_action(self, mock_display_mcp_action):
config = MagicMock(spec=OpenHandsConfig)
mcp_action = MCPAction(name='test_tool', arguments={'param': 'value'})
display_event(mcp_action, config)
mock_display_mcp_action.assert_called_once_with(mcp_action)
@patch('openhands.cli.tui.display_mcp_observation')
def test_display_event_mcp_observation(self, mock_display_mcp_observation):
config = MagicMock(spec=OpenHandsConfig)
mcp_observation = MCPObservation(
content='Tool result', name='test_tool', arguments={'param': 'value'}
)
display_event(mcp_observation, config)
mock_display_mcp_observation.assert_called_once_with(mcp_observation)
@patch('openhands.cli.tui.print_container')
def test_display_mcp_action(self, mock_print_container):
mcp_action = MCPAction(name='test_tool', arguments={'param': 'value'})
display_mcp_action(mcp_action)
mock_print_container.assert_called_once()
container = mock_print_container.call_args[0][0]
assert 'test_tool' in container.body.text
assert 'param' in container.body.text
@patch('openhands.cli.tui.print_container')
def test_display_mcp_action_no_args(self, mock_print_container):
mcp_action = MCPAction(name='test_tool')
display_mcp_action(mcp_action)
mock_print_container.assert_called_once()
container = mock_print_container.call_args[0][0]
assert 'test_tool' in container.body.text
assert 'Arguments' not in container.body.text
@patch('openhands.cli.tui.print_container')
def test_display_mcp_observation(self, mock_print_container):
mcp_observation = MCPObservation(
content='Tool result', name='test_tool', arguments={'param': 'value'}
)
display_mcp_observation(mcp_observation)
mock_print_container.assert_called_once()
container = mock_print_container.call_args[0][0]
assert 'test_tool' in container.body.text
assert 'Tool result' in container.body.text
@patch('openhands.cli.tui.print_container')
def test_display_mcp_observation_no_content(self, mock_print_container):
mcp_observation = MCPObservation(content='', name='test_tool')
display_mcp_observation(mcp_observation)
mock_print_container.assert_called_once()
container = mock_print_container.call_args[0][0]
assert 'No output' in container.body.text
@patch('openhands.cli.tui.print_formatted_text')
def test_display_message(self, mock_print):
message = 'Test message'
@ -307,14 +407,86 @@ class TestReadConfirmationInput:
result = await read_confirmation_input(config=cfg)
assert result == 'always'
@pytest.mark.asyncio
"""Tests for CLI TUI MCP functionality."""
class TestMCPTUIDisplay:
"""Test MCP TUI display functions."""
@patch('openhands.cli.tui.print_container')
def test_display_mcp_action_with_arguments(self, mock_print_container):
"""Test displaying MCP action with arguments."""
mcp_action = MCPAction(
name='test_tool', arguments={'param1': 'value1', 'param2': 42}
)
display_mcp_action(mcp_action)
mock_print_container.assert_called_once()
container = mock_print_container.call_args[0][0]
assert 'test_tool' in container.body.text
assert 'param1' in container.body.text
assert 'value1' in container.body.text
@patch('openhands.cli.tui.print_container')
def test_display_mcp_observation_with_content(self, mock_print_container):
"""Test displaying MCP observation with content."""
mcp_observation = MCPObservation(
content='Tool execution successful',
name='test_tool',
arguments={'param': 'value'},
)
display_mcp_observation(mcp_observation)
mock_print_container.assert_called_once()
container = mock_print_container.call_args[0][0]
assert 'test_tool' in container.body.text
assert 'Tool execution successful' in container.body.text
@patch('openhands.cli.tui.print_formatted_text')
@patch('openhands.cli.tui.cli_confirm')
async def test_read_confirmation_input_edit(self, mock_confirm, mock_print):
mock_confirm.return_value = 3 # user picked third menu item
@patch('openhands.cli.tui.mcp_error_collector')
def test_display_mcp_errors_no_errors(self, mock_collector, mock_print):
"""Test displaying MCP errors when none exist."""
mock_collector.get_errors.return_value = []
cfg = MagicMock() # <- no spec for simplicity
cfg.cli = MagicMock(vi_mode=False)
display_mcp_errors()
result = await read_confirmation_input(config=cfg)
assert result == 'edit'
mock_print.assert_called_once()
call_args = mock_print.call_args[0][0]
assert 'No MCP errors detected' in str(call_args)
@patch('openhands.cli.tui.print_container')
@patch('openhands.cli.tui.print_formatted_text')
@patch('openhands.cli.tui.mcp_error_collector')
def test_display_mcp_errors_with_errors(
self, mock_collector, mock_print, mock_print_container
):
"""Test displaying MCP errors when some exist."""
# Create mock errors
error1 = MCPError(
timestamp=1234567890.0,
server_name='test-server-1',
server_type='stdio',
error_message='Connection failed',
exception_details='Socket timeout',
)
error2 = MCPError(
timestamp=1234567891.0,
server_name='test-server-2',
server_type='sse',
error_message='Server unreachable',
)
mock_collector.get_errors.return_value = [error1, error2]
display_mcp_errors()
# Should print error count header
assert mock_print.call_count >= 1
header_call = mock_print.call_args_list[0][0][0]
assert '2 MCP error(s) detected' in str(header_call)
# Should print containers for each error
assert mock_print_container.call_count == 2

View File

@ -27,10 +27,9 @@ def test_empty_sse_config():
def test_invalid_sse_url():
"""Test SSE configuration with invalid URL format."""
config = MCPConfig(sse_servers=[MCPSSEServerConfig(url='not_a_url')])
with pytest.raises(ValueError) as exc_info:
config.validate_servers()
assert 'Invalid URL' in str(exc_info.value)
with pytest.raises(ValidationError) as exc_info:
MCPSSEServerConfig(url='not_a_url')
assert 'URL must include a scheme' in str(exc_info.value)
def test_duplicate_sse_urls():
@ -64,7 +63,7 @@ def test_from_toml_section_invalid_sse():
}
with pytest.raises(ValueError) as exc_info:
MCPConfig.from_toml_section(data)
assert 'Invalid URL' in str(exc_info.value)
assert 'URL must include a scheme' in str(exc_info.value)
def test_complex_urls():
@ -246,3 +245,28 @@ def test_stdio_server_equality_with_different_env_order():
# Should not be equal
assert server1 != server5
def test_mcp_stdio_server_args_parsing_basic():
"""Test MCPStdioServerConfig args parsing with basic shell-like format."""
# Test basic space-separated parsing
config = MCPStdioServerConfig(
name='test-server', command='python', args='arg1 arg2 arg3'
)
assert config.args == ['arg1', 'arg2', 'arg3']
# Test single argument
config = MCPStdioServerConfig(
name='test-server', command='python', args='single-arg'
)
assert config.args == ['single-arg']
def test_mcp_stdio_server_args_parsing_invalid_quotes():
"""Test MCPStdioServerConfig args parsing with invalid quotes."""
# Test unmatched quotes
with pytest.raises(ValidationError) as exc_info:
MCPStdioServerConfig(
name='test-server', command='python', args='--config "unmatched quote'
)
assert 'Invalid argument format' in str(exc_info.value)

View File

@ -0,0 +1,159 @@
"""Tests for MCP Error Collector functionality."""
import time
from openhands.mcp.error_collector import (
MCPError,
MCPErrorCollector,
mcp_error_collector,
)
class TestMCPError:
"""Test MCPError dataclass."""
def test_mcp_error_creation(self):
"""Test creating an MCP error."""
timestamp = time.time()
error = MCPError(
timestamp=timestamp,
server_name='test-server',
server_type='stdio',
error_message='Connection failed',
exception_details='Socket timeout',
)
assert error.timestamp == timestamp
assert error.server_name == 'test-server'
assert error.server_type == 'stdio'
assert error.error_message == 'Connection failed'
assert error.exception_details == 'Socket timeout'
def test_mcp_error_creation_without_exception_details(self):
"""Test creating an MCP error without exception details."""
timestamp = time.time()
error = MCPError(
timestamp=timestamp,
server_name='test-server',
server_type='sse',
error_message='Server unreachable',
)
assert error.timestamp == timestamp
assert error.server_name == 'test-server'
assert error.server_type == 'sse'
assert error.error_message == 'Server unreachable'
assert error.exception_details is None
class TestMCPErrorCollector:
"""Test MCPErrorCollector functionality."""
def setup_method(self):
"""Set up test fixtures."""
self.collector = MCPErrorCollector()
def test_initialization(self):
"""Test collector initialization."""
assert self.collector._errors == []
assert self.collector._collection_enabled is True
def test_add_error(self):
"""Test adding an error to the collector."""
self.collector.add_error(
server_name='test-server',
server_type='stdio',
error_message='Connection failed',
exception_details='Socket timeout',
)
errors = self.collector.get_errors()
assert len(errors) == 1
assert errors[0].server_name == 'test-server'
assert errors[0].server_type == 'stdio'
assert errors[0].error_message == 'Connection failed'
assert errors[0].exception_details == 'Socket timeout'
assert errors[0].timestamp > 0
def test_add_multiple_errors(self):
"""Test adding multiple errors."""
self.collector.add_error('server1', 'stdio', 'Error 1')
self.collector.add_error('server2', 'sse', 'Error 2')
self.collector.add_error('server3', 'shttp', 'Error 3')
errors = self.collector.get_errors()
assert len(errors) == 3
assert errors[0].server_name == 'server1'
assert errors[1].server_name == 'server2'
assert errors[2].server_name == 'server3'
def test_has_errors(self):
"""Test has_errors method."""
assert not self.collector.has_errors()
self.collector.add_error('server1', 'stdio', 'Error 1')
assert self.collector.has_errors()
self.collector.clear_errors()
assert not self.collector.has_errors()
def test_clear_errors(self):
"""Test clearing errors."""
self.collector.add_error('server1', 'stdio', 'Error 1')
self.collector.add_error('server2', 'sse', 'Error 2')
assert len(self.collector.get_errors()) == 2
self.collector.clear_errors()
assert len(self.collector.get_errors()) == 0
assert not self.collector.has_errors()
def test_enable_disable_collection(self):
"""Test enabling and disabling error collection."""
self.collector.add_error('server1', 'stdio', 'Error 1')
assert len(self.collector.get_errors()) == 1
# Disable collection
self.collector.disable_collection()
# Adding error should be ignored
self.collector.add_error('server2', 'sse', 'Error 2')
assert len(self.collector.get_errors()) == 1 # Still only 1 error
# Re-enable collection
self.collector.enable_collection()
# Adding error should work again
self.collector.add_error('server3', 'shttp', 'Error 3')
assert len(self.collector.get_errors()) == 2
class TestGlobalMCPErrorCollector:
"""Test the global MCP error collector instance."""
def setup_method(self):
"""Clear global collector before each test."""
mcp_error_collector.clear_errors()
mcp_error_collector.enable_collection()
def teardown_method(self):
"""Clean up after each test."""
mcp_error_collector.clear_errors()
mcp_error_collector.enable_collection()
def test_global_collector_exists(self):
"""Test that global collector instance exists."""
assert mcp_error_collector is not None
assert isinstance(mcp_error_collector, MCPErrorCollector)
def test_global_collector_functionality(self):
"""Test basic functionality of global collector."""
assert not mcp_error_collector.has_errors()
mcp_error_collector.add_error('global-server', 'stdio', 'Global error')
assert mcp_error_collector.has_errors()
assert mcp_error_collector.get_error_count() == 1
errors = mcp_error_collector.get_errors()
assert len(errors) == 1
assert errors[0].server_name == 'global-server'

View File

@ -5,7 +5,7 @@ import pytest
# Import the module, not the functions directly to avoid circular imports
import openhands.mcp.utils
from openhands.core.config.mcp_config import MCPSSEServerConfig
from openhands.core.config.mcp_config import MCPSSEServerConfig, MCPStdioServerConfig
from openhands.events.action.mcp import MCPAction
from openhands.events.observation.mcp import MCPObservation
@ -161,3 +161,136 @@ async def test_call_tool_mcp_success():
assert isinstance(observation, MCPObservation)
assert json.loads(observation.content) == {'result': 'success'}
mock_client.call_tool.assert_called_once_with('test_tool', {'arg1': 'value1'})
@pytest.mark.asyncio
@patch('openhands.mcp.utils.MCPClient')
async def test_create_mcp_clients_stdio_success(mock_mcp_client):
"""Test successful creation of MCP clients with stdio servers."""
# Setup mock
mock_client_instance = AsyncMock()
mock_mcp_client.return_value = mock_client_instance
mock_client_instance.connect_stdio = AsyncMock()
# Test with stdio servers
stdio_server_configs = [
MCPStdioServerConfig(
name='test-server-1',
command='python',
args=['-m', 'server1'],
env={'DEBUG': 'true'},
),
MCPStdioServerConfig(
name='test-server-2',
command='/usr/bin/node',
args=['server2.js'],
env={'NODE_ENV': 'development'},
),
]
clients = await openhands.mcp.utils.create_mcp_clients(
[], [], stdio_servers=stdio_server_configs
)
# Verify
assert len(clients) == 2
assert mock_mcp_client.call_count == 2
# Check that connect_stdio was called with correct parameters
mock_client_instance.connect_stdio.assert_any_call(stdio_server_configs[0])
mock_client_instance.connect_stdio.assert_any_call(stdio_server_configs[1])
@pytest.mark.asyncio
@patch('openhands.mcp.utils.MCPClient')
async def test_create_mcp_clients_stdio_connection_failure(mock_mcp_client):
"""Test handling of stdio connection failures when creating MCP clients."""
# Setup mock
mock_client_instance = AsyncMock()
mock_mcp_client.return_value = mock_client_instance
# First connection succeeds, second fails
mock_client_instance.connect_stdio.side_effect = [
None, # Success
Exception('Stdio connection failed'), # Failure
]
stdio_server_configs = [
MCPStdioServerConfig(name='server1', command='python'),
MCPStdioServerConfig(name='server2', command='invalid_command'),
]
clients = await openhands.mcp.utils.create_mcp_clients(
[], [], stdio_servers=stdio_server_configs
)
# Verify only one client was successfully created
assert len(clients) == 1
@pytest.mark.asyncio
@patch('openhands.mcp.utils.create_mcp_clients')
async def test_fetch_mcp_tools_from_config_with_stdio(mock_create_clients):
"""Test fetching MCP tools with stdio servers enabled."""
from openhands.core.config.mcp_config import MCPConfig
# Setup mock clients
mock_client = MagicMock()
mock_tool = MagicMock()
mock_tool.to_param.return_value = {'function': {'name': 'stdio_tool'}}
mock_client.tools = [mock_tool]
mock_create_clients.return_value = [mock_client]
# Create config with stdio servers
mcp_config = MCPConfig(
stdio_servers=[MCPStdioServerConfig(name='test-server', command='python')]
)
# Test with use_stdio=True
tools = await openhands.mcp.utils.fetch_mcp_tools_from_config(
mcp_config, conversation_id='test-conv', use_stdio=True
)
# Verify
assert len(tools) == 1
assert tools[0] == {'function': {'name': 'stdio_tool'}}
# Verify create_mcp_clients was called with stdio servers
mock_create_clients.assert_called_once_with(
[], [], 'test-conv', mcp_config.stdio_servers
)
@pytest.mark.asyncio
async def test_call_tool_mcp_stdio_client():
"""Test calling MCP tool on a stdio client."""
# Create mock stdio client with the requested tool
mock_client = MagicMock()
mock_tool = MagicMock()
mock_tool.name = 'stdio_test_tool'
mock_client.tools = [mock_tool]
# Setup response
mock_response = MagicMock()
mock_response.model_dump.return_value = {
'result': 'stdio_success',
'data': 'test_data',
}
# Setup call_tool method
mock_client.call_tool = AsyncMock(return_value=mock_response)
action = MCPAction(name='stdio_test_tool', arguments={'input': 'test_input'})
# Call the function
observation = await openhands.mcp.utils.call_tool_mcp([mock_client], action)
# Verify
assert isinstance(observation, MCPObservation)
assert json.loads(observation.content) == {
'result': 'stdio_success',
'data': 'test_data',
}
mock_client.call_tool.assert_called_once_with(
'stdio_test_tool', {'input': 'test_input'}
)