feat(MCP): MCP refactor, support stdio, and running MCP server in runtime (#7911)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
Xingyao Wang
2025-05-02 09:43:19 +08:00
committed by GitHub
parent 0fc86b4063
commit 6032d2620d
37 changed files with 1906 additions and 208 deletions

96
docs/modules/usage/mcp.md Normal file
View File

@@ -0,0 +1,96 @@
# Model Context Protocol (MCP)
:::note
This page outlines how to configure and use the Model Context Protocol (MCP) in OpenHands, allowing you to extend the agent's capabilities with custom tools.
:::
## Overview
Model Context Protocol (MCP) is a mechanism that allows OpenHands to communicate with external tool servers. These servers can provide additional functionality to the agent, such as specialized data processing, external API access, or custom tools. MCP is based on the open standard defined at [modelcontextprotocol.io](https://modelcontextprotocol.io).
## Configuration
MCP configuration is defined in the `[mcp]` section of your `config.toml` file.
### Configuration Example
```toml
[mcp]
# SSE Servers - External servers that communicate via Server-Sent Events
sse_servers = [
# Basic SSE server with just a URL
"http://example.com:8080/mcp",
# SSE server with API key authentication
{url="https://secure-example.com/mcp", api_key="your-api-key"}
]
# Stdio Servers - Local processes that communicate via standard input/output
stdio_servers = [
# Basic stdio server
{name="fetch", command="uvx", args=["mcp-server-fetch"]},
# Stdio server with environment variables
{
name="data-processor",
command="python",
args=["-m", "my_mcp_server"],
env={
"DEBUG": "true",
"PORT": "8080"
}
}
]
```
## Configuration Options
### SSE Servers
SSE servers are configured using either a string URL or an object with the following properties:
- `url` (required)
- Type: `str`
- Description: The URL of the SSE server
- `api_key` (optional)
- Type: `str`
- Default: `None`
- Description: API key for authentication with the SSE server
### Stdio Servers
Stdio servers are configured using an object with the following properties:
- `name` (required)
- Type: `str`
- Description: A unique name for the server
- `command` (required)
- Type: `str`
- Description: The command to run the server
- `args` (optional)
- Type: `list of str`
- Default: `[]`
- Description: Command-line arguments to pass to the server
- `env` (optional)
- Type: `dict of str to str`
- Default: `{}`
- Description: Environment variables to set for the server process
## How MCP Works
When OpenHands starts, it:
1. Reads the MCP configuration from `config.toml`
2. Connects to any configured SSE servers
3. Starts any configured stdio servers
4. Registers the tools provided by these servers with the agent
The agent can then use these tools just like any built-in tool. When the agent calls an MCP tool:
1. OpenHands routes the call to the appropriate MCP server
2. The server processes the request and returns a response
3. OpenHands converts the response to an observation and presents it to the agent

View File

@@ -177,38 +177,30 @@ class CodeActAgent(Agent):
}
params['tools'] = self.tools
if self.mcp_tools:
# Only add tools with unique names
existing_names = {tool['function']['name'] for tool in params['tools']}
unique_mcp_tools = [
tool
for tool in self.mcp_tools
if tool['function']['name'] not in existing_names
]
if self.llm.config.model == 'gemini-2.5-pro-preview-03-25':
logger.info(
f'Removing the default fields from the MCP tools for {self.llm.config.model} '
"since it doesn't support them and the request would crash."
)
# prevent mutation of input tools
unique_mcp_tools = copy.deepcopy(unique_mcp_tools)
# Strip off default fields that cause errors with gemini-preview
for tool in unique_mcp_tools:
if 'function' in tool and 'parameters' in tool['function']:
if 'properties' in tool['function']['parameters']:
for prop_name, prop in tool['function']['parameters'][
'properties'
].items():
if 'default' in prop:
del prop['default']
params['tools'] += unique_mcp_tools
# Special handling for Gemini model which doesn't support default fields
if self.llm.config.model == 'gemini-2.5-pro-preview-03-25':
logger.info(
f'Removing the default fields from tools for {self.llm.config.model} '
"since it doesn't support them and the request would crash."
)
# prevent mutation of input tools
params['tools'] = copy.deepcopy(params['tools'])
# Strip off default fields that cause errors with gemini-preview
for tool in params['tools']:
if 'function' in tool and 'parameters' in tool['function']:
if 'properties' in tool['function']['parameters']:
for prop_name, prop in tool['function']['parameters'][
'properties'
].items():
if 'default' in prop:
del prop['default']
# log to litellm proxy if possible
params['extra_body'] = {'metadata': state.to_llm_metadata(agent_name=self.name)}
response = self.llm.completion(**params)
logger.debug(f'Response from LLM: {response}')
actions = self.response_to_actions_fn(response)
actions = self.response_to_actions_fn(
response, mcp_tool_names=list(self.mcp_tools.keys())
)
logger.debug(f'Actions after response_to_actions: {actions}')
for action in actions:
self.pending_actions.append(action)

View File

@@ -37,10 +37,9 @@ from openhands.events.action import (
IPythonRunCellAction,
MessageAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.mcp import MCPAction
from openhands.events.event import FileEditSource, FileReadSource
from openhands.events.tool import ToolCallMetadata
from openhands.mcp import MCPClientTool
def combine_thought(action: Action, thought: str) -> Action:
@@ -53,7 +52,9 @@ def combine_thought(action: Action, thought: str) -> Action:
return action
def response_to_actions(response: ModelResponse) -> list[Action]:
def response_to_actions(
response: ModelResponse, mcp_tool_names: list[str] | None = None
) -> list[Action]:
actions: list[Action] = []
assert len(response.choices) == 1, 'Only one choice is supported for now'
choice = response.choices[0]
@@ -195,12 +196,12 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
action = BrowseURLAction(url=arguments['url'])
# ================================================
# McpAction (MCP)
# MCPAction (MCP)
# ================================================
elif tool_call.function.name.endswith(MCPClientTool.postfix()):
action = McpAction(
name=tool_call.function.name.removesuffix(MCPClientTool.postfix()),
arguments=tool_call.function.arguments,
elif mcp_tool_names and tool_call.function.name in mcp_tool_names:
action = MCPAction(
name=tool_call.function.name,
arguments=arguments,
)
else:
raise FunctionCallNotExistsError(

View File

@@ -36,6 +36,7 @@ from openhands.events.action import (
BrowseURLAction,
CmdRunAction,
FileReadAction,
MCPAction,
MessageAction,
)
from openhands.events.event import FileReadSource
@@ -102,7 +103,9 @@ def glob_to_cmdrun(pattern: str, path: str = '.') -> str:
return echo_cmd + complete_cmd
def response_to_actions(response: ModelResponse) -> list[Action]:
def response_to_actions(
response: ModelResponse, mcp_tool_names: list[str] | None = None
) -> list[Action]:
actions: list[Action] = []
assert len(response.choices) == 1, 'Only one choice is supported for now'
choice = response.choices[0]
@@ -198,6 +201,15 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
)
action = BrowseURLAction(url=arguments['url'])
# ================================================
# MCPAction (MCP)
# ================================================
elif mcp_tool_names and tool_call.function.name in mcp_tool_names:
action = MCPAction(
name=tool_call.function.name,
arguments=arguments,
)
else:
raise FunctionCallNotExistsError(
f'Tool {tool_call.function.name} is not registered. (arguments: {arguments}). Please check the tool name and retry with an existing tool.'

View File

@@ -8,6 +8,8 @@ if TYPE_CHECKING:
from openhands.core.config import AgentConfig
from openhands.events.action import Action
from openhands.events.action.message import SystemMessageAction
from litellm import ChatCompletionToolParam
from openhands.core.exceptions import (
AgentAlreadyRegisteredError,
AgentNotRegisteredError,
@@ -42,7 +44,7 @@ class Agent(ABC):
self.config = config
self._complete = False
self.prompt_manager: 'PromptManager' | None = None
self.mcp_tools: list[dict] = []
self.mcp_tools: dict[str, ChatCompletionToolParam] = {}
self.tools: list = []
def get_system_message(self) -> 'SystemMessageAction | None':
@@ -160,4 +162,18 @@ class Agent(ABC):
Args:
- mcp_tools (list[dict]): The list of MCP tools.
"""
self.mcp_tools = mcp_tools
logger.info(
f"Setting {len(mcp_tools)} MCP tools for agent {self.name}: {[tool['function']['name'] for tool in mcp_tools]}"
)
for tool in mcp_tools:
_tool = ChatCompletionToolParam(**tool)
if _tool['function']['name'] in self.mcp_tools:
logger.warning(
f"Tool {_tool['function']['name']} already exists, skipping"
)
continue
self.mcp_tools[_tool['function']['name']] = _tool
self.tools.append(_tool)
logger.info(
f"Tools updated for agent {self.name}, total {len(self.tools)}: {[tool['function']['name'] for tool in self.tools]}"
)

View File

@@ -54,7 +54,7 @@ from openhands.events.observation import (
AgentStateChangedObservation,
)
from openhands.io import read_task
from openhands.mcp import fetch_mcp_tools_from_config
from openhands.mcp import add_mcp_tools_to_agent
from openhands.memory.condenser.impl.llm_summarizing_condenser import (
LLMSummarizingCondenserConfig,
)
@@ -112,8 +112,6 @@ async def run_session(
)
agent = create_agent(config)
mcp_tools = await fetch_mcp_tools_from_config(config.mcp)
agent.set_mcp_tools(mcp_tools)
runtime = create_runtime(
config,
sid=sid,
@@ -209,6 +207,7 @@ async def run_session(
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
await runtime.connect()
await add_mcp_tools_to_agent(agent, runtime, config.mcp)
# Initialize repository if needed
repo_directory = None

View File

@@ -7,6 +7,7 @@ from openhands.core.config.config_utils import (
)
from openhands.core.config.extended_config import ExtendedConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.mcp_config import MCPConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
from openhands.core.config.utils import (
@@ -26,6 +27,7 @@ __all__ = [
'OH_MAX_ITERATIONS',
'AgentConfig',
'AppConfig',
'MCPConfig',
'LLMConfig',
'SandboxConfig',
'SecurityConfig',

View File

@@ -1,28 +1,59 @@
from typing import List
from urllib.parse import urlparse
from pydantic import BaseModel, Field, ValidationError
class MCPSSEServerConfig(BaseModel):
"""Configuration for a single MCP server.
Attributes:
url: The server URL
api_key: Optional API key for authentication
"""
url: str
api_key: str | None = None
class MCPStdioServerConfig(BaseModel):
"""Configuration for a MCP server that uses stdio.
Attributes:
name: The name of the server
command: The command to run the server
args: The arguments to pass to the server
env: The environment variables to set for the server
"""
name: str
command: str
args: list[str] = Field(default_factory=list)
env: dict[str, str] = Field(default_factory=dict)
class MCPConfig(BaseModel):
"""Configuration for MCP (Message Control Protocol) settings.
Attributes:
mcp_servers: List of MCP SSE (Server-Sent Events) server URLs.
sse_servers: List of MCP SSE server configs
stdio_servers: List of MCP stdio server configs. These servers will be added to the MCP Router running inside runtime container.
"""
mcp_servers: List[str] = Field(default_factory=list)
sse_servers: list[MCPSSEServerConfig] = Field(default_factory=list)
stdio_servers: list[MCPStdioServerConfig] = Field(default_factory=list)
model_config = {'extra': 'forbid'}
def validate_servers(self) -> None:
"""Validate that server URLs are valid and unique."""
urls = [server.url for server in self.sse_servers]
# Check for duplicate server URLs
if len(set(self.mcp_servers)) != len(self.mcp_servers):
if len(set(urls)) != len(urls):
raise ValueError('Duplicate MCP server URLs are not allowed')
# Validate URLs
for url in self.mcp_servers:
for url in urls:
try:
result = urlparse(url)
if not all([result.scheme, result.netloc]):
@@ -44,11 +75,32 @@ class MCPConfig(BaseModel):
mcp_mapping: dict[str, MCPConfig] = {}
try:
# Convert all entries in sse_servers to MCPSSEServerConfig objects
if 'sse_servers' in data:
servers = []
for server in data['sse_servers']:
if isinstance(server, dict):
servers.append(MCPSSEServerConfig(**server))
else:
# Convert string URLs to MCPSSEServerConfig objects with no API key
servers.append(MCPSSEServerConfig(url=server))
data['sse_servers'] = servers
# Convert all entries in stdio_servers to MCPStdioServerConfig objects
if 'stdio_servers' in data:
servers = []
for server in data['stdio_servers']:
servers.append(MCPStdioServerConfig(**server))
data['stdio_servers'] = servers
# Create SSE config if present
mcp_config = MCPConfig.model_validate(data)
mcp_config.validate_servers()
# Create the main MCP config
mcp_mapping['mcp'] = cls(mcp_servers=mcp_config.mcp_servers)
mcp_mapping['mcp'] = cls(
sse_servers=mcp_config.sse_servers,
stdio_servers=mcp_config.stdio_servers,
)
except ValidationError as e:
raise ValueError(f'Invalid MCP configuration: {e}')

View File

@@ -30,7 +30,7 @@ from openhands.events.action.action import Action
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.io import read_input, read_task
from openhands.mcp import fetch_mcp_tools_from_config
from openhands.mcp import add_mcp_tools_to_agent
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
@@ -96,8 +96,6 @@ async def run_controller(
if agent is None:
agent = create_agent(config)
mcp_tools = await fetch_mcp_tools_from_config(config.mcp)
agent.set_mcp_tools(mcp_tools)
# when the runtime is created, it will be connected and clone the selected repository
repo_directory = None
@@ -118,6 +116,8 @@ async def run_controller(
selected_repository=config.sandbox.selected_repo,
)
await add_mcp_tools_to_agent(agent, runtime, config.mcp)
event_stream = runtime.event_stream
# when memory is created, it will load the microagents from the selected repository

View File

@@ -15,7 +15,7 @@ from openhands.events.action.files import (
FileReadAction,
FileWriteAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.mcp import MCPAction
from openhands.events.action.message import MessageAction, SystemMessageAction
__all__ = [
@@ -37,5 +37,5 @@ __all__ = [
'ActionConfirmationStatus',
'AgentThinkAction',
'RecallAction',
'McpAction',
'MCPAction',
]

View File

@@ -1,14 +1,14 @@
from dataclasses import dataclass
from typing import ClassVar
from dataclasses import dataclass, field
from typing import Any, ClassVar
from openhands.core.schema import ActionType
from openhands.events.action.action import Action, ActionSecurityRisk
@dataclass
class McpAction(Action):
class MCPAction(Action):
name: str
arguments: str | None = None
arguments: dict[str, Any] = field(default_factory=dict)
thought: str = ''
action: str = ActionType.MCP
runnable: ClassVar[bool] = True
@@ -24,7 +24,7 @@ class McpAction(Action):
)
def __str__(self) -> str:
ret = '**McpAction**\n'
ret = '**MCPAction**\n'
if self.thought:
ret += f'THOUGHT: {self.thought}\n'
ret += f'NAME: {self.name}\n'

View File

@@ -21,6 +21,7 @@ from openhands.events.observation.files import (
FileReadObservation,
FileWriteObservation,
)
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.observation.success import SuccessObservation

View File

@@ -22,7 +22,7 @@ from openhands.events.action.files import (
FileReadAction,
FileWriteAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.mcp import MCPAction
from openhands.events.action.message import MessageAction, SystemMessageAction
actions = (
@@ -43,7 +43,7 @@ actions = (
MessageAction,
SystemMessageAction,
CondensationAction,
McpAction,
MCPAction,
)
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]

View File

@@ -1,9 +1,7 @@
from openhands.mcp.client import MCPClient
from openhands.mcp.tool import (
BaseTool,
MCPClientTool,
)
from openhands.mcp.tool import MCPClientTool
from openhands.mcp.utils import (
add_mcp_tools_to_agent,
call_tool_mcp,
convert_mcp_clients_to_tools,
create_mcp_clients,
@@ -14,8 +12,8 @@ __all__ = [
'MCPClient',
'convert_mcp_clients_to_tools',
'create_mcp_clients',
'BaseTool',
'MCPClientTool',
'fetch_mcp_tools_from_config',
'call_tool_mcp',
'add_mcp_tools_to_agent',
]

View File

@@ -7,7 +7,7 @@ from mcp.client.sse import sse_client
from pydantic import BaseModel, Field
from openhands.core.logger import openhands_logger as logger
from openhands.mcp.tool import BaseTool, MCPClientTool
from openhands.mcp.tool import MCPClientTool
class MCPClient(BaseModel):
@@ -18,13 +18,15 @@ class MCPClient(BaseModel):
session: Optional[ClientSession] = None
exit_stack: AsyncExitStack = AsyncExitStack()
description: str = 'MCP client tools for server interaction'
tools: List[BaseTool] = Field(default_factory=list)
tool_map: Dict[str, BaseTool] = Field(default_factory=dict)
tools: List[MCPClientTool] = Field(default_factory=list)
tool_map: Dict[str, MCPClientTool] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
async def connect_sse(self, server_url: str, timeout: float = 30.0) -> None:
async def connect_sse(
self, server_url: str, api_key: str | None = None, timeout: float = 30.0
) -> None:
"""Connect to an MCP server using SSE transport.
Args:
@@ -41,7 +43,8 @@ class MCPClient(BaseModel):
async def connect_with_timeout():
streams_context = sse_client(
url=server_url,
timeout=timeout, # Pass the timeout to sse_client
headers={'Authorization': f'Bearer {api_key}'} if api_key else None,
timeout=timeout,
)
streams = await self.exit_stack.enter_async_context(streams_context)
self.session = await self.exit_stack.enter_async_context(
@@ -92,7 +95,10 @@ class MCPClient(BaseModel):
"""Call a tool on the MCP server."""
if tool_name not in self.tool_map:
raise ValueError(f'Tool {tool_name} not found.')
return await self.tool_map[tool_name].execute(**args)
# The MCPClientTool is primarily for metadata; use the session to call the actual tool.
if not self.session:
raise RuntimeError('Client session is not available.')
return await self.session.call_tool(name=tool_name, arguments=args)
async def disconnect(self) -> None:
"""Disconnect from the MCP server and clean up resources."""

View File

@@ -1,54 +1,26 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional
from typing import Dict
from mcp import ClientSession
from mcp.types import CallToolResult, TextContent, Tool
from mcp.types import Tool
class BaseTool(ABC, Tool):
@classmethod
def postfix(cls) -> str:
return '_mcp_tool_call'
class MCPClientTool(Tool):
"""
Represents a tool proxy that can be called on the MCP server from the client side.
This version doesn't store a session reference, as sessions are created on-demand
by the MCPClient for each operation.
"""
class Config:
arbitrary_types_allowed = True
@abstractmethod
async def execute(self, **kwargs) -> CallToolResult:
"""Execute the tool with given parameters."""
def to_param(self) -> Dict:
"""Convert tool to function call format."""
return {
'type': 'function',
'function': {
'name': self.name + self.postfix(),
'name': self.name,
'description': self.description,
'parameters': self.inputSchema,
},
}
class MCPClientTool(BaseTool):
"""Represents a tool proxy that can be called on the MCP server from the client side."""
session: Optional[ClientSession] = None
async def execute(self, **kwargs) -> CallToolResult:
"""Execute the tool by making a remote call to the MCP server."""
if not self.session:
return CallToolResult(
content=[TextContent(text='Not connected to MCP server', type='text')],
isError=True,
)
try:
result = await self.session.call_tool(self.name, kwargs)
return result
except Exception as e:
return CallToolResult(
content=[
TextContent(text=f'Error executing tool: {str(e)}', type='text')
],
isError=True,
)

View File

@@ -1,11 +1,16 @@
import json
from typing import TYPE_CHECKING
from openhands.core.config.mcp_config import MCPConfig
if TYPE_CHECKING:
from openhands.controller.agent import Agent
from openhands.core.config.mcp_config import MCPConfig, MCPSSEServerConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.mcp import McpAction
from openhands.events.action.mcp import MCPAction
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.mcp.client import MCPClient
from openhands.runtime.base import Runtime
def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[dict]:
@@ -38,19 +43,19 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di
async def create_mcp_clients(
mcp_servers: list[str],
sse_servers: list[MCPSSEServerConfig],
) -> list[MCPClient]:
mcp_clients: list[MCPClient] = []
# Initialize SSE connections
if mcp_servers:
for server_url in mcp_servers:
if sse_servers:
for server_url in sse_servers:
logger.info(
f'Initializing MCP agent for {server_url} with SSE connection...'
)
client = MCPClient()
try:
await client.connect_sse(server_url)
await client.connect_sse(server_url.url, api_key=server_url.api_key)
# Only add the client to the list after a successful connection
mcp_clients.append(client)
logger.info(f'Connected to MCP server {server_url} via SSE')
@@ -77,14 +82,16 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
mcp_tools = []
try:
logger.debug(f'Creating MCP clients with config: {mcp_config}')
# Create clients - this will fetch tools but not maintain active connections
mcp_clients = await create_mcp_clients(
mcp_config.mcp_servers,
mcp_config.sse_servers,
)
if not mcp_clients:
logger.debug('No MCP clients were successfully connected')
return []
# Convert tools to the format expected by the agent
mcp_tools = convert_mcp_clients_to_tools(mcp_clients)
# Always disconnect clients to clean up resources
@@ -93,6 +100,7 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
await mcp_client.disconnect()
except Exception as disconnect_error:
logger.error(f'Error disconnecting MCP client: {str(disconnect_error)}')
except Exception as e:
logger.error(f'Error fetching MCP tools: {str(e)}')
return []
@@ -101,13 +109,13 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
return mcp_tools
async def call_tool_mcp(mcp_clients: list[MCPClient], action: McpAction) -> Observation:
async def call_tool_mcp(mcp_clients: list[MCPClient], action: MCPAction) -> Observation:
"""
Call a tool on an MCP server and return the observation.
Args:
mcp_clients: The list of MCP clients to execute the action on
action: The MCP action to execute
sse_mcp_servers: List of SSE MCP server URLs
Returns:
The observation from the MCP server
@@ -116,20 +124,55 @@ async def call_tool_mcp(mcp_clients: list[MCPClient], action: McpAction) -> Obse
raise ValueError('No MCP clients found')
logger.debug(f'MCP action received: {action}')
# Find the MCP agent that has the matching tool name
# Find the MCP client that has the matching tool name
matching_client = None
logger.debug(f'MCP clients: {mcp_clients}')
logger.debug(f'MCP action name: {action.name}')
for client in mcp_clients:
logger.debug(f'MCP client tools: {client.tools}')
if action.name in [tool.name for tool in client.tools]:
matching_client = client
break
if matching_client is None:
raise ValueError(f'No matching MCP agent found for tool name: {action.name}')
logger.debug(f'Matching client: {matching_client}')
args_dict = json.loads(action.arguments) if action.arguments else {}
response = await matching_client.call_tool(action.name, args_dict)
# Call the tool - this will create a new connection internally
response = await matching_client.call_tool(action.name, action.arguments)
logger.debug(f'MCP response: {response}')
return MCPObservation(content=f'MCP result:{response.model_dump(mode="json")}')
return MCPObservation(content=json.dumps(response.model_dump(mode='json')))
async def add_mcp_tools_to_agent(
agent: 'Agent', runtime: Runtime, mcp_config: MCPConfig
):
"""
Add MCP tools to an agent.
"""
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient, # inline import to avoid circular import
)
assert isinstance(
runtime, ActionExecutionClient
), 'Runtime must be an instance of ActionExecutionClient'
assert (
runtime.runtime_initialized
), 'Runtime must be initialized before adding MCP tools'
# Add the runtime as another MCP server
updated_mcp_config = runtime.get_updated_mcp_config()
# Fetch the MCP tools
mcp_tools = await fetch_mcp_tools_from_config(updated_mcp_config)
logger.info(
f"Loaded {len(mcp_tools)} MCP tools: {[tool['function']['name'] for tool in mcp_tools]}"
)
# Set the MCP tools on the agent
agent.set_mcp_tools(mcp_tools)

View File

@@ -19,7 +19,7 @@ from openhands.events.action import (
IPythonRunCellAction,
MessageAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.mcp import MCPAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import Event, RecallType
from openhands.events.observation import (
@@ -184,7 +184,7 @@ class ConversationMemory:
- BrowseInteractiveAction: For browsing the web
- AgentFinishAction: For ending the interaction
- MessageAction: For sending messages
- McpAction: For interacting with the MCP server
- MCPAction: For interacting with the MCP server
pending_tool_call_action_messages: Dictionary mapping response IDs to their corresponding messages.
Used in function calling mode to track tool calls that are waiting for their results.
@@ -210,7 +210,7 @@ class ConversationMemory:
FileReadAction,
BrowseInteractiveAction,
BrowseURLAction,
McpAction,
MCPAction,
),
) or (isinstance(action, CmdRunAction) and action.source == 'agent'):
tool_metadata = action.tool_call_metadata

View File

@@ -8,6 +8,8 @@ NOTE: this will be executed inside the docker sandbox.
import argparse
import asyncio
import base64
import json
import logging
import mimetypes
import os
import shutil
@@ -23,6 +25,8 @@ from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.responses import FileResponse, JSONResponse
from fastapi.security import APIKeyHeader
from mcpm import MCPRouter, RouterConfig
from mcpm.router.router import logger as mcp_router_logger
from openhands_aci.editor.editor import OHEditor
from openhands_aci.editor.exceptions import ToolError
from openhands_aci.editor.results import ToolResult
@@ -68,6 +72,9 @@ from openhands.runtime.utils.runtime_init import init_user_and_working_directory
from openhands.runtime.utils.system_stats import get_system_stats
from openhands.utils.async_utils import call_sync_from_async, wait_all
# Set MCP router logger to the same level as the main logger
mcp_router_logger.setLevel(logger.getEffectiveLevel())
class ActionRequest(BaseModel):
action: dict
@@ -572,10 +579,15 @@ if __name__ == '__main__':
plugins_to_load.append(ALL_PLUGINS[plugin]()) # type: ignore
client: ActionExecutor | None = None
mcp_router: MCPRouter | None = None
MCP_ROUTER_PROFILE_PATH = os.path.join(
os.path.dirname(__file__), 'mcp', 'config.json'
)
@asynccontextmanager
async def lifespan(app: FastAPI):
global client
global client, mcp_router
logger.info('Initializing ActionExecutor...')
client = ActionExecutor(
plugins_to_load,
work_dir=args.working_dir,
@@ -584,9 +596,70 @@ if __name__ == '__main__':
browsergym_eval_env=args.browsergym_eval_env,
)
await client.ainit()
logger.info('ActionExecutor initialized.')
# Initialize and mount MCP Router
logger.info('Initializing MCP Router...')
mcp_router = MCPRouter(
profile_path=MCP_ROUTER_PROFILE_PATH,
router_config=RouterConfig(
api_key=SESSION_API_KEY,
auth_enabled=bool(SESSION_API_KEY),
),
)
allowed_origins = ['*']
sse_app = await mcp_router.get_sse_server_app(
allow_origins=allowed_origins, include_lifespan=False
)
# Check for route conflicts before mounting
main_app_routes = {route.path for route in app.routes}
sse_app_routes = {route.path for route in sse_app.routes}
conflicting_routes = main_app_routes.intersection(sse_app_routes)
if conflicting_routes:
logger.error(f'Route conflicts detected: {conflicting_routes}')
raise RuntimeError(
f'Cannot mount SSE app - conflicting routes found: {conflicting_routes}'
)
app.mount('/', sse_app)
logger.info(
f'Mounted MCP Router SSE app at root path with allowed origins: {allowed_origins}'
)
# Additional debug logging
if logger.isEnabledFor(logging.DEBUG):
logger.debug('Main app routes:')
for route in main_app_routes:
logger.debug(f' {route}')
logger.debug('MCP SSE server app routes:')
for route in sse_app_routes:
logger.debug(f' {route}')
yield
# Clean up & release the resources
client.close()
logger.info('Shutting down MCP Router...')
if mcp_router:
try:
await mcp_router.shutdown()
logger.info('MCP Router shutdown successfully.')
except Exception as e:
logger.error(f'Error shutting down MCP Router: {e}', exc_info=True)
else:
logger.info('MCP Router instance not found for shutdown.')
logger.info('Closing ActionExecutor...')
if client:
try:
client.close()
logger.info('ActionExecutor closed successfully.')
except Exception as e:
logger.error(f'Error closing ActionExecutor: {e}', exc_info=True)
else:
logger.info('ActionExecutor instance not found for closing.')
logger.info('Shutdown complete.')
app = FastAPI(lifespan=lifespan)
@@ -663,6 +736,51 @@ if __name__ == '__main__':
detail=traceback.format_exc(),
)
@app.post('/update_mcp_server')
async def update_mcp_server(request: Request):
assert mcp_router is not None
assert os.path.exists(MCP_ROUTER_PROFILE_PATH)
# Use synchronous file operations outside of async function
def read_profile():
with open(MCP_ROUTER_PROFILE_PATH, 'r') as f:
return json.load(f)
current_profile = read_profile()
assert 'default' in current_profile
assert isinstance(current_profile['default'], list)
# Get the request body
mcp_tools_to_sync = await request.json()
if not isinstance(mcp_tools_to_sync, list):
raise HTTPException(
status_code=400, detail='Request must be a list of MCP tools to sync'
)
logger.info(
f'Updating MCP server to: {json.dumps(mcp_tools_to_sync, indent=2)}.\nPrevious profile: {json.dumps(current_profile, indent=2)}'
)
current_profile['default'] = mcp_tools_to_sync
# Use synchronous file operations outside of async function
def write_profile(profile):
with open(MCP_ROUTER_PROFILE_PATH, 'w') as f:
json.dump(profile, f)
write_profile(current_profile)
# Manually reload the profile and update the servers
mcp_router.profile_manager.reload()
servers_wait_for_update = mcp_router.get_unique_servers()
await mcp_router.update_servers(servers_wait_for_update)
logger.info(
f'MCP router updated successfully with unique servers: {servers_wait_for_update}'
)
return JSONResponse(
status_code=200, content={'detail': 'MCP server updated successfully'}
)
@app.post('/upload_file')
async def upload_file(
file: UploadFile, destination: str = '/', recursive: bool = False

View File

@@ -30,7 +30,7 @@ from openhands.events.action import (
FileWriteAction,
IPythonRunCellAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.mcp import MCPAction
from openhands.events.event import Event
from openhands.events.observation import (
AgentThinkObservation,
@@ -282,9 +282,8 @@ class Runtime(FileEditRuntimeMixin):
assert event.timeout is not None
try:
await self._export_latest_git_provider_tokens(event)
if isinstance(event, McpAction):
# we don't call call_tool_mcp impl directly because there can be other action ActionExecutionClient
observation: Observation = await getattr(self, McpAction.action)(event)
if isinstance(event, MCPAction):
observation: Observation = await self.call_tool_mcp(event)
else:
observation = await call_sync_from_async(self.run_action, event)
except Exception as e:
@@ -571,7 +570,7 @@ class Runtime(FileEditRuntimeMixin):
pass
@abstractmethod
async def call_tool_mcp(self, action: McpAction) -> Observation:
async def call_tool_mcp(self, action: MCPAction) -> Observation:
pass
# ====================================================================

View File

@@ -1,3 +1,4 @@
import asyncio
import os
import tempfile
import threading
@@ -10,6 +11,7 @@ import httpx
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from openhands.core.config import AppConfig
from openhands.core.config.mcp_config import MCPConfig, MCPSSEServerConfig
from openhands.core.exceptions import (
AgentRuntimeTimeoutError,
)
@@ -27,7 +29,7 @@ from openhands.events.action import (
)
from openhands.events.action.action import Action
from openhands.events.action.files import FileEditSource
from openhands.events.action.mcp import McpAction
from openhands.events.action.mcp import MCPAction
from openhands.events.observation import (
AgentThinkObservation,
ErrorObservation,
@@ -38,16 +40,12 @@ from openhands.events.observation import (
from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.mcp import MCPClient, create_mcp_clients
from openhands.mcp import call_tool_mcp as call_tool_mcp_handler
from openhands.runtime.base import Runtime
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.request import send_request
from openhands.utils.async_utils import call_async_from_sync
from openhands.utils.http_session import HttpSession
from openhands.utils.tenacity_stop import stop_if_should_exit
def _is_retryable_error(exception):
return isinstance(
exception, (httpx.RemoteProtocolError, httpcore.RemoteProtocolError)
@@ -79,7 +77,6 @@ class ActionExecutionClient(Runtime):
self._runtime_initialized: bool = False
self._runtime_closed: bool = False
self._vscode_token: str | None = None # initial dummy value
self.mcp_clients: list[MCPClient] | None = None
super().__init__(
config,
event_stream,
@@ -329,19 +326,59 @@ class ActionExecutionClient(Runtime):
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
return self.send_action_for_execution(action)
async def call_tool_mcp(self, action: McpAction) -> Observation:
if self.mcp_clients is None:
self.log(
'debug',
f'Creating MCP clients with servers: {self.config.mcp.mcp_servers}',
)
self.mcp_clients = await create_mcp_clients(self.config.mcp.mcp_servers)
return await call_tool_mcp_handler(self.mcp_clients, action)
def get_updated_mcp_config(self) -> MCPConfig:
# Add the runtime as another MCP server
updated_mcp_config = self.config.mcp.model_copy()
# Send a request to the action execution server to updated MCP config
stdio_tools = [
server.model_dump(mode='json')
for server in updated_mcp_config.stdio_servers
]
self.log('debug', f'Updating MCP server to: {stdio_tools}')
response = self._send_action_server_request(
'POST',
f'{self.action_execution_server_url}/update_mcp_server',
json=stdio_tools,
timeout=10,
)
if response.status_code != 200:
raise RuntimeError(f'Failed to update MCP server: {response.text}')
async def aclose(self) -> None:
if self.mcp_clients:
for client in self.mcp_clients:
await client.disconnect()
# No API key by default. Child runtime can override this when appropriate
updated_mcp_config.sse_servers.append(
MCPSSEServerConfig(
url=self.action_execution_server_url.rstrip('/') + '/sse', api_key=None
)
)
self.log(
'debug',
f'Updated MCP config by adding runtime as another server: {updated_mcp_config}',
)
return updated_mcp_config
async def call_tool_mcp(self, action: MCPAction) -> Observation:
# Import here to avoid circular imports
from openhands.mcp.utils import create_mcp_clients, call_tool_mcp as call_tool_mcp_handler
# Get the updated MCP config
updated_mcp_config = self.get_updated_mcp_config()
self.log(
'debug',
f'Creating MCP clients with servers: {updated_mcp_config.sse_servers}',
)
# Create clients for this specific operation
mcp_clients = await create_mcp_clients(updated_mcp_config.sse_servers)
# Call the tool and return the result
# No need for try/finally since disconnect() is now just resetting state
result = await call_tool_mcp_handler(mcp_clients, action)
# Reset client state (no active connections to worry about)
for client in mcp_clients:
await client.disconnect()
return result
def close(self) -> None:
# Make sure we don't close the session multiple times
@@ -350,4 +387,3 @@ class ActionExecutionClient(Runtime):
return
self._runtime_closed = True
self.session.close()
call_async_from_sync(self.aclose)

View File

@@ -0,0 +1,3 @@
{
"default": []
}

View File

@@ -34,10 +34,12 @@ RUN apt-get update && \
{% if 'ubuntu' in base_image %}
RUN ln -s "$(dirname $(which node))/corepack" /usr/local/bin/corepack && \
npm install -g corepack && corepack enable yarn && \
curl -fsSL --compressed https://install.python-poetry.org | python - && \
curl -LsSf https://astral.sh/uv/install.sh | sh
curl -fsSL --compressed https://install.python-poetry.org | python -
{% endif %}
# Install uv (required by MCP)
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
# Remove UID 1000 named pn or ubuntu, so the 'openhands' user can be created from ubuntu hosts
RUN (if getent passwd 1000 | grep -q pn; then userdel pn; fi) && \
(if getent passwd 1000 | grep -q ubuntu; then userdel ubuntu; fi)

View File

@@ -18,6 +18,7 @@ from openhands.events.event import Event, EventSource
from openhands.events.stream import EventStream
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
from openhands.integrations.service_types import Repository
from openhands.mcp import add_mcp_tools_to_agent
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroagent
from openhands.runtime import get_runtime_cls
@@ -124,6 +125,11 @@ class AgentSession:
selected_branch=selected_branch,
)
# NOTE: this needs to happen before controller is created
# so MCP tools can be included into the SystemMessageAction
if self.runtime and runtime_connected:
await add_mcp_tools_to_agent(agent, self.runtime, config.mcp)
if replay_json:
initial_message = self._run_replay(
initial_message,
@@ -148,6 +154,7 @@ class AgentSession:
repo_directory = None
if self.runtime and runtime_connected and selected_repository:
repo_directory = selected_repository.full_name.split('/')[-1]
self.memory = await self._create_memory(
selected_repository=selected_repository,
repo_directory=repo_directory,

View File

@@ -25,7 +25,6 @@ from openhands.events.observation.error import ErrorObservation
from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.mcp import fetch_mcp_tools_from_config
from openhands.server.session.agent_session import AgentSession
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.storage.data_models.settings import Settings
@@ -147,9 +146,7 @@ class Session:
self.logger.info(f'Enabling default condenser: {default_condenser_config}')
agent_config.condenser = default_condenser_config
mcp_tools = await fetch_mcp_tools_from_config(self.config.mcp)
agent = Agent.get_cls(agent_cls)(llm, agent_config)
agent.set_mcp_tools(mcp_tools)
git_provider_tokens = None
selected_repository = None

869
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -76,6 +76,8 @@ mcp = "1.6.0"
python-json-logger = "^3.2.1"
playwright = "^1.51.0"
prompt-toolkit = "^3.0.50"
mcpm = "1.8.0"
poetry = "^2.1.2"
anyio = "4.9.0"
[tool.poetry.group.dev.dependencies]
@@ -99,6 +101,7 @@ gevent = ">=24.2.1,<26.0.0"
concurrency = ["gevent"]
[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
@@ -128,6 +131,7 @@ ignore = ["D1"]
convention = "google"
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"

View File

@@ -7,7 +7,7 @@ import time
import pytest
from pytest import TempPathFactory
from openhands.core.config import AppConfig, load_app_config
from openhands.core.config import AppConfig, MCPConfig, load_app_config
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.runtime.base import Runtime
@@ -214,6 +214,7 @@ def _load_runtime(
force_rebuild_runtime: bool = False,
runtime_startup_env_vars: dict[str, str] | None = None,
docker_runtime_kwargs: dict[str, str] | None = None,
override_mcp_config: MCPConfig | None = None,
) -> tuple[Runtime, AppConfig]:
sid = 'rt_' + str(random.randint(100000, 999999))
@@ -256,6 +257,9 @@ def _load_runtime(
config.sandbox.base_container_image = base_container_image
config.sandbox.runtime_container_image = None
if override_mcp_config is not None:
config.mcp = override_mcp_config
file_store = get_file_store(config.file_store, config.file_store_path)
event_stream = EventStream(sid, file_store)

View File

@@ -0,0 +1,78 @@
"""Bash-related tests for the DockerRuntime, which connects to the ActionExecutor running in the sandbox."""
import json
import os
import pytest
from conftest import (
_load_runtime,
)
import openhands
from openhands.core.config import MCPConfig
from openhands.core.config.mcp_config import MCPStdioServerConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import CmdRunAction, MCPAction
from openhands.events.observation import CmdOutputObservation, MCPObservation
# ============================================================================================================================
# Bash-specific tests
# ============================================================================================================================
def test_default_activated_tools():
project_root = os.path.dirname(openhands.__file__)
mcp_config_path = os.path.join(project_root, 'runtime', 'mcp', 'config.json')
assert os.path.exists(
mcp_config_path
), f'MCP config file not found at {mcp_config_path}'
with open(mcp_config_path, 'r') as f:
mcp_config = json.load(f)
assert 'default' in mcp_config
# no tools are always activated yet
assert len(mcp_config['default']) == 0
@pytest.mark.asyncio
async def test_fetch_mcp_via_stdio(temp_dir, runtime_cls, run_as_openhands):
mcp_stdio_server_config = MCPStdioServerConfig(
name='fetch', command='uvx', args=['mcp-server-fetch']
)
override_mcp_config = MCPConfig(stdio_servers=[mcp_stdio_server_config])
runtime, config = _load_runtime(
temp_dir, runtime_cls, run_as_openhands, override_mcp_config=override_mcp_config
)
# Test browser server
action_cmd = CmdRunAction(command='python3 -m http.server 8000 > server.log 2>&1 &')
logger.info(action_cmd, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action_cmd)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert isinstance(obs, CmdOutputObservation)
assert obs.exit_code == 0
assert '[1]' in obs.content
action_cmd = CmdRunAction(command='sleep 3 && cat server.log')
logger.info(action_cmd, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action_cmd)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
mcp_action = MCPAction(name='fetch', arguments={'url': 'http://localhost:8000'})
obs = await runtime.call_tool_mcp(mcp_action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert isinstance(
obs, MCPObservation
), 'The observation should be a MCPObservation.'
result_json = json.loads(obs.content)
assert not result_json['isError']
assert len(result_json['content']) == 1
assert result_json['content'][0]['type'] == 'text'
assert (
result_json['content'][0]['text']
== 'Contents of http://localhost:8000/:\n---\n\n* <server.log>\n\n---'
)
runtime.close()

View File

@@ -32,6 +32,9 @@ from openhands.llm import LLM
from openhands.llm.metrics import Metrics, TokenUsage
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.storage.memory import InMemoryFileStore
@@ -83,8 +86,12 @@ def test_event_stream():
@pytest.fixture
def mock_runtime() -> Runtime:
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
runtime = MagicMock(
spec=Runtime,
spec=ActionExecutionClient,
event_stream=test_event_stream,
)
return runtime
@@ -233,7 +240,7 @@ async def test_run_controller_with_fatal_error(
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=Runtime)
runtime = MagicMock(spec=ActionExecutionClient)
def on_event(event: Event):
if isinstance(event, CmdRunAction):
@@ -298,7 +305,7 @@ async def test_run_controller_stop_with_stuck(
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=Runtime)
runtime = MagicMock(spec=ActionExecutionClient)
def on_event(event: Event):
if isinstance(event, CmdRunAction):
@@ -660,7 +667,7 @@ async def test_run_controller_max_iterations_has_metrics(
mock_agent.step = agent_step_fn
runtime = MagicMock(spec=Runtime)
runtime = MagicMock(spec=ActionExecutionClient)
def on_event(event: Event):
if isinstance(event, CmdRunAction):
@@ -1064,7 +1071,7 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
mock_agent.step = agent_step_fn
runtime = MagicMock(spec=Runtime)
runtime = MagicMock(spec=ActionExecutionClient)
runtime.event_stream = event_stream
# Create a real Memory instance

View File

@@ -11,7 +11,9 @@ from openhands.events import EventStream, EventStreamSubscriber
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore
@@ -58,7 +60,7 @@ async def test_agent_session_start_with_no_state(mock_agent):
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
@@ -141,7 +143,7 @@ async def test_agent_session_start_with_restored_state(mock_agent):
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):

View File

@@ -123,7 +123,7 @@ def mock_settings_store():
@patch('openhands.core.cli.display_runtime_initialization_message')
@patch('openhands.core.cli.display_initialization_animation')
@patch('openhands.core.cli.create_agent')
@patch('openhands.core.cli.fetch_mcp_tools_from_config')
@patch('openhands.core.cli.add_mcp_tools_to_agent')
@patch('openhands.core.cli.create_runtime')
@patch('openhands.core.cli.create_controller')
@patch('openhands.core.cli.create_memory')
@@ -137,7 +137,7 @@ async def test_run_session_without_initial_action(
mock_create_memory,
mock_create_controller,
mock_create_runtime,
mock_fetch_mcp_tools,
mock_add_mcp_tools,
mock_create_agent,
mock_display_animation,
mock_display_runtime_init,
@@ -154,9 +154,6 @@ async def test_run_session_without_initial_action(
mock_agent = AsyncMock()
mock_create_agent.return_value = mock_agent
mock_mcp_tools = []
mock_fetch_mcp_tools.return_value = mock_mcp_tools
mock_runtime = AsyncMock()
mock_runtime.event_stream = MagicMock()
mock_create_runtime.return_value = mock_runtime
@@ -193,8 +190,9 @@ async def test_run_session_without_initial_action(
mock_display_runtime_init.assert_called_once_with('local')
mock_display_animation.assert_called_once()
mock_create_agent.assert_called_once_with(mock_config)
mock_fetch_mcp_tools.assert_called_once()
mock_agent.set_mcp_tools.assert_called_once_with(mock_mcp_tools)
mock_add_mcp_tools.assert_called_once_with(
mock_agent, mock_runtime, mock_config.mcp
)
mock_create_runtime.assert_called_once()
mock_create_controller.assert_called_once()
mock_create_memory.assert_called_once()
@@ -213,7 +211,7 @@ async def test_run_session_without_initial_action(
@patch('openhands.core.cli.display_runtime_initialization_message')
@patch('openhands.core.cli.display_initialization_animation')
@patch('openhands.core.cli.create_agent')
@patch('openhands.core.cli.fetch_mcp_tools_from_config')
@patch('openhands.core.cli.add_mcp_tools_to_agent')
@patch('openhands.core.cli.create_runtime')
@patch('openhands.core.cli.create_controller')
@patch('openhands.core.cli.create_memory')
@@ -227,7 +225,7 @@ async def test_run_session_with_initial_action(
mock_create_memory,
mock_create_controller,
mock_create_runtime,
mock_fetch_mcp_tools,
mock_add_mcp_tools,
mock_create_agent,
mock_display_animation,
mock_display_runtime_init,
@@ -244,9 +242,6 @@ async def test_run_session_with_initial_action(
mock_agent = AsyncMock()
mock_create_agent.return_value = mock_agent
mock_mcp_tools = []
mock_fetch_mcp_tools.return_value = mock_mcp_tools
mock_runtime = AsyncMock()
mock_runtime.event_stream = MagicMock()
mock_create_runtime.return_value = mock_runtime

View File

@@ -0,0 +1,108 @@
import json
from openhands.core.schema import ActionType, ObservationType
from openhands.events.action.mcp import MCPAction
from openhands.events.observation.mcp import MCPObservation
def test_mcp_action_creation():
"""Test creating an MCPAction."""
action = MCPAction(name='test_tool', arguments={'arg1': 'value1', 'arg2': 42})
assert action.name == 'test_tool'
assert action.arguments == {'arg1': 'value1', 'arg2': 42}
assert action.action == ActionType.MCP
assert action.thought == ''
assert action.runnable is True
assert action.security_risk is None
def test_mcp_action_with_thought():
"""Test creating an MCPAction with a thought."""
action = MCPAction(
name='test_tool',
arguments={'arg1': 'value1', 'arg2': 42},
thought='This is a test thought',
)
assert action.name == 'test_tool'
assert action.arguments == {'arg1': 'value1', 'arg2': 42}
assert action.thought == 'This is a test thought'
def test_mcp_action_message():
"""Test the message property of MCPAction."""
action = MCPAction(name='test_tool', arguments={'arg1': 'value1', 'arg2': 42})
message = action.message
assert 'test_tool' in message
assert 'arg1' in message
assert 'value1' in message
assert '42' in message
def test_mcp_action_str_representation():
"""Test the string representation of MCPAction."""
action = MCPAction(
name='test_tool',
arguments={'arg1': 'value1', 'arg2': 42},
thought='This is a test thought',
)
str_repr = str(action)
assert 'MCPAction' in str_repr
assert 'THOUGHT: This is a test thought' in str_repr
assert 'NAME: test_tool' in str_repr
assert 'ARGUMENTS:' in str_repr
assert 'arg1' in str_repr
assert 'value1' in str_repr
assert '42' in str_repr
def test_mcp_observation_creation():
"""Test creating an MCPObservation."""
observation = MCPObservation(
content=json.dumps({'result': 'success', 'data': 'test data'})
)
assert observation.content == json.dumps({'result': 'success', 'data': 'test data'})
assert observation.observation == ObservationType.MCP
def test_mcp_observation_message():
"""Test the message property of MCPObservation."""
observation = MCPObservation(
content=json.dumps({'result': 'success', 'data': 'test data'})
)
message = observation.message
assert message == json.dumps({'result': 'success', 'data': 'test data'})
assert 'result' in message
assert 'success' in message
assert 'data' in message
assert 'test data' in message
def test_mcp_action_with_complex_arguments():
"""Test MCPAction with complex nested arguments."""
complex_args = {
'simple_arg': 'value',
'number_arg': 42,
'boolean_arg': True,
'nested_arg': {'inner_key': 'inner_value', 'inner_list': [1, 2, 3]},
'list_arg': ['a', 'b', 'c'],
}
action = MCPAction(name='complex_tool', arguments=complex_args)
assert action.name == 'complex_tool'
assert action.arguments == complex_args
assert action.arguments['nested_arg']['inner_key'] == 'inner_value'
assert action.arguments['list_arg'] == ['a', 'b', 'c']
# Check that the message contains the complex arguments
message = action.message
assert 'complex_tool' in message
assert 'nested_arg' in message
assert 'inner_key' in message
assert 'inner_value' in message

View File

@@ -1,23 +1,33 @@
import pytest
from pydantic import ValidationError
from openhands.core.config.mcp_config import MCPConfig
from openhands.core.config.mcp_config import (
MCPConfig,
MCPSSEServerConfig,
MCPStdioServerConfig,
)
def test_valid_sse_config():
"""Test a valid SSE configuration."""
config = MCPConfig(mcp_servers=['http://server1:8080', 'http://server2:8080'])
config = MCPConfig(
sse_servers=[
MCPSSEServerConfig(url='http://server1:8080'),
MCPSSEServerConfig(url='http://server2:8080'),
]
)
config.validate_servers() # Should not raise any exception
def test_empty_sse_config():
"""Test SSE configuration with empty servers list."""
config = MCPConfig(mcp_servers=[])
config = MCPConfig(sse_servers=[])
config.validate_servers()
def test_invalid_sse_url():
"""Test SSE configuration with invalid URL format."""
config = MCPConfig(mcp_servers=['not_a_url'])
config = MCPConfig(sse_servers=[MCPSSEServerConfig(url='not_a_url')])
with pytest.raises(ValueError) as exc_info:
config.validate_servers()
assert 'Invalid URL' in str(exc_info.value)
@@ -25,7 +35,12 @@ def test_invalid_sse_url():
def test_duplicate_sse_urls():
"""Test SSE configuration with duplicate server URLs."""
config = MCPConfig(mcp_servers=['http://server1:8080', 'http://server1:8080'])
config = MCPConfig(
sse_servers=[
MCPSSEServerConfig(url='http://server1:8080'),
MCPSSEServerConfig(url='http://server1:8080'),
]
)
with pytest.raises(ValueError) as exc_info:
config.validate_servers()
assert 'Duplicate MCP server URLs are not allowed' in str(exc_info.value)
@@ -34,17 +49,18 @@ def test_duplicate_sse_urls():
def test_from_toml_section_valid():
"""Test creating config from valid TOML section."""
data = {
'mcp_servers': ['http://server1:8080'],
'sse_servers': ['http://server1:8080'],
}
result = MCPConfig.from_toml_section(data)
assert 'mcp' in result
assert result['mcp'].mcp_servers == ['http://server1:8080']
assert len(result['mcp'].sse_servers) == 1
assert result['mcp'].sse_servers[0].url == 'http://server1:8080'
def test_from_toml_section_invalid_sse():
"""Test creating config from TOML section with invalid SSE URL."""
data = {
'mcp_servers': ['not_a_url'],
'sse_servers': ['not_a_url'],
}
with pytest.raises(ValueError) as exc_info:
MCPConfig.from_toml_section(data)
@@ -54,10 +70,125 @@ def test_from_toml_section_invalid_sse():
def test_complex_urls():
"""Test SSE configuration with complex URLs."""
config = MCPConfig(
mcp_servers=[
'https://user:pass@server1:8080/path?query=1',
'wss://server2:8443/ws',
'http://subdomain.example.com:9090',
sse_servers=[
MCPSSEServerConfig(url='https://user:pass@server1:8080/path?query=1'),
MCPSSEServerConfig(url='wss://server2:8443/ws'),
MCPSSEServerConfig(url='http://subdomain.example.com:9090'),
]
)
config.validate_servers() # Should not raise any exception
def test_mcp_sse_server_config_with_api_key():
"""Test MCPSSEServerConfig with API key."""
config = MCPSSEServerConfig(url='http://server1:8080', api_key='test-api-key')
assert config.url == 'http://server1:8080'
assert config.api_key == 'test-api-key'
def test_mcp_sse_server_config_without_api_key():
"""Test MCPSSEServerConfig without API key."""
config = MCPSSEServerConfig(url='http://server1:8080')
assert config.url == 'http://server1:8080'
assert config.api_key is None
def test_mcp_stdio_server_config_basic():
"""Test basic MCPStdioServerConfig."""
config = MCPStdioServerConfig(name='test-server', command='python')
assert config.name == 'test-server'
assert config.command == 'python'
assert config.args == []
assert config.env == {}
def test_mcp_stdio_server_config_with_args_and_env():
"""Test MCPStdioServerConfig with args and env."""
config = MCPStdioServerConfig(
name='test-server',
command='python',
args=['-m', 'server'],
env={'DEBUG': 'true', 'PORT': '8080'},
)
assert config.name == 'test-server'
assert config.command == 'python'
assert config.args == ['-m', 'server']
assert config.env == {'DEBUG': 'true', 'PORT': '8080'}
def test_mcp_config_with_stdio_servers():
"""Test MCPConfig with stdio servers."""
stdio_server = MCPStdioServerConfig(
name='test-server',
command='python',
args=['-m', 'server'],
env={'DEBUG': 'true'},
)
config = MCPConfig(stdio_servers=[stdio_server])
assert len(config.stdio_servers) == 1
assert config.stdio_servers[0].name == 'test-server'
assert config.stdio_servers[0].command == 'python'
assert config.stdio_servers[0].args == ['-m', 'server']
assert config.stdio_servers[0].env == {'DEBUG': 'true'}
def test_from_toml_section_with_stdio_servers():
"""Test creating config from TOML section with stdio servers."""
data = {
'sse_servers': ['http://server1:8080'],
'stdio_servers': [
{
'name': 'test-server',
'command': 'python',
'args': ['-m', 'server'],
'env': {'DEBUG': 'true'},
}
],
}
result = MCPConfig.from_toml_section(data)
assert 'mcp' in result
assert len(result['mcp'].sse_servers) == 1
assert result['mcp'].sse_servers[0].url == 'http://server1:8080'
assert len(result['mcp'].stdio_servers) == 1
assert result['mcp'].stdio_servers[0].name == 'test-server'
assert result['mcp'].stdio_servers[0].command == 'python'
assert result['mcp'].stdio_servers[0].args == ['-m', 'server']
assert result['mcp'].stdio_servers[0].env == {'DEBUG': 'true'}
def test_mcp_config_with_both_server_types():
"""Test MCPConfig with both SSE and stdio servers."""
sse_server = MCPSSEServerConfig(url='http://server1:8080', api_key='test-api-key')
stdio_server = MCPStdioServerConfig(
name='test-server',
command='python',
args=['-m', 'server'],
env={'DEBUG': 'true'},
)
config = MCPConfig(sse_servers=[sse_server], stdio_servers=[stdio_server])
assert len(config.sse_servers) == 1
assert config.sse_servers[0].url == 'http://server1:8080'
assert config.sse_servers[0].api_key == 'test-api-key'
assert len(config.stdio_servers) == 1
assert config.stdio_servers[0].name == 'test-server'
assert config.stdio_servers[0].command == 'python'
def test_mcp_config_model_validation_error():
"""Test MCPConfig validation error with invalid data."""
with pytest.raises(ValidationError):
# Missing required 'url' field
MCPSSEServerConfig()
with pytest.raises(ValidationError):
# Missing required 'name' and 'command' fields
MCPStdioServerConfig()
def test_mcp_config_extra_fields_forbidden():
"""Test that extra fields are forbidden in MCPConfig."""
with pytest.raises(ValidationError):
MCPConfig(extra_field='value')
# Note: The nested models don't have 'extra': 'forbid' set, so they allow extra fields
# We're only testing the main MCPConfig class here

View File

@@ -3,7 +3,7 @@ from unittest import mock
import pytest
from openhands.core.config.mcp_config import MCPConfig
from openhands.core.config.mcp_config import MCPConfig, MCPSSEServerConfig
from openhands.mcp import MCPClient, create_mcp_clients, fetch_mcp_tools_from_config
@@ -24,10 +24,13 @@ async def test_sse_connection_timeout():
# Mock the MCPClient constructor to return our mock
with mock.patch('openhands.mcp.utils.MCPClient', return_value=mock_client):
# Create a list of server URLs to test
servers = ['http://server1:8080', 'http://server2:8080']
servers = [
MCPSSEServerConfig(url='http://server1:8080'),
MCPSSEServerConfig(url='http://server2:8080'),
]
# Call create_mcp_clients with the server URLs
clients = await create_mcp_clients(mcp_servers=servers)
clients = await create_mcp_clients(sse_servers=servers)
# Verify that no clients were successfully connected
assert len(clients) == 0
@@ -46,7 +49,7 @@ async def test_fetch_mcp_tools_with_timeout():
mock_config = mock.MagicMock(spec=MCPConfig)
# Configure the mock config
mock_config.mcp_servers = ['http://server1:8080']
mock_config.sse_servers = ['http://server1:8080']
# Mock create_mcp_clients to return an empty list (simulating all connections failing)
with mock.patch('openhands.mcp.utils.create_mcp_clients', return_value=[]):
@@ -64,7 +67,7 @@ async def test_mixed_connection_results():
mock_config = mock.MagicMock(spec=MCPConfig)
# Configure the mock config
mock_config.mcp_servers = ['http://server1:8080', 'http://server2:8080']
mock_config.sse_servers = ['http://server1:8080', 'http://server2:8080']
# Create a successful client
successful_client = mock.MagicMock(spec=MCPClient)

View File

@@ -0,0 +1,165 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Import the module, not the functions directly to avoid circular imports
import openhands.mcp.utils
from openhands.core.config.mcp_config import MCPSSEServerConfig
from openhands.events.action.mcp import MCPAction
from openhands.events.observation.mcp import MCPObservation
@pytest.mark.asyncio
async def test_create_mcp_clients_empty():
"""Test creating MCP clients with empty server list."""
clients = await openhands.mcp.utils.create_mcp_clients([])
assert clients == []
@pytest.mark.asyncio
@patch('openhands.mcp.utils.MCPClient')
async def test_create_mcp_clients_success(mock_mcp_client):
"""Test successful creation of MCP clients."""
# Setup mock
mock_client_instance = AsyncMock()
mock_mcp_client.return_value = mock_client_instance
mock_client_instance.connect_sse = AsyncMock()
# Test with two servers
server_configs = [
MCPSSEServerConfig(url='http://server1:8080'),
MCPSSEServerConfig(url='http://server2:8080', api_key='test-key'),
]
clients = await openhands.mcp.utils.create_mcp_clients(server_configs)
# Verify
assert len(clients) == 2
assert mock_mcp_client.call_count == 2
# Check that connect_sse was called with correct parameters
mock_client_instance.connect_sse.assert_any_call(
'http://server1:8080', api_key=None
)
mock_client_instance.connect_sse.assert_any_call(
'http://server2:8080', api_key='test-key'
)
@pytest.mark.asyncio
@patch('openhands.mcp.utils.MCPClient')
async def test_create_mcp_clients_connection_failure(mock_mcp_client):
"""Test handling of connection failures when creating MCP clients."""
# Setup mock
mock_client_instance = AsyncMock()
mock_mcp_client.return_value = mock_client_instance
# First connection succeeds, second fails
mock_client_instance.connect_sse.side_effect = [
None, # Success
Exception('Connection failed'), # Failure
]
mock_client_instance.disconnect = AsyncMock()
server_configs = [
MCPSSEServerConfig(url='http://server1:8080'),
MCPSSEServerConfig(url='http://server2:8080'),
]
clients = await openhands.mcp.utils.create_mcp_clients(server_configs)
# Verify only one client was successfully created
assert len(clients) == 1
assert mock_client_instance.disconnect.call_count == 1
def test_convert_mcp_clients_to_tools_empty():
"""Test converting empty MCP clients list to tools."""
tools = openhands.mcp.utils.convert_mcp_clients_to_tools(None)
assert tools == []
tools = openhands.mcp.utils.convert_mcp_clients_to_tools([])
assert tools == []
def test_convert_mcp_clients_to_tools():
"""Test converting MCP clients to tools."""
# Create mock clients with tools
mock_client1 = MagicMock()
mock_client2 = MagicMock()
# Create mock tools
mock_tool1 = MagicMock()
mock_tool1.to_param.return_value = {'function': {'name': 'tool1'}}
mock_tool2 = MagicMock()
mock_tool2.to_param.return_value = {'function': {'name': 'tool2'}}
mock_tool3 = MagicMock()
mock_tool3.to_param.return_value = {'function': {'name': 'tool3'}}
# Set up the clients with their tools
mock_client1.tools = [mock_tool1, mock_tool2]
mock_client2.tools = [mock_tool3]
# Convert to tools
tools = openhands.mcp.utils.convert_mcp_clients_to_tools(
[mock_client1, mock_client2]
)
# Verify
assert len(tools) == 3
assert tools[0] == {'function': {'name': 'tool1'}}
assert tools[1] == {'function': {'name': 'tool2'}}
assert tools[2] == {'function': {'name': 'tool3'}}
@pytest.mark.asyncio
async def test_call_tool_mcp_no_clients():
"""Test calling MCP tool with no clients."""
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
with pytest.raises(ValueError, match='No MCP clients found'):
await openhands.mcp.utils.call_tool_mcp([], action)
@pytest.mark.asyncio
async def test_call_tool_mcp_no_matching_client():
"""Test calling MCP tool with no matching client."""
# Create mock client without the requested tool
mock_client = MagicMock()
mock_client.tools = [MagicMock(name='other_tool')]
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
with pytest.raises(ValueError, match='No matching MCP agent found for tool name'):
await openhands.mcp.utils.call_tool_mcp([mock_client], action)
@pytest.mark.asyncio
async def test_call_tool_mcp_success():
"""Test successful MCP tool call."""
# Create mock client with the requested tool
mock_client = MagicMock()
mock_tool = MagicMock()
# Set the name attribute properly for the tool
mock_tool.name = 'test_tool'
mock_client.tools = [mock_tool]
# Setup response
mock_response = MagicMock()
mock_response.model_dump.return_value = {'result': 'success'}
# Setup call_tool method
mock_client.call_tool = AsyncMock(return_value=mock_response)
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
# Call the function
observation = await openhands.mcp.utils.call_tool_mcp([mock_client], action)
# Verify
assert isinstance(observation, MCPObservation)
assert json.loads(observation.content) == {'result': 'success'}
mock_client.call_tool.assert_called_once_with('test_tool', {'arg1': 'value1'})

View File

@@ -21,7 +21,9 @@ from openhands.events.stream import EventStream
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.storage.memory import InMemoryFileStore
@@ -77,7 +79,7 @@ def mock_agent():
async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent):
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
# Create a mock runtime
runtime = MagicMock(spec=Runtime)
runtime = MagicMock(spec=ActionExecutionClient)
runtime.event_stream = event_stream
# Mock Memory method to raise an exception
@@ -106,7 +108,7 @@ async def test_memory_on_workspace_context_recall_exception_handling(
):
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
# Create a mock runtime
runtime = MagicMock(spec=Runtime)
runtime = MagicMock(spec=ActionExecutionClient)
runtime.event_stream = event_stream
# Mock Memory._on_workspace_context_recall to raise an exception