From fbd92802398b065c7d232691ce1707765f32a608 Mon Sep 17 00:00:00 2001 From: "Ryan H. Tran" Date: Thu, 24 Jul 2025 00:06:01 +0700 Subject: [PATCH] Add MCP support for CLI (#9519) Co-authored-by: openhands Co-authored-by: Xingyao Wang --- docs/usage/how-to/cli-mode.mdx | 38 +- openhands/cli/commands.py | 513 +++++++++++++++++++++- openhands/cli/main.py | 40 +- openhands/cli/tui.py | 113 +++++ openhands/core/config/mcp_config.py | 138 +++++- openhands/core/config/utils.py | 16 +- openhands/mcp/__init__.py | 2 + openhands/mcp/client.py | 57 ++- openhands/mcp/error_collector.py | 78 ++++ openhands/mcp/utils.py | 63 ++- openhands/runtime/impl/cli/cli_runtime.py | 101 ++++- tests/unit/test_cli_commands.py | 99 +++++ tests/unit/test_cli_config_management.py | 106 +++++ tests/unit/test_cli_runtime_mcp.py | 156 +++++++ tests/unit/test_cli_tui.py | 188 +++++++- tests/unit/test_mcp_config.py | 34 +- tests/unit/test_mcp_error_collector.py | 159 +++++++ tests/unit/test_mcp_utils.py | 135 +++++- 18 files changed, 2001 insertions(+), 35 deletions(-) create mode 100644 openhands/mcp/error_collector.py create mode 100644 tests/unit/test_cli_config_management.py create mode 100644 tests/unit/test_cli_runtime_mcp.py create mode 100644 tests/unit/test_mcp_error_collector.py diff --git a/docs/usage/how-to/cli-mode.mdx b/docs/usage/how-to/cli-mode.mdx index aa9f579de8..3b0dd48e12 100644 --- a/docs/usage/how-to/cli-mode.mdx +++ b/docs/usage/how-to/cli-mode.mdx @@ -153,6 +153,7 @@ You can use the following commands whenever the prompt (`>`) is displayed: | `/new` | Start a new conversation | | `/settings` | View and modify current LLM/agent settings | | `/resume` | Resume the agent if paused | +| `/mcp` | Manage MCP server configuration and view connection errors | #### Settings and Configuration @@ -162,7 +163,7 @@ follow the prompts: - **Basic settings**: Choose a model/provider and enter your API key. - **Advanced settings**: Set custom endpoints, enable or disable confirmation mode, and configure memory condensation. -Settings can also be managed via the `config.toml` file. +Settings can also be managed via the `config.toml` file in the current directory or `~/.openhands/config.toml`. #### Repository Initialization @@ -174,6 +175,41 @@ project details and structure. Use this when onboarding the agent to a new codeb You can pause the agent while it is running by pressing `Ctrl-P`. To continue the conversation after pausing, simply type `/resume` at the prompt. +#### MCP Server Management + +To configure Model Context Protocol (MCP) servers, you can refer to the documentation on [MCP servers](../mcp) and use the `/mcp` command in the CLI. This command provides an interactive interface for managing Model Context Protocol (MCP) servers: + +- **List configured servers**: View all currently configured MCP servers (SSE, Stdio, and SHTTP) +- **Add new server**: Interactively add a new MCP server with guided prompts +- **Remove server**: Remove an existing MCP server from your configuration +- **View errors**: Display any connection errors that occurred during MCP server startup + +This command modifies your `~/.openhands/config.toml` file and will prompt you to restart OpenHands for changes to take effect. + +To enable the [Tavily MCP server](https://github.com/tavily-ai/tavily-mcp) search engine, you can set the `search_api_key` under the `[core]` section in the `~/.openhands/config.toml` file. + +##### Example of the `config.toml` file with MCP server configuration: + +```toml +[core] +search_api_key = "tvly-your-api-key-here" + +[mcp] +stdio_servers = [ + {name="fetch", command="uvx", args=["mcp-server-fetch"]}, +] + +sse_servers = [ + # Basic SSE server with just a URL + "http://example.com:8080/sse", +] + +shttp_servers = [ + # Streamable HTTP server with API key authentication + {url="https://secure-example.com/mcp", api_key="your-api-key"} +] +``` + ## Tips and Troubleshooting - Use `/help` at any time to see the list of available commands. diff --git a/openhands/cli/commands.py b/openhands/cli/commands.py index 2aef2a2f9f..2ce1b6536b 100644 --- a/openhands/cli/commands.py +++ b/openhands/cli/commands.py @@ -1,9 +1,15 @@ import asyncio +import os +import sys from pathlib import Path +from typing import Any -from prompt_toolkit import print_formatted_text +import toml +from prompt_toolkit import HTML, print_formatted_text +from prompt_toolkit.patch_stdout import patch_stdout from prompt_toolkit.shortcuts import clear, print_container from prompt_toolkit.widgets import Frame, TextArea +from pydantic import ValidationError from openhands.cli.settings import ( display_settings, @@ -14,9 +20,12 @@ from openhands.cli.tui import ( COLOR_GREY, UsageMetrics, cli_confirm, + create_prompt_session, display_help, + display_mcp_errors, display_shutdown_message, display_status, + read_prompt_input, ) from openhands.cli.utils import ( add_local_config_trusted_dir, @@ -27,6 +36,11 @@ from openhands.cli.utils import ( from openhands.core.config import ( OpenHandsConfig, ) +from openhands.core.config.mcp_config import ( + MCPSHTTPServerConfig, + MCPSSEServerConfig, + MCPStdioServerConfig, +) from openhands.core.schema import AgentState from openhands.core.schema.exit_reason import ExitReason from openhands.events import EventSource @@ -38,6 +52,72 @@ from openhands.events.stream import EventStream from openhands.storage.settings.file_settings_store import FileSettingsStore +async def collect_input(config: OpenHandsConfig, prompt_text: str) -> str | None: + """Collect user input with cancellation support. + + Args: + config: OpenHands configuration + prompt_text: Text to display to user + + Returns: + str | None: User input string, or None if user cancelled + """ + print_formatted_text(prompt_text, end=' ') + user_input = await read_prompt_input(config, '', multiline=False) + + # Check for cancellation + if user_input.strip().lower() in ['/exit', '/cancel', 'cancel']: + return None + + return user_input.strip() + + +def restart_cli() -> None: + """Restart the CLI by replacing the current process.""" + print_formatted_text('🔄 Restarting OpenHands CLI...') + + # Get the current Python executable and script arguments + python_executable = sys.executable + script_args = sys.argv + + # Use os.execv to replace the current process + # This preserves the original command line arguments + try: + os.execv(python_executable, [python_executable] + script_args) + except Exception as e: + print_formatted_text(f'❌ Failed to restart CLI: {e}') + print_formatted_text( + 'Please restart OpenHands manually for changes to take effect.' + ) + + +async def prompt_for_restart(config: OpenHandsConfig) -> bool: + """Prompt user if they want to restart the CLI and return their choice.""" + print_formatted_text('📝 MCP server configuration updated successfully!') + print_formatted_text('The changes will take effect after restarting OpenHands.') + + prompt_session = create_prompt_session(config) + + while True: + try: + with patch_stdout(): + response = await prompt_session.prompt_async( + HTML( + 'Would you like to restart OpenHands now? (y/n): ' + ) + ) + response = response.strip().lower() if response else '' + + if response in ['y', 'yes']: + return True + elif response in ['n', 'no']: + return False + else: + print_formatted_text('Please enter "y" for yes or "n" for no.') + except (KeyboardInterrupt, EOFError): + return False + + async def handle_commands( command: str, event_stream: EventStream, @@ -79,6 +159,8 @@ async def handle_commands( await handle_settings_command(config, settings_store) elif command == '/resume': close_repl, new_session_requested = await handle_resume_command(event_stream) + elif command == '/mcp': + await handle_mcp_command(config) else: close_repl = True action = MessageAction(content=command) @@ -327,3 +409,432 @@ def check_folder_security_agreement(config: OpenHandsConfig, current_dir: str) - return confirm return True + + +async def handle_mcp_command(config: OpenHandsConfig) -> None: + """Handle MCP command with interactive menu.""" + action = cli_confirm( + config, + 'MCP Server Configuration', + [ + 'List configured servers', + 'Add new server', + 'Remove server', + 'View errors', + 'Go back', + ], + ) + + if action == 0: # List + display_mcp_servers(config) + elif action == 1: # Add + await add_mcp_server(config) + elif action == 2: # Remove + await remove_mcp_server(config) + elif action == 3: # View errors + handle_mcp_errors_command() + # action == 4 is "Go back", do nothing + + +def display_mcp_servers(config: OpenHandsConfig) -> None: + """Display MCP server configuration information.""" + mcp_config = config.mcp + + # Count the different types of servers + sse_count = len(mcp_config.sse_servers) + stdio_count = len(mcp_config.stdio_servers) + shttp_count = len(mcp_config.shttp_servers) + total_count = sse_count + stdio_count + shttp_count + + if total_count == 0: + print_formatted_text( + 'No custom MCP servers configured. See the documentation to learn more:\n' + ' https://docs.all-hands.dev/usage/how-to/cli-mode#using-mcp-servers' + ) + else: + print_formatted_text( + f'Configured MCP servers:\n' + f' • SSE servers: {sse_count}\n' + f' • Stdio servers: {stdio_count}\n' + f' • SHTTP servers: {shttp_count}\n' + f' • Total: {total_count}' + ) + + # Show details for each type if they exist + if sse_count > 0: + print_formatted_text('SSE Servers:') + for idx, sse_server in enumerate(mcp_config.sse_servers, 1): + print_formatted_text(f' {idx}. {sse_server.url}') + print_formatted_text('') + + if stdio_count > 0: + print_formatted_text('Stdio Servers:') + for idx, stdio_server in enumerate(mcp_config.stdio_servers, 1): + print_formatted_text( + f' {idx}. {stdio_server.name} ({stdio_server.command})' + ) + print_formatted_text('') + + if shttp_count > 0: + print_formatted_text('SHTTP Servers:') + for idx, shttp_server in enumerate(mcp_config.shttp_servers, 1): + print_formatted_text(f' {idx}. {shttp_server.url}') + print_formatted_text('') + + +def handle_mcp_errors_command() -> None: + """Display MCP connection errors.""" + display_mcp_errors() + + +def get_config_file_path() -> Path: + """Get the path to the config file. By default, we use config.toml in the current working directory. If not found, we use ~/.openhands/config.toml.""" + # Check if config.toml exists in the current directory + current_dir = Path.cwd() / 'config.toml' + if current_dir.exists(): + return current_dir + + # Fallback to the user's home directory + return Path.home() / '.openhands' / 'config.toml' + + +def load_config_file(file_path: Path) -> dict: + """Load the config file, creating it if it doesn't exist.""" + if file_path.exists(): + try: + with open(file_path, 'r') as f: + return toml.load(f) + except Exception: + pass + + # Create directory if it doesn't exist + file_path.parent.mkdir(parents=True, exist_ok=True) + return {} + + +def save_config_file(config_data: dict, file_path: Path) -> None: + """Save the config file.""" + with open(file_path, 'w') as f: + toml.dump(config_data, f) + + +def _ensure_mcp_config_structure(config_data: dict) -> None: + """Ensure MCP configuration structure exists in config data.""" + if 'mcp' not in config_data: + config_data['mcp'] = {} + + +def _add_server_to_config(server_type: str, server_config: dict) -> Path: + """Add a server configuration to the config file.""" + config_file_path = get_config_file_path() + config_data = load_config_file(config_file_path) + _ensure_mcp_config_structure(config_data) + + if server_type not in config_data['mcp']: + config_data['mcp'][server_type] = [] + + config_data['mcp'][server_type].append(server_config) + save_config_file(config_data, config_file_path) + + return config_file_path + + +async def add_mcp_server(config: OpenHandsConfig) -> None: + """Add a new MCP server configuration.""" + # Choose transport type + transport_type = cli_confirm( + config, + 'Select MCP server transport type:', + [ + 'SSE (Server-Sent Events)', + 'Stdio (Standard Input/Output)', + 'SHTTP (Streamable HTTP)', + 'Cancel', + ], + ) + + if transport_type == 3: # Cancel + return + + try: + if transport_type == 0: # SSE + await add_sse_server(config) + elif transport_type == 1: # Stdio + await add_stdio_server(config) + elif transport_type == 2: # SHTTP + await add_shttp_server(config) + except Exception as e: + print_formatted_text(f'Error adding MCP server: {e}') + + +async def add_sse_server(config: OpenHandsConfig) -> None: + """Add an SSE MCP server.""" + print_formatted_text('Adding SSE MCP Server') + + while True: # Retry loop for the entire form + # Collect all inputs + url = await collect_input(config, '\nEnter server URL:') + if url is None: + print_formatted_text('Operation cancelled.') + return + + api_key = await collect_input( + config, '\nEnter API key (optional, press Enter to skip):' + ) + if api_key is None: + print_formatted_text('Operation cancelled.') + return + + # Convert empty string to None for optional field + api_key = api_key if api_key else None + + # Validate all inputs at once + try: + server = MCPSSEServerConfig(url=url, api_key=api_key) + break # Success - exit retry loop + + except ValidationError as e: + # Show all errors at once + print_formatted_text('❌ Please fix the following errors:') + for error in e.errors(): + field = error['loc'][0] if error['loc'] else 'unknown' + print_formatted_text(f' • {field}: {error["msg"]}') + + if cli_confirm(config, '\nTry again?') != 0: + print_formatted_text('Operation cancelled.') + return + + # Save to config file + server_config = {'url': server.url} + if server.api_key: + server_config['api_key'] = server.api_key + + config_file_path = _add_server_to_config('sse_servers', server_config) + print_formatted_text(f'✓ SSE MCP server added to {config_file_path}: {server.url}') + + # Prompt for restart + if await prompt_for_restart(config): + restart_cli() + + +async def add_stdio_server(config: OpenHandsConfig) -> None: + """Add a Stdio MCP server.""" + print_formatted_text('Adding Stdio MCP Server') + + # Get existing server names to check for duplicates + existing_names = [server.name for server in config.mcp.stdio_servers] + + while True: # Retry loop for the entire form + # Collect all inputs + name = await collect_input(config, '\nEnter server name:') + if name is None: + print_formatted_text('Operation cancelled.') + return + + command = await collect_input(config, "\nEnter command (e.g., 'uvx', 'npx'):") + if command is None: + print_formatted_text('Operation cancelled.') + return + + args_input = await collect_input( + config, + '\nEnter arguments (optional, e.g., "-y server-package arg1"):', + ) + if args_input is None: + print_formatted_text('Operation cancelled.') + return + + env_input = await collect_input( + config, + '\nEnter environment variables (KEY=VALUE format, comma-separated, optional):', + ) + if env_input is None: + print_formatted_text('Operation cancelled.') + return + + # Check for duplicate server names + if name in existing_names: + print_formatted_text(f"❌ Server name '{name}' already exists.") + if cli_confirm(config, '\nTry again?') != 0: + print_formatted_text('Operation cancelled.') + return + continue + + # Validate all inputs at once + try: + server = MCPStdioServerConfig( + name=name, + command=command, + args=args_input, # type: ignore # Will be parsed by Pydantic validator + env=env_input, # type: ignore # Will be parsed by Pydantic validator + ) + break # Success - exit retry loop + + except ValidationError as e: + # Show all errors at once + print_formatted_text('❌ Please fix the following errors:') + for error in e.errors(): + field = error['loc'][0] if error['loc'] else 'unknown' + print_formatted_text(f' • {field}: {error["msg"]}') + + if cli_confirm(config, '\nTry again?') != 0: + print_formatted_text('Operation cancelled.') + return + + # Save to config file + server_config: dict[str, Any] = { + 'name': server.name, + 'command': server.command, + } + if server.args: + server_config['args'] = server.args + if server.env: + server_config['env'] = server.env + + config_file_path = _add_server_to_config('stdio_servers', server_config) + print_formatted_text( + f'✓ Stdio MCP server added to {config_file_path}: {server.name}' + ) + + # Prompt for restart + if await prompt_for_restart(config): + restart_cli() + + +async def add_shttp_server(config: OpenHandsConfig) -> None: + """Add an SHTTP MCP server.""" + print_formatted_text('Adding SHTTP MCP Server') + + while True: # Retry loop for the entire form + # Collect all inputs + url = await collect_input(config, '\nEnter server URL:') + if url is None: + print_formatted_text('Operation cancelled.') + return + + api_key = await collect_input( + config, '\nEnter API key (optional, press Enter to skip):' + ) + if api_key is None: + print_formatted_text('Operation cancelled.') + return + + # Convert empty string to None for optional field + api_key = api_key if api_key else None + + # Validate all inputs at once + try: + server = MCPSHTTPServerConfig(url=url, api_key=api_key) + break # Success - exit retry loop + + except ValidationError as e: + # Show all errors at once + print_formatted_text('❌ Please fix the following errors:') + for error in e.errors(): + field = error['loc'][0] if error['loc'] else 'unknown' + print_formatted_text(f' • {field}: {error["msg"]}') + + if cli_confirm(config, '\nTry again?') != 0: + print_formatted_text('Operation cancelled.') + return + + # Save to config file + server_config = {'url': server.url} + if server.api_key: + server_config['api_key'] = server.api_key + + config_file_path = _add_server_to_config('shttp_servers', server_config) + print_formatted_text( + f'✓ SHTTP MCP server added to {config_file_path}: {server.url}' + ) + + # Prompt for restart + if await prompt_for_restart(config): + restart_cli() + + +async def remove_mcp_server(config: OpenHandsConfig) -> None: + """Remove an MCP server configuration.""" + mcp_config = config.mcp + + # Collect all servers with their types + servers: list[tuple[str, str, object]] = [] + + # Add SSE servers + for sse_server in mcp_config.sse_servers: + servers.append(('SSE', sse_server.url, sse_server)) + + # Add Stdio servers + for stdio_server in mcp_config.stdio_servers: + servers.append(('Stdio', stdio_server.name, stdio_server)) + + # Add SHTTP servers + for shttp_server in mcp_config.shttp_servers: + servers.append(('SHTTP', shttp_server.url, shttp_server)) + + if not servers: + print_formatted_text('No MCP servers configured to remove.') + return + + # Create choices for the user + choices = [] + for server_type, identifier, _ in servers: + choices.append(f'{server_type}: {identifier}') + choices.append('Cancel') + + # Let user choose which server to remove + choice = cli_confirm(config, 'Select MCP server to remove:', choices) + + if choice == len(choices) - 1: # Cancel + return + + # Remove the selected server + server_type, identifier, _ = servers[choice] + + # Confirm removal + confirm = cli_confirm( + config, + f'Are you sure you want to remove {server_type} server "{identifier}"?', + ['Yes, remove', 'Cancel'], + ) + + if confirm == 1: # Cancel + return + + # Load config file and remove the server + config_file_path = get_config_file_path() + config_data = load_config_file(config_file_path) + + _ensure_mcp_config_structure(config_data) + + removed = False + + if server_type == 'SSE' and 'sse_servers' in config_data['mcp']: + config_data['mcp']['sse_servers'] = [ + s for s in config_data['mcp']['sse_servers'] if s.get('url') != identifier + ] + removed = True + elif server_type == 'Stdio' and 'stdio_servers' in config_data['mcp']: + config_data['mcp']['stdio_servers'] = [ + s + for s in config_data['mcp']['stdio_servers'] + if s.get('name') != identifier + ] + removed = True + elif server_type == 'SHTTP' and 'shttp_servers' in config_data['mcp']: + config_data['mcp']['shttp_servers'] = [ + s for s in config_data['mcp']['shttp_servers'] if s.get('url') != identifier + ] + removed = True + + if removed: + save_config_file(config_data, config_file_path) + print_formatted_text( + f'✓ {server_type} MCP server "{identifier}" removed from {config_file_path}.' + ) + + # Prompt for restart + if await prompt_for_restart(config): + restart_cli() + else: + print_formatted_text(f'Failed to remove {server_type} server "{identifier}".') diff --git a/openhands/cli/main.py b/openhands/cli/main.py index 1bba8592f1..ae2923f2d7 100644 --- a/openhands/cli/main.py +++ b/openhands/cli/main.py @@ -74,6 +74,7 @@ from openhands.events.observation import ( ) from openhands.io import read_task from openhands.mcp import add_mcp_tools_to_agent +from openhands.mcp.error_collector import mcp_error_collector from openhands.memory.condenser.impl.llm_summarizing_condenser import ( LLMSummarizingCondenserConfig, ) @@ -298,6 +299,10 @@ async def run_session( # Add MCP tools to the agent if agent.config.enable_mcp: + # Clear any previous errors and enable collection + mcp_error_collector.clear_errors() + mcp_error_collector.enable_collection() + # Add OpenHands' MCP server by default _, openhands_mcp_stdio_servers = ( OpenHandsMCPConfigImpl.create_default_mcp_server_config( @@ -309,6 +314,9 @@ async def run_session( await add_mcp_tools_to_agent(agent, runtime, memory) + # Disable collection after startup + mcp_error_collector.disable_collection() + # Clear loading animation is_loaded.set() @@ -319,7 +327,27 @@ async def run_session( if not skip_banner: display_banner(session_id=sid) - welcome_message = 'What do you want to build?' # from the application + welcome_message = '' + + # Display number of MCP servers configured + if agent.config.enable_mcp: + total_mcp_servers = ( + len(runtime.config.mcp.stdio_servers) + + len(runtime.config.mcp.sse_servers) + + len(runtime.config.mcp.shttp_servers) + ) + if total_mcp_servers > 0: + mcp_line = f'Using {len(runtime.config.mcp.stdio_servers)} stdio MCP servers, {len(runtime.config.mcp.sse_servers)} SSE MCP servers and {len(runtime.config.mcp.shttp_servers)} SHTTP MCP servers.' + + # Check for MCP errors and add indicator to the same line + if agent.config.enable_mcp and mcp_error_collector.has_errors(): + mcp_line += ( + ' ✗ MCP errors detected (type /mcp → select View errors to view)' + ) + + welcome_message += mcp_line + '\n\n' + + welcome_message += 'What do you want to build?' # from the application initial_message = '' # from the user if task_content: @@ -488,6 +516,16 @@ async def main_with_loop(loop: asyncio.AbstractEventLoop) -> None: if not env_log_level: logger.setLevel(logging.WARNING) + # If `config.toml` does not exist in current directory, use the file under home directory + if not os.path.exists(args.config_file): + home_config_file = os.path.join( + os.path.expanduser('~'), '.openhands', 'config.toml' + ) + logger.info( + f'Config file {args.config_file} does not exist, using default config file in home directory: {home_config_file}.' + ) + args.config_file = home_config_file + # Load config from toml and override with command line arguments config: OpenHandsConfig = setup_config_from_args(args) diff --git a/openhands/cli/tui.py b/openhands/cli/tui.py index ac98cd2f36..287cff42da 100644 --- a/openhands/cli/tui.py +++ b/openhands/cli/tui.py @@ -4,6 +4,8 @@ import asyncio import contextlib +import datetime +import json import sys import threading import time @@ -36,6 +38,7 @@ from openhands.events.action import ( ActionConfirmationStatus, ChangeAgentStateAction, CmdRunAction, + MCPAction, MessageAction, ) from openhands.events.event import Event @@ -45,8 +48,10 @@ from openhands.events.observation import ( ErrorObservation, FileEditObservation, FileReadObservation, + MCPObservation, ) from openhands.llm.metrics import Metrics +from openhands.mcp.error_collector import mcp_error_collector ENABLE_STREAMING = False # FIXME: this doesn't work @@ -76,6 +81,7 @@ COMMANDS = { '/new': 'Create a new conversation', '/settings': 'Display and modify current settings', '/resume': 'Resume the agent when paused', + '/mcp': 'Manage MCP server configuration and view errors', } print_lock = threading.Lock() @@ -162,6 +168,7 @@ def display_welcome_message(message: str = '') -> None: print_formatted_text( HTML("Let's start building!\n"), style=DEFAULT_STYLE ) + if message: print_formatted_text( HTML(f'{message} Type /help for help'), @@ -186,6 +193,48 @@ def display_initial_user_prompt(prompt: str) -> None: ) +def display_mcp_errors() -> None: + """Display collected MCP errors.""" + errors = mcp_error_collector.get_errors() + + if not errors: + print_formatted_text(HTML('✓ No MCP errors detected\n')) + return + + print_formatted_text( + HTML( + f'✗ {len(errors)} MCP error(s) detected during startup:\n' + ) + ) + + for i, error in enumerate(errors, 1): + # Format timestamp + timestamp = datetime.datetime.fromtimestamp(error.timestamp).strftime( + '%H:%M:%S' + ) + + # Create error display text + error_text = ( + f'[{timestamp}] {error.server_type.upper()} Server: {error.server_name}\n' + ) + error_text += f'Error: {error.error_message}\n' + if error.exception_details: + error_text += f'Details: {error.exception_details}' + + container = Frame( + TextArea( + text=error_text, + read_only=True, + style='ansired', + wrap_lines=True, + ), + title=f'MCP Error #{i}', + style='ansired', + ) + print_container(container) + print_formatted_text('') # Add spacing between errors + + # Prompt output display functions def display_thought_if_new(thought: str) -> None: """Display a thought only if it hasn't been displayed recently.""" @@ -215,6 +264,8 @@ def display_event(event: Event, config: OpenHandsConfig) -> None: if event.confirmation_state == ActionConfirmationStatus.CONFIRMED: initialize_streaming_output() + elif isinstance(event, MCPAction): + display_mcp_action(event) elif isinstance(event, Action): # For other actions, display thoughts normally if hasattr(event, 'thought') and event.thought: @@ -232,6 +283,8 @@ def display_event(event: Event, config: OpenHandsConfig) -> None: display_file_edit(event) elif isinstance(event, FileReadObservation): display_file_read(event) + elif isinstance(event, MCPObservation): + display_mcp_observation(event) elif isinstance(event, AgentStateChangedObservation): display_agent_state_change_message(event.agent_state) elif isinstance(event, ErrorObservation): @@ -337,6 +390,66 @@ def display_file_read(event: FileReadObservation) -> None: print_container(container) +def display_mcp_action(event: MCPAction) -> None: + """Display an MCP action in the CLI.""" + # Format the arguments for display + args_text = '' + if event.arguments: + try: + args_text = json.dumps(event.arguments, indent=2) + except (TypeError, ValueError): + args_text = str(event.arguments) + + # Create the display text + display_text = f'Tool: {event.name}' + if args_text: + display_text += f'\n\nArguments:\n{args_text}' + + container = Frame( + TextArea( + text=display_text, + read_only=True, + style='ansiblue', + wrap_lines=True, + ), + title='MCP Tool Call', + style='ansiblue', + ) + print_formatted_text('') + print_container(container) + + +def display_mcp_observation(event: MCPObservation) -> None: + """Display an MCP observation in the CLI.""" + # Format the content for display + content = event.content.strip() if event.content else 'No output' + + # Add tool name and arguments info if available + display_text = content + if event.name: + header = f'Tool: {event.name}' + if event.arguments: + try: + args_text = json.dumps(event.arguments, indent=2) + header += f'\nArguments: {args_text}' + except (TypeError, ValueError): + header += f'\nArguments: {event.arguments}' + display_text = f'{header}\n\nResult:\n{content}' + + container = Frame( + TextArea( + text=display_text, + read_only=True, + style=COLOR_GREY, + wrap_lines=True, + ), + title='MCP Tool Result', + style=f'fg:{COLOR_GREY}', + ) + print_formatted_text('') + print_container(container) + + def initialize_streaming_output(): """Initialize the streaming output TextArea.""" if not ENABLE_STREAMING: diff --git a/openhands/core/config/mcp_config.py b/openhands/core/config/mcp_config.py index b952b6fe00..3a8110dcc7 100644 --- a/openhands/core/config/mcp_config.py +++ b/openhands/core/config/mcp_config.py @@ -1,8 +1,17 @@ import os +import re +import shlex from typing import TYPE_CHECKING from urllib.parse import urlparse -from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_validator, + model_validator, +) if TYPE_CHECKING: from openhands.core.config.openhands_config import OpenHandsConfig @@ -11,6 +20,27 @@ from openhands.core.logger import openhands_logger as logger from openhands.utils.import_utils import get_impl +def _validate_mcp_url(url: str) -> str: + """Shared URL validation logic for MCP servers.""" + if not url.strip(): + raise ValueError('URL cannot be empty') + + url = url.strip() + try: + parsed = urlparse(url) + if not parsed.scheme: + raise ValueError('URL must include a scheme (http:// or https://)') + if not parsed.netloc: + raise ValueError('URL must include a valid domain/host') + if parsed.scheme not in ['http', 'https', 'ws', 'wss']: + raise ValueError('URL scheme must be http, https, ws, or wss') + return url + except Exception as e: + if isinstance(e, ValueError): + raise + raise ValueError(f'Invalid URL format: {str(e)}') + + class MCPSSEServerConfig(BaseModel): """Configuration for a single MCP server. @@ -22,6 +52,12 @@ class MCPSSEServerConfig(BaseModel): url: str api_key: str | None = None + @field_validator('url') + @classmethod + def validate_url(cls, v: str) -> str: + """Validate URL format for MCP servers.""" + return _validate_mcp_url(v) + class MCPStdioServerConfig(BaseModel): """Configuration for a MCP server that uses stdio. @@ -38,6 +74,100 @@ class MCPStdioServerConfig(BaseModel): args: list[str] = Field(default_factory=list) env: dict[str, str] = Field(default_factory=dict) + @field_validator('name') + @classmethod + def validate_server_name(cls, v: str) -> str: + """Validate server name for stdio MCP servers.""" + if not v.strip(): + raise ValueError('Server name cannot be empty') + + v = v.strip() + + # Check for valid characters (alphanumeric, hyphens, underscores) + if not re.match(r'^[a-zA-Z0-9_-]+$', v): + raise ValueError( + 'Server name can only contain letters, numbers, hyphens, and underscores' + ) + + return v + + @field_validator('command') + @classmethod + def validate_command(cls, v: str) -> str: + """Validate command for stdio MCP servers.""" + if not v.strip(): + raise ValueError('Command cannot be empty') + + v = v.strip() + + # Check that command doesn't contain spaces (should be a single executable) + if ' ' in v: + raise ValueError( + 'Command should be a single executable without spaces (use arguments field for parameters)' + ) + + return v + + @field_validator('args', mode='before') + @classmethod + def parse_args(cls, v) -> list[str]: + """Parse arguments from string or return list as-is. + + Supports shell-like argument parsing using shlex.split(). + Examples: + - "-y mcp-remote https://example.com" + - '--config "path with spaces" --debug' + - "arg1 arg2 arg3" + """ + if isinstance(v, str): + if not v.strip(): + return [] + + v = v.strip() + + # Use shell-like parsing for natural argument handling + try: + return shlex.split(v) + except ValueError as e: + # If shlex parsing fails (e.g., unmatched quotes), provide clear error + raise ValueError( + f'Invalid argument format: {str(e)}. Use shell-like format, e.g., "arg1 arg2" or \'--config "value with spaces"\'' + ) + + return v or [] + + @field_validator('env', mode='before') + @classmethod + def parse_env(cls, v) -> dict[str, str]: + """Parse environment variables from string or return dict as-is.""" + if isinstance(v, str): + if not v.strip(): + return {} + + env = {} + for pair in v.split(','): + pair = pair.strip() + if not pair: + continue + + if '=' not in pair: + raise ValueError( + f"Environment variable '{pair}' must be in KEY=VALUE format" + ) + + key, value = pair.split('=', 1) + key = key.strip() + if not key: + raise ValueError('Environment variable key cannot be empty') + if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', key): + raise ValueError( + f"Invalid environment variable name '{key}'. Must start with letter or underscore, contain only alphanumeric characters and underscores" + ) + + env[key] = value + return env + return v or {} + def __eq__(self, other): """Override equality operator to compare server configurations. @@ -59,6 +189,12 @@ class MCPSHTTPServerConfig(BaseModel): url: str api_key: str | None = None + @field_validator('url') + @classmethod + def validate_url(cls, v: str) -> str: + """Validate URL format for MCP servers.""" + return _validate_mcp_url(v) + class MCPConfig(BaseModel): """Configuration for MCP (Message Control Protocol) settings. diff --git a/openhands/core/config/utils.py b/openhands/core/config/utils.py index c526088809..1af26ce443 100644 --- a/openhands/core/config/utils.py +++ b/openhands/core/config/utils.py @@ -5,7 +5,7 @@ import platform import sys from ast import literal_eval from types import UnionType -from typing import MutableMapping, get_args, get_origin +from typing import MutableMapping, get_args, get_origin, get_type_hints from uuid import uuid4 import toml @@ -154,8 +154,22 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None core_config = toml_config['core'] # Process core section if present + cfg_type_hints = get_type_hints(cfg.__class__) for key, value in core_config.items(): if hasattr(cfg, key): + # Get expected type of the attribute + expected_type = cfg_type_hints.get(key, None) + + # Check if expected_type is a Union that includes SecretStr and value is str, e.g. search_api_key + if expected_type: + origin = get_origin(expected_type) + args = get_args(expected_type) + + if origin is UnionType and SecretStr in args and isinstance(value, str): + value = SecretStr(value) + elif expected_type is SecretStr and isinstance(value, str): + value = SecretStr(value) + setattr(cfg, key, value) else: logger.openhands_logger.warning( diff --git a/openhands/mcp/__init__.py b/openhands/mcp/__init__.py index e345266166..a7ba3ec3bd 100644 --- a/openhands/mcp/__init__.py +++ b/openhands/mcp/__init__.py @@ -1,4 +1,5 @@ from openhands.mcp.client import MCPClient +from openhands.mcp.error_collector import mcp_error_collector from openhands.mcp.tool import MCPClientTool from openhands.mcp.utils import ( add_mcp_tools_to_agent, @@ -16,4 +17,5 @@ __all__ = [ 'fetch_mcp_tools_from_config', 'call_tool_mcp', 'add_mcp_tools_to_agent', + 'mcp_error_collector', ] diff --git a/openhands/mcp/client.py b/openhands/mcp/client.py index 8b90fe98a1..9bf8fb7005 100644 --- a/openhands/mcp/client.py +++ b/openhands/mcp/client.py @@ -1,13 +1,22 @@ from typing import Optional from fastmcp import Client -from fastmcp.client.transports import SSETransport, StreamableHttpTransport +from fastmcp.client.transports import ( + SSETransport, + StdioTransport, + StreamableHttpTransport, +) from mcp import McpError from mcp.types import CallToolResult from pydantic import BaseModel, ConfigDict, Field -from openhands.core.config.mcp_config import MCPSHTTPServerConfig, MCPSSEServerConfig +from openhands.core.config.mcp_config import ( + MCPSHTTPServerConfig, + MCPSSEServerConfig, + MCPStdioServerConfig, +) from openhands.core.logger import openhands_logger as logger +from openhands.mcp.error_collector import mcp_error_collector from openhands.mcp.tool import MCPClientTool @@ -90,11 +99,51 @@ class MCPClient(BaseModel): await self._initialize_and_list_tools() except McpError as e: - logger.error(f'McpError connecting to {server_url}: {e}') + error_msg = f'McpError connecting to {server_url}: {e}' + logger.error(error_msg) + mcp_error_collector.add_error( + server_name=server_url, + server_type='shttp' + if isinstance(server, MCPSHTTPServerConfig) + else 'sse', + error_message=error_msg, + exception_details=str(e), + ) raise # Re-raise the error except Exception as e: - logger.error(f'Error connecting to {server_url}: {e}') + error_msg = f'Error connecting to {server_url}: {e}' + logger.error(error_msg) + mcp_error_collector.add_error( + server_name=server_url, + server_type='shttp' + if isinstance(server, MCPSHTTPServerConfig) + else 'sse', + error_message=error_msg, + exception_details=str(e), + ) + raise + + async def connect_stdio(self, server: MCPStdioServerConfig, timeout: float = 30.0): + """Connect to MCP server using stdio transport""" + try: + transport = StdioTransport( + command=server.command, args=server.args or [], env=server.env + ) + self.client = Client(transport, timeout=timeout) + await self._initialize_and_list_tools() + except Exception as e: + server_name = getattr( + server, 'name', f'{server.command} {" ".join(server.args or [])}' + ) + error_msg = f'Failed to connect to stdio server {server_name}: {e}' + logger.error(error_msg) + mcp_error_collector.add_error( + server_name=server_name, + server_type='stdio', + error_message=error_msg, + exception_details=str(e), + ) raise async def call_tool(self, tool_name: str, args: dict) -> CallToolResult: diff --git a/openhands/mcp/error_collector.py b/openhands/mcp/error_collector.py new file mode 100644 index 0000000000..69fabd8e24 --- /dev/null +++ b/openhands/mcp/error_collector.py @@ -0,0 +1,78 @@ +"""MCP Error Collector for capturing and storing MCP-related errors during startup.""" + +import threading +import time +from dataclasses import dataclass + + +@dataclass +class MCPError: + """Represents an MCP-related error.""" + + timestamp: float + server_name: str + server_type: str # 'stdio', 'sse', 'shttp' + error_message: str + exception_details: str | None = None + + +class MCPErrorCollector: + """Thread-safe collector for MCP errors during startup.""" + + def __init__(self): + self._errors: list[MCPError] = [] + self._lock = threading.Lock() + self._collection_enabled = True + + def add_error( + self, + server_name: str, + server_type: str, + error_message: str, + exception_details: str | None = None, + ) -> None: + """Add an MCP error to the collection.""" + if not self._collection_enabled: + return + + with self._lock: + error = MCPError( + timestamp=time.time(), + server_name=server_name, + server_type=server_type, + error_message=error_message, + exception_details=exception_details, + ) + self._errors.append(error) + + def get_errors(self) -> list[MCPError]: + """Get a copy of all collected errors.""" + with self._lock: + return self._errors.copy() + + def has_errors(self) -> bool: + """Check if there are any collected errors.""" + with self._lock: + return len(self._errors) > 0 + + def clear_errors(self) -> None: + """Clear all collected errors.""" + with self._lock: + self._errors.clear() + + def disable_collection(self) -> None: + """Disable error collection (useful after startup).""" + self._collection_enabled = False + + def enable_collection(self) -> None: + """Enable error collection.""" + self._collection_enabled = True + + def get_error_count(self) -> int: + """Get the number of collected errors.""" + with self._lock: + return len(self._errors) + + +# Global instance for collecting MCP errors +mcp_error_collector = MCPErrorCollector() diff --git a/openhands/mcp/utils.py b/openhands/mcp/utils.py index 06ff980b10..391ccfe8f1 100644 --- a/openhands/mcp/utils.py +++ b/openhands/mcp/utils.py @@ -11,14 +11,17 @@ from openhands.core.config.mcp_config import ( MCPConfig, MCPSHTTPServerConfig, MCPSSEServerConfig, + MCPStdioServerConfig, ) from openhands.core.logger import openhands_logger as logger from openhands.events.action.mcp import MCPAction from openhands.events.observation.mcp import MCPObservation from openhands.events.observation.observation import Observation from openhands.mcp.client import MCPClient +from openhands.mcp.error_collector import mcp_error_collector from openhands.memory.memory import Memory from openhands.runtime.base import Runtime +from openhands.runtime.impl.cli.cli_runtime import CLIRuntime def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[dict]: @@ -45,7 +48,14 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di mcp_tools = tool.to_param() all_mcp_tools.append(mcp_tools) except Exception as e: - logger.error(f'Error in convert_mcp_clients_to_tools: {e}') + error_msg = f'Error in convert_mcp_clients_to_tools: {e}' + logger.error(error_msg) + mcp_error_collector.add_error( + server_name='general', + server_type='conversion', + error_message=error_msg, + exception_details=str(e), + ) return [] return all_mcp_tools @@ -54,6 +64,7 @@ async def create_mcp_clients( sse_servers: list[MCPSSEServerConfig], shttp_servers: list[MCPSHTTPServerConfig], conversation_id: str | None = None, + stdio_servers: list[MCPStdioServerConfig] | None = None, ) -> list[MCPClient]: import sys @@ -64,9 +75,13 @@ async def create_mcp_clients( ) return [] - servers: list[MCPSSEServerConfig | MCPSHTTPServerConfig] = [ + if stdio_servers is None: + stdio_servers = [] + + servers: list[MCPSSEServerConfig | MCPSHTTPServerConfig | MCPStdioServerConfig] = [ *sse_servers, *shttp_servers, + *stdio_servers, ] if not servers: @@ -75,6 +90,17 @@ async def create_mcp_clients( mcp_clients = [] for server in servers: + if isinstance(server, MCPStdioServerConfig): + logger.info(f'Initializing MCP agent for {server} with stdio connection...') + client = MCPClient() + try: + await client.connect_stdio(server) + mcp_clients.append(client) + except Exception as e: + # Error is already logged and collected in client.connect_stdio() + logger.error(f'Failed to connect to {server}: {str(e)}', exc_info=True) + continue + is_shttp = isinstance(server, MCPSHTTPServerConfig) connection_type = 'SHTTP' if is_shttp else 'SSE' logger.info( @@ -89,13 +115,14 @@ async def create_mcp_clients( mcp_clients.append(client) except Exception as e: + # Error is already logged and collected in client.connect_http() logger.error(f'Failed to connect to {server}: {str(e)}', exc_info=True) return mcp_clients async def fetch_mcp_tools_from_config( - mcp_config: MCPConfig, conversation_id: str | None = None + mcp_config: MCPConfig, conversation_id: str | None = None, use_stdio: bool = False ) -> list[dict]: """ Retrieves the list of MCP tools from the MCP clients. @@ -103,6 +130,7 @@ async def fetch_mcp_tools_from_config( Args: mcp_config: The MCP configuration conversation_id: Optional conversation ID to associate with the MCP clients + use_stdio: Whether to use stdio servers for MCP clients, set to True when running from a CLI runtime Returns: A list of tool dictionaries. Returns an empty list if no connections could be established. @@ -120,7 +148,10 @@ async def fetch_mcp_tools_from_config( logger.debug(f'Creating MCP clients with config: {mcp_config}') # Create clients - this will fetch tools but not maintain active connections mcp_clients = await create_mcp_clients( - mcp_config.sse_servers, mcp_config.shttp_servers, conversation_id + mcp_config.sse_servers, + mcp_config.shttp_servers, + conversation_id, + mcp_config.stdio_servers if use_stdio else [], ) if not mcp_clients: @@ -131,7 +162,14 @@ async def fetch_mcp_tools_from_config( mcp_tools = convert_mcp_clients_to_tools(mcp_clients) except Exception as e: - logger.error(f'Error fetching MCP tools: {str(e)}') + error_msg = f'Error fetching MCP tools: {str(e)}' + logger.error(error_msg) + mcp_error_collector.add_error( + server_name='general', + server_type='fetch', + error_message=error_msg, + exception_details=str(e), + ) return [] logger.debug(f'MCP tools: {mcp_tools}') @@ -200,7 +238,9 @@ async def call_tool_mcp(mcp_clients: list[MCPClient], action: MCPAction) -> Obse ) -async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memory'): +async def add_mcp_tools_to_agent( + agent: 'Agent', runtime: Runtime, memory: 'Memory' +) -> MCPConfig: """ Add MCP tools to an agent. """ @@ -231,13 +271,18 @@ async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memo # Check if this stdio server is already in the config if stdio_server not in extra_stdio_servers: extra_stdio_servers.append(stdio_server) - logger.info(f'Added microagent stdio server: {stdio_server.name}') + logger.warning( + f'Added microagent stdio server: {stdio_server.name}' + ) # Add the runtime as another MCP server updated_mcp_config = runtime.get_mcp_config(extra_stdio_servers) # Fetch the MCP tools - mcp_tools = await fetch_mcp_tools_from_config(updated_mcp_config) + # Only use stdio if run from a CLI runtime + mcp_tools = await fetch_mcp_tools_from_config( + updated_mcp_config, use_stdio=isinstance(runtime, CLIRuntime) + ) logger.info( f'Loaded {len(mcp_tools)} MCP tools: {[tool["function"]["name"] for tool in mcp_tools]}' @@ -245,3 +290,5 @@ async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memo # Set the MCP tools on the agent agent.set_mcp_tools(mcp_tools) + + return updated_mcp_config diff --git a/openhands/runtime/impl/cli/cli_runtime.py b/openhands/runtime/impl/cli/cli_runtime.py index c8f661b84b..44ac62ac05 100644 --- a/openhands/runtime/impl/cli/cli_runtime.py +++ b/openhands/runtime/impl/cli/cli_runtime.py @@ -689,8 +689,69 @@ class CLIRuntime(Runtime): ) async def call_tool_mcp(self, action: MCPAction) -> Observation: - """Not implemented for CLI runtime.""" - return ErrorObservation('MCP functionality is not implemented in CLIRuntime') + """Execute an MCP tool action in CLI runtime. + + Args: + action: The MCP action to execute + + Returns: + Observation: The result of the MCP tool execution + """ + # Check if we're on Windows - MCP is disabled on Windows + if sys.platform == 'win32': + self.log('info', 'MCP functionality is disabled on Windows') + return ErrorObservation('MCP functionality is not available on Windows') + + # Import here to avoid circular imports + from openhands.mcp.utils import call_tool_mcp as call_tool_mcp_handler + from openhands.mcp.utils import create_mcp_clients + + try: + # Get the MCP config for this runtime + mcp_config = self.get_mcp_config() + + if ( + not mcp_config.sse_servers + and not mcp_config.shttp_servers + and not mcp_config.stdio_servers + ): + self.log('warning', 'No MCP servers configured') + return ErrorObservation('No MCP servers configured') + + self.log( + 'debug', + f'Creating MCP clients for action {action.name} with servers: ' + f'SSE={len(mcp_config.sse_servers)}, SHTTP={len(mcp_config.shttp_servers)}, ' + f'stdio={len(mcp_config.stdio_servers)}', + ) + + # Create clients for this specific operation + mcp_clients = await create_mcp_clients( + mcp_config.sse_servers, + mcp_config.shttp_servers, + self.sid, + mcp_config.stdio_servers, + ) + + if not mcp_clients: + self.log('warning', 'No MCP clients could be created') + return ErrorObservation( + 'No MCP clients could be created - check server configurations' + ) + + # Call the tool and return the result + self.log( + 'debug', + f'Executing MCP tool: {action.name} with arguments: {action.arguments}', + ) + result = await call_tool_mcp_handler(mcp_clients, action) + self.log('debug', f'MCP tool {action.name} executed successfully') + return result + + except Exception as e: + error_msg = f'Error executing MCP tool {action.name}: {str(e)}' + self.log('error', error_msg) + return ErrorObservation(error_msg) @property def workspace_root(self) -> Path: @@ -869,8 +930,40 @@ class CLIRuntime(Runtime): def get_mcp_config( self, extra_stdio_servers: list[MCPStdioServerConfig] | None = None ) -> MCPConfig: - # TODO: Load MCP config from a local file - return MCPConfig() + """Get MCP configuration for CLI runtime. + + Args: + extra_stdio_servers: Additional stdio servers to include in the config + + Returns: + MCPConfig: The MCP configuration with stdio servers and any configured SSE/SHTTP servers + """ + # Check if we're on Windows - MCP is disabled on Windows + if sys.platform == 'win32': + self.log('debug', 'MCP is disabled on Windows, returning empty config') + return MCPConfig(sse_servers=[], stdio_servers=[], shttp_servers=[]) + + # Note: we update the self.config.mcp directly for CLI runtime, which is different from other runtimes. + mcp_config = self.config.mcp + + # Add any extra stdio servers + if extra_stdio_servers: + current_stdio_servers = list(mcp_config.stdio_servers) + for extra_server in extra_stdio_servers: + # Check if this stdio server is already in the config + if extra_server not in current_stdio_servers: + current_stdio_servers.append(extra_server) + self.log('info', f'Added extra stdio server: {extra_server.name}') + mcp_config.stdio_servers = current_stdio_servers + + self.log( + 'debug', + f'CLI MCP config: {len(mcp_config.sse_servers)} SSE servers, ' + f'{len(mcp_config.stdio_servers)} stdio servers, ' + f'{len(mcp_config.shttp_servers)} SHTTP servers', + ) + + return mcp_config def subscribe_to_shell_stream( self, callback: Callable[[str], None] | None = None diff --git a/tests/unit/test_cli_commands.py b/tests/unit/test_cli_commands.py index d5b7f0e0cf..db9fcaf445 100644 --- a/tests/unit/test_cli_commands.py +++ b/tests/unit/test_cli_commands.py @@ -3,10 +3,12 @@ from unittest.mock import MagicMock, patch import pytest from openhands.cli.commands import ( + display_mcp_servers, handle_commands, handle_exit_command, handle_help_command, handle_init_command, + handle_mcp_command, handle_new_command, handle_resume_command, handle_settings_command, @@ -143,6 +145,18 @@ class TestHandleCommands: assert reload_microagents is False assert new_session is False + @pytest.mark.asyncio + @patch('openhands.cli.commands.handle_mcp_command') + async def test_handle_mcp_command(self, mock_handle_mcp, mock_dependencies): + close_repl, reload_microagents, new_session, _ = await handle_commands( + '/mcp', **mock_dependencies + ) + + mock_handle_mcp.assert_called_once_with(mock_dependencies['config']) + assert close_repl is False + assert reload_microagents is False + assert new_session is False + @pytest.mark.asyncio async def test_handle_unknown_command(self, mock_dependencies): user_message = 'Hello, this is not a command' @@ -219,6 +233,78 @@ class TestHandleHelpCommand: mock_display_help.assert_called_once() +class TestDisplayMcpServers: + @patch('openhands.cli.commands.print_formatted_text') + def test_display_mcp_servers_no_servers(self, mock_print): + from openhands.core.config.mcp_config import MCPConfig + + config = MagicMock(spec=OpenHandsConfig) + config.mcp = MCPConfig() # Empty config with no servers + + display_mcp_servers(config) + + mock_print.assert_called_once() + call_args = mock_print.call_args[0][0] + assert 'No custom MCP servers configured' in call_args + assert ( + 'https://docs.all-hands.dev/usage/how-to/cli-mode#using-mcp-servers' + in call_args + ) + + @patch('openhands.cli.commands.print_formatted_text') + def test_display_mcp_servers_with_servers(self, mock_print): + from openhands.core.config.mcp_config import ( + MCPConfig, + MCPSHTTPServerConfig, + MCPSSEServerConfig, + MCPStdioServerConfig, + ) + + config = MagicMock(spec=OpenHandsConfig) + config.mcp = MCPConfig( + sse_servers=[MCPSSEServerConfig(url='https://example.com/sse')], + stdio_servers=[MCPStdioServerConfig(name='tavily', command='npx')], + shttp_servers=[MCPSHTTPServerConfig(url='http://localhost:3000/mcp')], + ) + + display_mcp_servers(config) + + # Should be called multiple times for different sections + assert mock_print.call_count >= 4 + + # Check that the summary is printed + first_call = mock_print.call_args_list[0][0][0] + assert 'Configured MCP servers:' in first_call + assert 'SSE servers: 1' in first_call + assert 'Stdio servers: 1' in first_call + assert 'SHTTP servers: 1' in first_call + assert 'Total: 3' in first_call + + +class TestHandleMcpCommand: + @pytest.mark.asyncio + @patch('openhands.cli.commands.cli_confirm') + @patch('openhands.cli.commands.display_mcp_servers') + async def test_handle_mcp_command_list_action(self, mock_display, mock_cli_confirm): + config = MagicMock(spec=OpenHandsConfig) + mock_cli_confirm.return_value = 0 # List action + + await handle_mcp_command(config) + + mock_cli_confirm.assert_called_once_with( + config, + 'MCP Server Configuration', + [ + 'List configured servers', + 'Add new server', + 'Remove server', + 'View errors', + 'Go back', + ], + ) + mock_display.assert_called_once_with(config) + + class TestHandleStatusCommand: @patch('openhands.cli.commands.display_status') def test_status_command(self, mock_display_status): @@ -496,3 +582,16 @@ class TestHandleResumeCommand: # Check the return values assert close_repl is True assert new_session_requested is False + + +class TestMCPErrorHandling: + """Test MCP error handling in commands.""" + + @patch('openhands.cli.commands.display_mcp_errors') + def test_handle_mcp_errors_command(self, mock_display_errors): + """Test handling MCP errors command.""" + from openhands.cli.commands import handle_mcp_errors_command + + handle_mcp_errors_command() + + mock_display_errors.assert_called_once() diff --git a/tests/unit/test_cli_config_management.py b/tests/unit/test_cli_config_management.py new file mode 100644 index 0000000000..c43f760659 --- /dev/null +++ b/tests/unit/test_cli_config_management.py @@ -0,0 +1,106 @@ +"""Tests for CLI server management functionality.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from openhands.cli.commands import ( + display_mcp_servers, + remove_mcp_server, +) +from openhands.core.config import OpenHandsConfig +from openhands.core.config.mcp_config import ( + MCPConfig, + MCPSSEServerConfig, + MCPStdioServerConfig, +) + + +class TestMCPServerManagement: + """Test MCP server management functions.""" + + def setup_method(self): + """Set up test fixtures.""" + self.config = MagicMock(spec=OpenHandsConfig) + self.config.cli = MagicMock() + self.config.cli.vi_mode = False + + @patch('openhands.cli.commands.print_formatted_text') + def test_display_mcp_servers_no_servers(self, mock_print): + """Test displaying MCP servers when none are configured.""" + self.config.mcp = MCPConfig() # Empty config + + display_mcp_servers(self.config) + + mock_print.assert_called_once() + call_args = mock_print.call_args[0][0] + assert 'No custom MCP servers configured' in call_args + + @patch('openhands.cli.commands.print_formatted_text') + def test_display_mcp_servers_with_servers(self, mock_print): + """Test displaying MCP servers when some are configured.""" + self.config.mcp = MCPConfig( + sse_servers=[MCPSSEServerConfig(url='http://test.com')], + stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')], + ) + + display_mcp_servers(self.config) + + # Should be called multiple times for different sections + assert mock_print.call_count >= 2 + + # Check that the summary is printed + first_call = mock_print.call_args_list[0][0][0] + assert 'Configured MCP servers:' in first_call + assert 'SSE servers: 1' in first_call + assert 'Stdio servers: 1' in first_call + + @pytest.mark.asyncio + @patch('openhands.cli.commands.cli_confirm') + @patch('openhands.cli.commands.print_formatted_text') + async def test_remove_mcp_server_no_servers(self, mock_print, mock_cli_confirm): + """Test removing MCP server when none are configured.""" + self.config.mcp = MCPConfig() # Empty config + + await remove_mcp_server(self.config) + + mock_print.assert_called_once_with('No MCP servers configured to remove.') + mock_cli_confirm.assert_not_called() + + @pytest.mark.asyncio + @patch('openhands.cli.commands.cli_confirm') + @patch('openhands.cli.commands.load_config_file') + @patch('openhands.cli.commands.save_config_file') + @patch('openhands.cli.commands.print_formatted_text') + async def test_remove_mcp_server_success( + self, mock_print, mock_save, mock_load, mock_cli_confirm + ): + """Test successfully removing an MCP server.""" + # Set up config with servers + self.config.mcp = MCPConfig( + sse_servers=[MCPSSEServerConfig(url='http://test.com')], + stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')], + ) + + # Mock user selections + mock_cli_confirm.side_effect = [0, 0] # Select first server, confirm removal + + # Mock config file operations + mock_load.return_value = { + 'mcp': { + 'sse_servers': [{'url': 'http://test.com'}], + 'stdio_servers': [{'name': 'test-stdio', 'command': 'python'}], + } + } + + await remove_mcp_server(self.config) + + # Should have been called twice (select server, confirm removal) + assert mock_cli_confirm.call_count == 2 + mock_save.assert_called_once() + + # Check that success message was printed + success_calls = [ + call for call in mock_print.call_args_list if 'removed' in str(call[0][0]) + ] + assert len(success_calls) >= 1 diff --git a/tests/unit/test_cli_runtime_mcp.py b/tests/unit/test_cli_runtime_mcp.py new file mode 100644 index 0000000000..1ec246d7eb --- /dev/null +++ b/tests/unit/test_cli_runtime_mcp.py @@ -0,0 +1,156 @@ +"""Tests for CLI Runtime MCP functionality.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from openhands.core.config import OpenHandsConfig +from openhands.core.config.mcp_config import ( + MCPConfig, + MCPSSEServerConfig, + MCPStdioServerConfig, +) +from openhands.events.action.mcp import MCPAction +from openhands.events.observation import ErrorObservation +from openhands.events.observation.mcp import MCPObservation +from openhands.runtime.impl.cli.cli_runtime import CLIRuntime + + +class TestCLIRuntimeMCP: + """Test MCP functionality in CLI Runtime.""" + + def setup_method(self): + """Set up test fixtures.""" + self.config = OpenHandsConfig() + self.event_stream = MagicMock() + self.runtime = CLIRuntime( + config=self.config, event_stream=self.event_stream, sid='test-session' + ) + + @pytest.mark.asyncio + async def test_call_tool_mcp_no_servers_configured(self): + """Test MCP call with no servers configured.""" + # Set up empty MCP config + self.runtime.config.mcp = MCPConfig() + + action = MCPAction(name='test_tool', arguments={'arg1': 'value1'}) + + with patch('sys.platform', 'linux'): + result = await self.runtime.call_tool_mcp(action) + + assert isinstance(result, ErrorObservation) + assert 'No MCP servers configured' in result.content + + @pytest.mark.asyncio + @patch('openhands.mcp.utils.create_mcp_clients') + async def test_call_tool_mcp_no_clients_created(self, mock_create_clients): + """Test MCP call when no clients can be created.""" + # Set up MCP config with servers + self.runtime.config.mcp = MCPConfig( + sse_servers=[MCPSSEServerConfig(url='http://test.com')] + ) + + # Mock create_mcp_clients to return empty list + mock_create_clients.return_value = [] + + action = MCPAction(name='test_tool', arguments={'arg1': 'value1'}) + + with patch('sys.platform', 'linux'): + result = await self.runtime.call_tool_mcp(action) + + assert isinstance(result, ErrorObservation) + assert 'No MCP clients could be created' in result.content + mock_create_clients.assert_called_once() + + @pytest.mark.asyncio + @patch('openhands.mcp.utils.create_mcp_clients') + @patch('openhands.mcp.utils.call_tool_mcp') + async def test_call_tool_mcp_success(self, mock_call_tool, mock_create_clients): + """Test successful MCP tool call.""" + # Set up MCP config with servers + self.runtime.config.mcp = MCPConfig( + sse_servers=[MCPSSEServerConfig(url='http://test.com')], + stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')], + ) + + # Mock successful client creation + mock_client = MagicMock() + mock_create_clients.return_value = [mock_client] + + # Mock successful tool call + expected_observation = MCPObservation( + content='{"result": "success"}', + name='test_tool', + arguments={'arg1': 'value1'}, + ) + mock_call_tool.return_value = expected_observation + + action = MCPAction(name='test_tool', arguments={'arg1': 'value1'}) + + with patch('sys.platform', 'linux'): + result = await self.runtime.call_tool_mcp(action) + + assert result == expected_observation + mock_create_clients.assert_called_once_with( + self.runtime.config.mcp.sse_servers, + self.runtime.config.mcp.shttp_servers, + self.runtime.sid, + self.runtime.config.mcp.stdio_servers, + ) + mock_call_tool.assert_called_once_with([mock_client], action) + + @pytest.mark.asyncio + @patch('openhands.mcp.utils.create_mcp_clients') + async def test_call_tool_mcp_exception_handling(self, mock_create_clients): + """Test exception handling in MCP tool call.""" + # Set up MCP config with servers + self.runtime.config.mcp = MCPConfig( + sse_servers=[MCPSSEServerConfig(url='http://test.com')] + ) + + # Mock create_mcp_clients to raise an exception + mock_create_clients.side_effect = Exception('Connection error') + + action = MCPAction(name='test_tool', arguments={'arg1': 'value1'}) + + with patch('sys.platform', 'linux'): + result = await self.runtime.call_tool_mcp(action) + + assert isinstance(result, ErrorObservation) + assert 'Error executing MCP tool test_tool' in result.content + assert 'Connection error' in result.content + + def test_get_mcp_config_basic(self): + """Test basic MCP config retrieval.""" + # Set up MCP config + expected_config = MCPConfig( + sse_servers=[MCPSSEServerConfig(url='http://test.com')], + stdio_servers=[MCPStdioServerConfig(name='test-stdio', command='python')], + ) + self.runtime.config.mcp = expected_config + + with patch('sys.platform', 'linux'): + result = self.runtime.get_mcp_config() + + assert result == expected_config + + def test_get_mcp_config_with_extra_stdio_servers(self): + """Test MCP config with extra stdio servers.""" + # Set up initial MCP config + initial_stdio_server = MCPStdioServerConfig(name='initial', command='python') + self.runtime.config.mcp = MCPConfig(stdio_servers=[initial_stdio_server]) + + # Add extra stdio servers + extra_servers = [ + MCPStdioServerConfig(name='extra1', command='node'), + MCPStdioServerConfig(name='extra2', command='java'), + ] + + with patch('sys.platform', 'linux'): + result = self.runtime.get_mcp_config(extra_stdio_servers=extra_servers) + + # Should have all three servers + assert len(result.stdio_servers) == 3 + assert initial_stdio_server in result.stdio_servers + assert extra_servers[0] in result.stdio_servers + assert extra_servers[1] in result.stdio_servers diff --git a/tests/unit/test_cli_tui.py b/tests/unit/test_cli_tui.py index 53f38b15f7..9349a6f41b 100644 --- a/tests/unit/test_cli_tui.py +++ b/tests/unit/test_cli_tui.py @@ -9,6 +9,9 @@ from openhands.cli.tui import ( display_banner, display_command, display_event, + display_mcp_action, + display_mcp_errors, + display_mcp_observation, display_message, display_runtime_initialization_message, display_shutdown_message, @@ -24,14 +27,17 @@ from openhands.events.action import ( Action, ActionConfirmationStatus, CmdRunAction, + MCPAction, MessageAction, ) from openhands.events.observation import ( CmdOutputObservation, FileEditObservation, FileReadObservation, + MCPObservation, ) from openhands.llm.metrics import Metrics +from openhands.mcp.error_collector import MCPError class TestDisplayFunctions: @@ -72,6 +78,35 @@ class TestDisplayFunctions: args, kwargs = mock_print.call_args_list[0] assert "Let's start building" in str(args[0]) + @patch('openhands.cli.tui.print_formatted_text') + def test_display_welcome_message_with_message(self, mock_print): + message = 'Test message' + display_welcome_message(message) + assert mock_print.call_count == 2 + # Check the first call contains the welcome message + args, kwargs = mock_print.call_args_list[0] + message_text = str(args[0]) + assert "Let's start building" in message_text + # Check the second call contains the custom message + args, kwargs = mock_print.call_args_list[1] + message_text = str(args[0]) + assert 'Test message' in message_text + assert 'Type /help for help' in message_text + + @patch('openhands.cli.tui.print_formatted_text') + def test_display_welcome_message_without_message(self, mock_print): + display_welcome_message() + assert mock_print.call_count == 2 + # Check the first call contains the welcome message + args, kwargs = mock_print.call_args_list[0] + message_text = str(args[0]) + assert "Let's start building" in message_text + # Check the second call contains the default message + args, kwargs = mock_print.call_args_list[1] + message_text = str(args[0]) + assert 'What do you want to build?' in message_text + assert 'Type /help for help' in message_text + @patch('openhands.cli.tui.display_message') def test_display_event_message_action(self, mock_display_message): config = MagicMock(spec=OpenHandsConfig) @@ -147,6 +182,71 @@ class TestDisplayFunctions: mock_display_message.assert_called_once_with('Thinking about this...') + @patch('openhands.cli.tui.display_mcp_action') + def test_display_event_mcp_action(self, mock_display_mcp_action): + config = MagicMock(spec=OpenHandsConfig) + mcp_action = MCPAction(name='test_tool', arguments={'param': 'value'}) + + display_event(mcp_action, config) + + mock_display_mcp_action.assert_called_once_with(mcp_action) + + @patch('openhands.cli.tui.display_mcp_observation') + def test_display_event_mcp_observation(self, mock_display_mcp_observation): + config = MagicMock(spec=OpenHandsConfig) + mcp_observation = MCPObservation( + content='Tool result', name='test_tool', arguments={'param': 'value'} + ) + + display_event(mcp_observation, config) + + mock_display_mcp_observation.assert_called_once_with(mcp_observation) + + @patch('openhands.cli.tui.print_container') + def test_display_mcp_action(self, mock_print_container): + mcp_action = MCPAction(name='test_tool', arguments={'param': 'value'}) + + display_mcp_action(mcp_action) + + mock_print_container.assert_called_once() + container = mock_print_container.call_args[0][0] + assert 'test_tool' in container.body.text + assert 'param' in container.body.text + + @patch('openhands.cli.tui.print_container') + def test_display_mcp_action_no_args(self, mock_print_container): + mcp_action = MCPAction(name='test_tool') + + display_mcp_action(mcp_action) + + mock_print_container.assert_called_once() + container = mock_print_container.call_args[0][0] + assert 'test_tool' in container.body.text + assert 'Arguments' not in container.body.text + + @patch('openhands.cli.tui.print_container') + def test_display_mcp_observation(self, mock_print_container): + mcp_observation = MCPObservation( + content='Tool result', name='test_tool', arguments={'param': 'value'} + ) + + display_mcp_observation(mcp_observation) + + mock_print_container.assert_called_once() + container = mock_print_container.call_args[0][0] + assert 'test_tool' in container.body.text + assert 'Tool result' in container.body.text + + @patch('openhands.cli.tui.print_container') + def test_display_mcp_observation_no_content(self, mock_print_container): + mcp_observation = MCPObservation(content='', name='test_tool') + + display_mcp_observation(mcp_observation) + + mock_print_container.assert_called_once() + container = mock_print_container.call_args[0][0] + assert 'No output' in container.body.text + @patch('openhands.cli.tui.print_formatted_text') def test_display_message(self, mock_print): message = 'Test message' @@ -307,14 +407,86 @@ class TestReadConfirmationInput: result = await read_confirmation_input(config=cfg) assert result == 'always' - @pytest.mark.asyncio + +"""Tests for CLI TUI MCP functionality.""" + + +class TestMCPTUIDisplay: + """Test MCP TUI display functions.""" + + @patch('openhands.cli.tui.print_container') + def test_display_mcp_action_with_arguments(self, mock_print_container): + """Test displaying MCP action with arguments.""" + mcp_action = MCPAction( + name='test_tool', arguments={'param1': 'value1', 'param2': 42} + ) + + display_mcp_action(mcp_action) + + mock_print_container.assert_called_once() + container = mock_print_container.call_args[0][0] + assert 'test_tool' in container.body.text + assert 'param1' in container.body.text + assert 'value1' in container.body.text + + @patch('openhands.cli.tui.print_container') + def test_display_mcp_observation_with_content(self, mock_print_container): + """Test displaying MCP observation with content.""" + mcp_observation = MCPObservation( + content='Tool execution successful', + name='test_tool', + arguments={'param': 'value'}, + ) + + display_mcp_observation(mcp_observation) + + mock_print_container.assert_called_once() + container = mock_print_container.call_args[0][0] + assert 'test_tool' in container.body.text + assert 'Tool execution successful' in container.body.text + @patch('openhands.cli.tui.print_formatted_text') - @patch('openhands.cli.tui.cli_confirm') - async def test_read_confirmation_input_edit(self, mock_confirm, mock_print): - mock_confirm.return_value = 3 # user picked third menu item + @patch('openhands.cli.tui.mcp_error_collector') + def test_display_mcp_errors_no_errors(self, mock_collector, mock_print): + """Test displaying MCP errors when none exist.""" + mock_collector.get_errors.return_value = [] - cfg = MagicMock() # <- no spec for simplicity - cfg.cli = MagicMock(vi_mode=False) + display_mcp_errors() - result = await read_confirmation_input(config=cfg) - assert result == 'edit' + mock_print.assert_called_once() + call_args = mock_print.call_args[0][0] + assert 'No MCP errors detected' in str(call_args) + + @patch('openhands.cli.tui.print_container') + @patch('openhands.cli.tui.print_formatted_text') + @patch('openhands.cli.tui.mcp_error_collector') + def test_display_mcp_errors_with_errors( + self, mock_collector, mock_print, mock_print_container + ): + """Test displaying MCP errors when some exist.""" + # Create mock errors + error1 = MCPError( + timestamp=1234567890.0, + server_name='test-server-1', + server_type='stdio', + error_message='Connection failed', + exception_details='Socket timeout', + ) + error2 = MCPError( + timestamp=1234567891.0, + server_name='test-server-2', + server_type='sse', + error_message='Server unreachable', + ) + + mock_collector.get_errors.return_value = [error1, error2] + + display_mcp_errors() + + # Should print error count header + assert mock_print.call_count >= 1 + header_call = mock_print.call_args_list[0][0][0] + assert '2 MCP error(s) detected' in str(header_call) + + # Should print containers for each error + assert mock_print_container.call_count == 2 diff --git a/tests/unit/test_mcp_config.py b/tests/unit/test_mcp_config.py index 25ae65e333..f0497e42fb 100644 --- a/tests/unit/test_mcp_config.py +++ b/tests/unit/test_mcp_config.py @@ -27,10 +27,9 @@ def test_empty_sse_config(): def test_invalid_sse_url(): """Test SSE configuration with invalid URL format.""" - config = MCPConfig(sse_servers=[MCPSSEServerConfig(url='not_a_url')]) - with pytest.raises(ValueError) as exc_info: - config.validate_servers() - assert 'Invalid URL' in str(exc_info.value) + with pytest.raises(ValidationError) as exc_info: + MCPSSEServerConfig(url='not_a_url') + assert 'URL must include a scheme' in str(exc_info.value) def test_duplicate_sse_urls(): @@ -64,7 +63,7 @@ def test_from_toml_section_invalid_sse(): } with pytest.raises(ValueError) as exc_info: MCPConfig.from_toml_section(data) - assert 'Invalid URL' in str(exc_info.value) + assert 'URL must include a scheme' in str(exc_info.value) def test_complex_urls(): @@ -246,3 +245,28 @@ def test_stdio_server_equality_with_different_env_order(): # Should not be equal assert server1 != server5 + + +def test_mcp_stdio_server_args_parsing_basic(): + """Test MCPStdioServerConfig args parsing with basic shell-like format.""" + # Test basic space-separated parsing + config = MCPStdioServerConfig( + name='test-server', command='python', args='arg1 arg2 arg3' + ) + assert config.args == ['arg1', 'arg2', 'arg3'] + + # Test single argument + config = MCPStdioServerConfig( + name='test-server', command='python', args='single-arg' + ) + assert config.args == ['single-arg'] + + +def test_mcp_stdio_server_args_parsing_invalid_quotes(): + """Test MCPStdioServerConfig args parsing with invalid quotes.""" + # Test unmatched quotes + with pytest.raises(ValidationError) as exc_info: + MCPStdioServerConfig( + name='test-server', command='python', args='--config "unmatched quote' + ) + assert 'Invalid argument format' in str(exc_info.value) diff --git a/tests/unit/test_mcp_error_collector.py b/tests/unit/test_mcp_error_collector.py new file mode 100644 index 0000000000..f3291bd730 --- /dev/null +++ b/tests/unit/test_mcp_error_collector.py @@ -0,0 +1,159 @@ +"""Tests for MCP Error Collector functionality.""" + +import time + +from openhands.mcp.error_collector import ( + MCPError, + MCPErrorCollector, + mcp_error_collector, +) + + +class TestMCPError: + """Test MCPError dataclass.""" + + def test_mcp_error_creation(self): + """Test creating an MCP error.""" + timestamp = time.time() + error = MCPError( + timestamp=timestamp, + server_name='test-server', + server_type='stdio', + error_message='Connection failed', + exception_details='Socket timeout', + ) + + assert error.timestamp == timestamp + assert error.server_name == 'test-server' + assert error.server_type == 'stdio' + assert error.error_message == 'Connection failed' + assert error.exception_details == 'Socket timeout' + + def test_mcp_error_creation_without_exception_details(self): + """Test creating an MCP error without exception details.""" + timestamp = time.time() + error = MCPError( + timestamp=timestamp, + server_name='test-server', + server_type='sse', + error_message='Server unreachable', + ) + + assert error.timestamp == timestamp + assert error.server_name == 'test-server' + assert error.server_type == 'sse' + assert error.error_message == 'Server unreachable' + assert error.exception_details is None + + +class TestMCPErrorCollector: + """Test MCPErrorCollector functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.collector = MCPErrorCollector() + + def test_initialization(self): + """Test collector initialization.""" + assert self.collector._errors == [] + assert self.collector._collection_enabled is True + + def test_add_error(self): + """Test adding an error to the collector.""" + self.collector.add_error( + server_name='test-server', + server_type='stdio', + error_message='Connection failed', + exception_details='Socket timeout', + ) + + errors = self.collector.get_errors() + assert len(errors) == 1 + assert errors[0].server_name == 'test-server' + assert errors[0].server_type == 'stdio' + assert errors[0].error_message == 'Connection failed' + assert errors[0].exception_details == 'Socket timeout' + assert errors[0].timestamp > 0 + + def test_add_multiple_errors(self): + """Test adding multiple errors.""" + self.collector.add_error('server1', 'stdio', 'Error 1') + self.collector.add_error('server2', 'sse', 'Error 2') + self.collector.add_error('server3', 'shttp', 'Error 3') + + errors = self.collector.get_errors() + assert len(errors) == 3 + assert errors[0].server_name == 'server1' + assert errors[1].server_name == 'server2' + assert errors[2].server_name == 'server3' + + def test_has_errors(self): + """Test has_errors method.""" + assert not self.collector.has_errors() + + self.collector.add_error('server1', 'stdio', 'Error 1') + assert self.collector.has_errors() + + self.collector.clear_errors() + assert not self.collector.has_errors() + + def test_clear_errors(self): + """Test clearing errors.""" + self.collector.add_error('server1', 'stdio', 'Error 1') + self.collector.add_error('server2', 'sse', 'Error 2') + + assert len(self.collector.get_errors()) == 2 + + self.collector.clear_errors() + assert len(self.collector.get_errors()) == 0 + assert not self.collector.has_errors() + + def test_enable_disable_collection(self): + """Test enabling and disabling error collection.""" + self.collector.add_error('server1', 'stdio', 'Error 1') + assert len(self.collector.get_errors()) == 1 + + # Disable collection + self.collector.disable_collection() + + # Adding error should be ignored + self.collector.add_error('server2', 'sse', 'Error 2') + assert len(self.collector.get_errors()) == 1 # Still only 1 error + + # Re-enable collection + self.collector.enable_collection() + + # Adding error should work again + self.collector.add_error('server3', 'shttp', 'Error 3') + assert len(self.collector.get_errors()) == 2 + + +class TestGlobalMCPErrorCollector: + """Test the global MCP error collector instance.""" + + def setup_method(self): + """Clear global collector before each test.""" + mcp_error_collector.clear_errors() + mcp_error_collector.enable_collection() + + def teardown_method(self): + """Clean up after each test.""" + mcp_error_collector.clear_errors() + mcp_error_collector.enable_collection() + + def test_global_collector_exists(self): + """Test that global collector instance exists.""" + assert mcp_error_collector is not None + assert isinstance(mcp_error_collector, MCPErrorCollector) + + def test_global_collector_functionality(self): + """Test basic functionality of global collector.""" + assert not mcp_error_collector.has_errors() + + mcp_error_collector.add_error('global-server', 'stdio', 'Global error') + assert mcp_error_collector.has_errors() + assert mcp_error_collector.get_error_count() == 1 + + errors = mcp_error_collector.get_errors() + assert len(errors) == 1 + assert errors[0].server_name == 'global-server' diff --git a/tests/unit/test_mcp_utils.py b/tests/unit/test_mcp_utils.py index 240bfb27f5..64b346e085 100644 --- a/tests/unit/test_mcp_utils.py +++ b/tests/unit/test_mcp_utils.py @@ -5,7 +5,7 @@ import pytest # Import the module, not the functions directly to avoid circular imports import openhands.mcp.utils -from openhands.core.config.mcp_config import MCPSSEServerConfig +from openhands.core.config.mcp_config import MCPSSEServerConfig, MCPStdioServerConfig from openhands.events.action.mcp import MCPAction from openhands.events.observation.mcp import MCPObservation @@ -161,3 +161,136 @@ async def test_call_tool_mcp_success(): assert isinstance(observation, MCPObservation) assert json.loads(observation.content) == {'result': 'success'} mock_client.call_tool.assert_called_once_with('test_tool', {'arg1': 'value1'}) + + +@pytest.mark.asyncio +@patch('openhands.mcp.utils.MCPClient') +async def test_create_mcp_clients_stdio_success(mock_mcp_client): + """Test successful creation of MCP clients with stdio servers.""" + # Setup mock + mock_client_instance = AsyncMock() + mock_mcp_client.return_value = mock_client_instance + mock_client_instance.connect_stdio = AsyncMock() + + # Test with stdio servers + stdio_server_configs = [ + MCPStdioServerConfig( + name='test-server-1', + command='python', + args=['-m', 'server1'], + env={'DEBUG': 'true'}, + ), + MCPStdioServerConfig( + name='test-server-2', + command='/usr/bin/node', + args=['server2.js'], + env={'NODE_ENV': 'development'}, + ), + ] + + clients = await openhands.mcp.utils.create_mcp_clients( + [], [], stdio_servers=stdio_server_configs + ) + + # Verify + assert len(clients) == 2 + assert mock_mcp_client.call_count == 2 + + # Check that connect_stdio was called with correct parameters + mock_client_instance.connect_stdio.assert_any_call(stdio_server_configs[0]) + mock_client_instance.connect_stdio.assert_any_call(stdio_server_configs[1]) + + +@pytest.mark.asyncio +@patch('openhands.mcp.utils.MCPClient') +async def test_create_mcp_clients_stdio_connection_failure(mock_mcp_client): + """Test handling of stdio connection failures when creating MCP clients.""" + # Setup mock + mock_client_instance = AsyncMock() + mock_mcp_client.return_value = mock_client_instance + + # First connection succeeds, second fails + mock_client_instance.connect_stdio.side_effect = [ + None, # Success + Exception('Stdio connection failed'), # Failure + ] + + stdio_server_configs = [ + MCPStdioServerConfig(name='server1', command='python'), + MCPStdioServerConfig(name='server2', command='invalid_command'), + ] + + clients = await openhands.mcp.utils.create_mcp_clients( + [], [], stdio_servers=stdio_server_configs + ) + + # Verify only one client was successfully created + assert len(clients) == 1 + + +@pytest.mark.asyncio +@patch('openhands.mcp.utils.create_mcp_clients') +async def test_fetch_mcp_tools_from_config_with_stdio(mock_create_clients): + """Test fetching MCP tools with stdio servers enabled.""" + from openhands.core.config.mcp_config import MCPConfig + + # Setup mock clients + mock_client = MagicMock() + mock_tool = MagicMock() + mock_tool.to_param.return_value = {'function': {'name': 'stdio_tool'}} + mock_client.tools = [mock_tool] + mock_create_clients.return_value = [mock_client] + + # Create config with stdio servers + mcp_config = MCPConfig( + stdio_servers=[MCPStdioServerConfig(name='test-server', command='python')] + ) + + # Test with use_stdio=True + tools = await openhands.mcp.utils.fetch_mcp_tools_from_config( + mcp_config, conversation_id='test-conv', use_stdio=True + ) + + # Verify + assert len(tools) == 1 + assert tools[0] == {'function': {'name': 'stdio_tool'}} + + # Verify create_mcp_clients was called with stdio servers + mock_create_clients.assert_called_once_with( + [], [], 'test-conv', mcp_config.stdio_servers + ) + + +@pytest.mark.asyncio +async def test_call_tool_mcp_stdio_client(): + """Test calling MCP tool on a stdio client.""" + # Create mock stdio client with the requested tool + mock_client = MagicMock() + mock_tool = MagicMock() + mock_tool.name = 'stdio_test_tool' + mock_client.tools = [mock_tool] + + # Setup response + mock_response = MagicMock() + mock_response.model_dump.return_value = { + 'result': 'stdio_success', + 'data': 'test_data', + } + + # Setup call_tool method + mock_client.call_tool = AsyncMock(return_value=mock_response) + + action = MCPAction(name='stdio_test_tool', arguments={'input': 'test_input'}) + + # Call the function + observation = await openhands.mcp.utils.call_tool_mcp([mock_client], action) + + # Verify + assert isinstance(observation, MCPObservation) + assert json.loads(observation.content) == { + 'result': 'stdio_success', + 'data': 'test_data', + } + mock_client.call_tool.assert_called_once_with( + 'stdio_test_tool', {'input': 'test_input'} + )