mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
feat(MCP): MCP refactor, support stdio, and running MCP server in runtime (#7911)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
96
docs/modules/usage/mcp.md
Normal file
96
docs/modules/usage/mcp.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# Model Context Protocol (MCP)
|
||||
|
||||
:::note
|
||||
This page outlines how to configure and use the Model Context Protocol (MCP) in OpenHands, allowing you to extend the agent's capabilities with custom tools.
|
||||
:::
|
||||
|
||||
## Overview
|
||||
|
||||
Model Context Protocol (MCP) is a mechanism that allows OpenHands to communicate with external tool servers. These servers can provide additional functionality to the agent, such as specialized data processing, external API access, or custom tools. MCP is based on the open standard defined at [modelcontextprotocol.io](https://modelcontextprotocol.io).
|
||||
|
||||
## Configuration
|
||||
|
||||
MCP configuration is defined in the `[mcp]` section of your `config.toml` file.
|
||||
|
||||
### Configuration Example
|
||||
|
||||
```toml
|
||||
[mcp]
|
||||
# SSE Servers - External servers that communicate via Server-Sent Events
|
||||
sse_servers = [
|
||||
# Basic SSE server with just a URL
|
||||
"http://example.com:8080/mcp",
|
||||
|
||||
# SSE server with API key authentication
|
||||
{url="https://secure-example.com/mcp", api_key="your-api-key"}
|
||||
]
|
||||
|
||||
# Stdio Servers - Local processes that communicate via standard input/output
|
||||
stdio_servers = [
|
||||
# Basic stdio server
|
||||
{name="fetch", command="uvx", args=["mcp-server-fetch"]},
|
||||
|
||||
# Stdio server with environment variables
|
||||
{
|
||||
name="data-processor",
|
||||
command="python",
|
||||
args=["-m", "my_mcp_server"],
|
||||
env={
|
||||
"DEBUG": "true",
|
||||
"PORT": "8080"
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### SSE Servers
|
||||
|
||||
SSE servers are configured using either a string URL or an object with the following properties:
|
||||
|
||||
- `url` (required)
|
||||
- Type: `str`
|
||||
- Description: The URL of the SSE server
|
||||
|
||||
- `api_key` (optional)
|
||||
- Type: `str`
|
||||
- Default: `None`
|
||||
- Description: API key for authentication with the SSE server
|
||||
|
||||
### Stdio Servers
|
||||
|
||||
Stdio servers are configured using an object with the following properties:
|
||||
|
||||
- `name` (required)
|
||||
- Type: `str`
|
||||
- Description: A unique name for the server
|
||||
|
||||
- `command` (required)
|
||||
- Type: `str`
|
||||
- Description: The command to run the server
|
||||
|
||||
- `args` (optional)
|
||||
- Type: `list of str`
|
||||
- Default: `[]`
|
||||
- Description: Command-line arguments to pass to the server
|
||||
|
||||
- `env` (optional)
|
||||
- Type: `dict of str to str`
|
||||
- Default: `{}`
|
||||
- Description: Environment variables to set for the server process
|
||||
|
||||
## How MCP Works
|
||||
|
||||
When OpenHands starts, it:
|
||||
|
||||
1. Reads the MCP configuration from `config.toml`
|
||||
2. Connects to any configured SSE servers
|
||||
3. Starts any configured stdio servers
|
||||
4. Registers the tools provided by these servers with the agent
|
||||
|
||||
The agent can then use these tools just like any built-in tool. When the agent calls an MCP tool:
|
||||
|
||||
1. OpenHands routes the call to the appropriate MCP server
|
||||
2. The server processes the request and returns a response
|
||||
3. OpenHands converts the response to an observation and presents it to the agent
|
||||
@@ -177,38 +177,30 @@ class CodeActAgent(Agent):
|
||||
}
|
||||
params['tools'] = self.tools
|
||||
|
||||
if self.mcp_tools:
|
||||
# Only add tools with unique names
|
||||
existing_names = {tool['function']['name'] for tool in params['tools']}
|
||||
unique_mcp_tools = [
|
||||
tool
|
||||
for tool in self.mcp_tools
|
||||
if tool['function']['name'] not in existing_names
|
||||
]
|
||||
|
||||
if self.llm.config.model == 'gemini-2.5-pro-preview-03-25':
|
||||
logger.info(
|
||||
f'Removing the default fields from the MCP tools for {self.llm.config.model} '
|
||||
"since it doesn't support them and the request would crash."
|
||||
)
|
||||
# prevent mutation of input tools
|
||||
unique_mcp_tools = copy.deepcopy(unique_mcp_tools)
|
||||
# Strip off default fields that cause errors with gemini-preview
|
||||
for tool in unique_mcp_tools:
|
||||
if 'function' in tool and 'parameters' in tool['function']:
|
||||
if 'properties' in tool['function']['parameters']:
|
||||
for prop_name, prop in tool['function']['parameters'][
|
||||
'properties'
|
||||
].items():
|
||||
if 'default' in prop:
|
||||
del prop['default']
|
||||
|
||||
params['tools'] += unique_mcp_tools
|
||||
# Special handling for Gemini model which doesn't support default fields
|
||||
if self.llm.config.model == 'gemini-2.5-pro-preview-03-25':
|
||||
logger.info(
|
||||
f'Removing the default fields from tools for {self.llm.config.model} '
|
||||
"since it doesn't support them and the request would crash."
|
||||
)
|
||||
# prevent mutation of input tools
|
||||
params['tools'] = copy.deepcopy(params['tools'])
|
||||
# Strip off default fields that cause errors with gemini-preview
|
||||
for tool in params['tools']:
|
||||
if 'function' in tool and 'parameters' in tool['function']:
|
||||
if 'properties' in tool['function']['parameters']:
|
||||
for prop_name, prop in tool['function']['parameters'][
|
||||
'properties'
|
||||
].items():
|
||||
if 'default' in prop:
|
||||
del prop['default']
|
||||
# log to litellm proxy if possible
|
||||
params['extra_body'] = {'metadata': state.to_llm_metadata(agent_name=self.name)}
|
||||
response = self.llm.completion(**params)
|
||||
logger.debug(f'Response from LLM: {response}')
|
||||
actions = self.response_to_actions_fn(response)
|
||||
actions = self.response_to_actions_fn(
|
||||
response, mcp_tool_names=list(self.mcp_tools.keys())
|
||||
)
|
||||
logger.debug(f'Actions after response_to_actions: {actions}')
|
||||
for action in actions:
|
||||
self.pending_actions.append(action)
|
||||
|
||||
@@ -37,10 +37,9 @@ from openhands.events.action import (
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.event import FileEditSource, FileReadSource
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.mcp import MCPClientTool
|
||||
|
||||
|
||||
def combine_thought(action: Action, thought: str) -> Action:
|
||||
@@ -53,7 +52,9 @@ def combine_thought(action: Action, thought: str) -> Action:
|
||||
return action
|
||||
|
||||
|
||||
def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
def response_to_actions(
|
||||
response: ModelResponse, mcp_tool_names: list[str] | None = None
|
||||
) -> list[Action]:
|
||||
actions: list[Action] = []
|
||||
assert len(response.choices) == 1, 'Only one choice is supported for now'
|
||||
choice = response.choices[0]
|
||||
@@ -195,12 +196,12 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
action = BrowseURLAction(url=arguments['url'])
|
||||
|
||||
# ================================================
|
||||
# McpAction (MCP)
|
||||
# MCPAction (MCP)
|
||||
# ================================================
|
||||
elif tool_call.function.name.endswith(MCPClientTool.postfix()):
|
||||
action = McpAction(
|
||||
name=tool_call.function.name.removesuffix(MCPClientTool.postfix()),
|
||||
arguments=tool_call.function.arguments,
|
||||
elif mcp_tool_names and tool_call.function.name in mcp_tool_names:
|
||||
action = MCPAction(
|
||||
name=tool_call.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
else:
|
||||
raise FunctionCallNotExistsError(
|
||||
|
||||
@@ -36,6 +36,7 @@ from openhands.events.action import (
|
||||
BrowseURLAction,
|
||||
CmdRunAction,
|
||||
FileReadAction,
|
||||
MCPAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import FileReadSource
|
||||
@@ -102,7 +103,9 @@ def glob_to_cmdrun(pattern: str, path: str = '.') -> str:
|
||||
return echo_cmd + complete_cmd
|
||||
|
||||
|
||||
def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
def response_to_actions(
|
||||
response: ModelResponse, mcp_tool_names: list[str] | None = None
|
||||
) -> list[Action]:
|
||||
actions: list[Action] = []
|
||||
assert len(response.choices) == 1, 'Only one choice is supported for now'
|
||||
choice = response.choices[0]
|
||||
@@ -198,6 +201,15 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
)
|
||||
action = BrowseURLAction(url=arguments['url'])
|
||||
|
||||
# ================================================
|
||||
# MCPAction (MCP)
|
||||
# ================================================
|
||||
elif mcp_tool_names and tool_call.function.name in mcp_tool_names:
|
||||
action = MCPAction(
|
||||
name=tool_call.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
|
||||
else:
|
||||
raise FunctionCallNotExistsError(
|
||||
f'Tool {tool_call.function.name} is not registered. (arguments: {arguments}). Please check the tool name and retry with an existing tool.'
|
||||
|
||||
@@ -8,6 +8,8 @@ if TYPE_CHECKING:
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from litellm import ChatCompletionToolParam
|
||||
|
||||
from openhands.core.exceptions import (
|
||||
AgentAlreadyRegisteredError,
|
||||
AgentNotRegisteredError,
|
||||
@@ -42,7 +44,7 @@ class Agent(ABC):
|
||||
self.config = config
|
||||
self._complete = False
|
||||
self.prompt_manager: 'PromptManager' | None = None
|
||||
self.mcp_tools: list[dict] = []
|
||||
self.mcp_tools: dict[str, ChatCompletionToolParam] = {}
|
||||
self.tools: list = []
|
||||
|
||||
def get_system_message(self) -> 'SystemMessageAction | None':
|
||||
@@ -160,4 +162,18 @@ class Agent(ABC):
|
||||
Args:
|
||||
- mcp_tools (list[dict]): The list of MCP tools.
|
||||
"""
|
||||
self.mcp_tools = mcp_tools
|
||||
logger.info(
|
||||
f"Setting {len(mcp_tools)} MCP tools for agent {self.name}: {[tool['function']['name'] for tool in mcp_tools]}"
|
||||
)
|
||||
for tool in mcp_tools:
|
||||
_tool = ChatCompletionToolParam(**tool)
|
||||
if _tool['function']['name'] in self.mcp_tools:
|
||||
logger.warning(
|
||||
f"Tool {_tool['function']['name']} already exists, skipping"
|
||||
)
|
||||
continue
|
||||
self.mcp_tools[_tool['function']['name']] = _tool
|
||||
self.tools.append(_tool)
|
||||
logger.info(
|
||||
f"Tools updated for agent {self.name}, total {len(self.tools)}: {[tool['function']['name'] for tool in self.tools]}"
|
||||
)
|
||||
|
||||
@@ -54,7 +54,7 @@ from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
from openhands.io import read_task
|
||||
from openhands.mcp import fetch_mcp_tools_from_config
|
||||
from openhands.mcp import add_mcp_tools_to_agent
|
||||
from openhands.memory.condenser.impl.llm_summarizing_condenser import (
|
||||
LLMSummarizingCondenserConfig,
|
||||
)
|
||||
@@ -112,8 +112,6 @@ async def run_session(
|
||||
)
|
||||
|
||||
agent = create_agent(config)
|
||||
mcp_tools = await fetch_mcp_tools_from_config(config.mcp)
|
||||
agent.set_mcp_tools(mcp_tools)
|
||||
runtime = create_runtime(
|
||||
config,
|
||||
sid=sid,
|
||||
@@ -209,6 +207,7 @@ async def run_session(
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
|
||||
|
||||
await runtime.connect()
|
||||
await add_mcp_tools_to_agent(agent, runtime, config.mcp)
|
||||
|
||||
# Initialize repository if needed
|
||||
repo_directory = None
|
||||
|
||||
@@ -7,6 +7,7 @@ from openhands.core.config.config_utils import (
|
||||
)
|
||||
from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
from openhands.core.config.utils import (
|
||||
@@ -26,6 +27,7 @@ __all__ = [
|
||||
'OH_MAX_ITERATIONS',
|
||||
'AgentConfig',
|
||||
'AppConfig',
|
||||
'MCPConfig',
|
||||
'LLMConfig',
|
||||
'SandboxConfig',
|
||||
'SecurityConfig',
|
||||
|
||||
@@ -1,28 +1,59 @@
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
|
||||
class MCPSSEServerConfig(BaseModel):
|
||||
"""Configuration for a single MCP server.
|
||||
|
||||
Attributes:
|
||||
url: The server URL
|
||||
api_key: Optional API key for authentication
|
||||
"""
|
||||
|
||||
url: str
|
||||
api_key: str | None = None
|
||||
|
||||
|
||||
class MCPStdioServerConfig(BaseModel):
|
||||
"""Configuration for a MCP server that uses stdio.
|
||||
|
||||
Attributes:
|
||||
name: The name of the server
|
||||
command: The command to run the server
|
||||
args: The arguments to pass to the server
|
||||
env: The environment variables to set for the server
|
||||
"""
|
||||
|
||||
name: str
|
||||
command: str
|
||||
args: list[str] = Field(default_factory=list)
|
||||
env: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Configuration for MCP (Message Control Protocol) settings.
|
||||
|
||||
Attributes:
|
||||
mcp_servers: List of MCP SSE (Server-Sent Events) server URLs.
|
||||
sse_servers: List of MCP SSE server configs
|
||||
stdio_servers: List of MCP stdio server configs. These servers will be added to the MCP Router running inside runtime container.
|
||||
"""
|
||||
|
||||
mcp_servers: List[str] = Field(default_factory=list)
|
||||
sse_servers: list[MCPSSEServerConfig] = Field(default_factory=list)
|
||||
stdio_servers: list[MCPStdioServerConfig] = Field(default_factory=list)
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
def validate_servers(self) -> None:
|
||||
"""Validate that server URLs are valid and unique."""
|
||||
urls = [server.url for server in self.sse_servers]
|
||||
|
||||
# Check for duplicate server URLs
|
||||
if len(set(self.mcp_servers)) != len(self.mcp_servers):
|
||||
if len(set(urls)) != len(urls):
|
||||
raise ValueError('Duplicate MCP server URLs are not allowed')
|
||||
|
||||
# Validate URLs
|
||||
for url in self.mcp_servers:
|
||||
for url in urls:
|
||||
try:
|
||||
result = urlparse(url)
|
||||
if not all([result.scheme, result.netloc]):
|
||||
@@ -44,11 +75,32 @@ class MCPConfig(BaseModel):
|
||||
mcp_mapping: dict[str, MCPConfig] = {}
|
||||
|
||||
try:
|
||||
# Convert all entries in sse_servers to MCPSSEServerConfig objects
|
||||
if 'sse_servers' in data:
|
||||
servers = []
|
||||
for server in data['sse_servers']:
|
||||
if isinstance(server, dict):
|
||||
servers.append(MCPSSEServerConfig(**server))
|
||||
else:
|
||||
# Convert string URLs to MCPSSEServerConfig objects with no API key
|
||||
servers.append(MCPSSEServerConfig(url=server))
|
||||
data['sse_servers'] = servers
|
||||
|
||||
# Convert all entries in stdio_servers to MCPStdioServerConfig objects
|
||||
if 'stdio_servers' in data:
|
||||
servers = []
|
||||
for server in data['stdio_servers']:
|
||||
servers.append(MCPStdioServerConfig(**server))
|
||||
data['stdio_servers'] = servers
|
||||
|
||||
# Create SSE config if present
|
||||
mcp_config = MCPConfig.model_validate(data)
|
||||
mcp_config.validate_servers()
|
||||
# Create the main MCP config
|
||||
mcp_mapping['mcp'] = cls(mcp_servers=mcp_config.mcp_servers)
|
||||
mcp_mapping['mcp'] = cls(
|
||||
sse_servers=mcp_config.sse_servers,
|
||||
stdio_servers=mcp_config.stdio_servers,
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f'Invalid MCP configuration: {e}')
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from openhands.events.action.action import Action
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.io import read_input, read_task
|
||||
from openhands.mcp import fetch_mcp_tools_from_config
|
||||
from openhands.mcp import add_mcp_tools_to_agent
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
@@ -96,8 +96,6 @@ async def run_controller(
|
||||
|
||||
if agent is None:
|
||||
agent = create_agent(config)
|
||||
mcp_tools = await fetch_mcp_tools_from_config(config.mcp)
|
||||
agent.set_mcp_tools(mcp_tools)
|
||||
|
||||
# when the runtime is created, it will be connected and clone the selected repository
|
||||
repo_directory = None
|
||||
@@ -118,6 +116,8 @@ async def run_controller(
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
)
|
||||
|
||||
await add_mcp_tools_to_agent(agent, runtime, config.mcp)
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
|
||||
# when memory is created, it will load the microagents from the selected repository
|
||||
|
||||
@@ -15,7 +15,7 @@ from openhands.events.action.files import (
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
|
||||
__all__ = [
|
||||
@@ -37,5 +37,5 @@ __all__ = [
|
||||
'ActionConfirmationStatus',
|
||||
'AgentThinkAction',
|
||||
'RecallAction',
|
||||
'McpAction',
|
||||
'MCPAction',
|
||||
]
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action, ActionSecurityRisk
|
||||
|
||||
|
||||
@dataclass
|
||||
class McpAction(Action):
|
||||
class MCPAction(Action):
|
||||
name: str
|
||||
arguments: str | None = None
|
||||
arguments: dict[str, Any] = field(default_factory=dict)
|
||||
thought: str = ''
|
||||
action: str = ActionType.MCP
|
||||
runnable: ClassVar[bool] = True
|
||||
@@ -24,7 +24,7 @@ class McpAction(Action):
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**McpAction**\n'
|
||||
ret = '**MCPAction**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'NAME: {self.name}\n'
|
||||
|
||||
@@ -21,6 +21,7 @@ from openhands.events.observation.files import (
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.observation.success import SuccessObservation
|
||||
|
||||
@@ -22,7 +22,7 @@ from openhands.events.action.files import (
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
|
||||
actions = (
|
||||
@@ -43,7 +43,7 @@ actions = (
|
||||
MessageAction,
|
||||
SystemMessageAction,
|
||||
CondensationAction,
|
||||
McpAction,
|
||||
MCPAction,
|
||||
)
|
||||
|
||||
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from openhands.mcp.client import MCPClient
|
||||
from openhands.mcp.tool import (
|
||||
BaseTool,
|
||||
MCPClientTool,
|
||||
)
|
||||
from openhands.mcp.tool import MCPClientTool
|
||||
from openhands.mcp.utils import (
|
||||
add_mcp_tools_to_agent,
|
||||
call_tool_mcp,
|
||||
convert_mcp_clients_to_tools,
|
||||
create_mcp_clients,
|
||||
@@ -14,8 +12,8 @@ __all__ = [
|
||||
'MCPClient',
|
||||
'convert_mcp_clients_to_tools',
|
||||
'create_mcp_clients',
|
||||
'BaseTool',
|
||||
'MCPClientTool',
|
||||
'fetch_mcp_tools_from_config',
|
||||
'call_tool_mcp',
|
||||
'add_mcp_tools_to_agent',
|
||||
]
|
||||
|
||||
@@ -7,7 +7,7 @@ from mcp.client.sse import sse_client
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.mcp.tool import BaseTool, MCPClientTool
|
||||
from openhands.mcp.tool import MCPClientTool
|
||||
|
||||
|
||||
class MCPClient(BaseModel):
|
||||
@@ -18,13 +18,15 @@ class MCPClient(BaseModel):
|
||||
session: Optional[ClientSession] = None
|
||||
exit_stack: AsyncExitStack = AsyncExitStack()
|
||||
description: str = 'MCP client tools for server interaction'
|
||||
tools: List[BaseTool] = Field(default_factory=list)
|
||||
tool_map: Dict[str, BaseTool] = Field(default_factory=dict)
|
||||
tools: List[MCPClientTool] = Field(default_factory=list)
|
||||
tool_map: Dict[str, MCPClientTool] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def connect_sse(self, server_url: str, timeout: float = 30.0) -> None:
|
||||
async def connect_sse(
|
||||
self, server_url: str, api_key: str | None = None, timeout: float = 30.0
|
||||
) -> None:
|
||||
"""Connect to an MCP server using SSE transport.
|
||||
|
||||
Args:
|
||||
@@ -41,7 +43,8 @@ class MCPClient(BaseModel):
|
||||
async def connect_with_timeout():
|
||||
streams_context = sse_client(
|
||||
url=server_url,
|
||||
timeout=timeout, # Pass the timeout to sse_client
|
||||
headers={'Authorization': f'Bearer {api_key}'} if api_key else None,
|
||||
timeout=timeout,
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(streams_context)
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
@@ -92,7 +95,10 @@ class MCPClient(BaseModel):
|
||||
"""Call a tool on the MCP server."""
|
||||
if tool_name not in self.tool_map:
|
||||
raise ValueError(f'Tool {tool_name} not found.')
|
||||
return await self.tool_map[tool_name].execute(**args)
|
||||
# The MCPClientTool is primarily for metadata; use the session to call the actual tool.
|
||||
if not self.session:
|
||||
raise RuntimeError('Client session is not available.')
|
||||
return await self.session.call_tool(name=tool_name, arguments=args)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the MCP server and clean up resources."""
|
||||
|
||||
@@ -1,54 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.types import CallToolResult, TextContent, Tool
|
||||
from mcp.types import Tool
|
||||
|
||||
|
||||
class BaseTool(ABC, Tool):
|
||||
@classmethod
|
||||
def postfix(cls) -> str:
|
||||
return '_mcp_tool_call'
|
||||
class MCPClientTool(Tool):
|
||||
"""
|
||||
Represents a tool proxy that can be called on the MCP server from the client side.
|
||||
|
||||
This version doesn't store a session reference, as sessions are created on-demand
|
||||
by the MCPClient for each operation.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> CallToolResult:
|
||||
"""Execute the tool with given parameters."""
|
||||
|
||||
def to_param(self) -> Dict:
|
||||
"""Convert tool to function call format."""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': self.name + self.postfix(),
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'parameters': self.inputSchema,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MCPClientTool(BaseTool):
|
||||
"""Represents a tool proxy that can be called on the MCP server from the client side."""
|
||||
|
||||
session: Optional[ClientSession] = None
|
||||
|
||||
async def execute(self, **kwargs) -> CallToolResult:
|
||||
"""Execute the tool by making a remote call to the MCP server."""
|
||||
if not self.session:
|
||||
return CallToolResult(
|
||||
content=[TextContent(text='Not connected to MCP server', type='text')],
|
||||
isError=True,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self.session.call_tool(self.name, kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
return CallToolResult(
|
||||
content=[
|
||||
TextContent(text=f'Error executing tool: {str(e)}', type='text')
|
||||
],
|
||||
isError=True,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
if TYPE_CHECKING:
|
||||
from openhands.controller.agent import Agent
|
||||
|
||||
from openhands.core.config.mcp_config import MCPConfig, MCPSSEServerConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.mcp.client import MCPClient
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
|
||||
def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[dict]:
|
||||
@@ -38,19 +43,19 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di
|
||||
|
||||
|
||||
async def create_mcp_clients(
|
||||
mcp_servers: list[str],
|
||||
sse_servers: list[MCPSSEServerConfig],
|
||||
) -> list[MCPClient]:
|
||||
mcp_clients: list[MCPClient] = []
|
||||
# Initialize SSE connections
|
||||
if mcp_servers:
|
||||
for server_url in mcp_servers:
|
||||
if sse_servers:
|
||||
for server_url in sse_servers:
|
||||
logger.info(
|
||||
f'Initializing MCP agent for {server_url} with SSE connection...'
|
||||
)
|
||||
|
||||
client = MCPClient()
|
||||
try:
|
||||
await client.connect_sse(server_url)
|
||||
await client.connect_sse(server_url.url, api_key=server_url.api_key)
|
||||
# Only add the client to the list after a successful connection
|
||||
mcp_clients.append(client)
|
||||
logger.info(f'Connected to MCP server {server_url} via SSE')
|
||||
@@ -77,14 +82,16 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
|
||||
mcp_tools = []
|
||||
try:
|
||||
logger.debug(f'Creating MCP clients with config: {mcp_config}')
|
||||
# Create clients - this will fetch tools but not maintain active connections
|
||||
mcp_clients = await create_mcp_clients(
|
||||
mcp_config.mcp_servers,
|
||||
mcp_config.sse_servers,
|
||||
)
|
||||
|
||||
if not mcp_clients:
|
||||
logger.debug('No MCP clients were successfully connected')
|
||||
return []
|
||||
|
||||
# Convert tools to the format expected by the agent
|
||||
mcp_tools = convert_mcp_clients_to_tools(mcp_clients)
|
||||
|
||||
# Always disconnect clients to clean up resources
|
||||
@@ -93,6 +100,7 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
|
||||
await mcp_client.disconnect()
|
||||
except Exception as disconnect_error:
|
||||
logger.error(f'Error disconnecting MCP client: {str(disconnect_error)}')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Error fetching MCP tools: {str(e)}')
|
||||
return []
|
||||
@@ -101,13 +109,13 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
|
||||
return mcp_tools
|
||||
|
||||
|
||||
async def call_tool_mcp(mcp_clients: list[MCPClient], action: McpAction) -> Observation:
|
||||
async def call_tool_mcp(mcp_clients: list[MCPClient], action: MCPAction) -> Observation:
|
||||
"""
|
||||
Call a tool on an MCP server and return the observation.
|
||||
|
||||
Args:
|
||||
mcp_clients: The list of MCP clients to execute the action on
|
||||
action: The MCP action to execute
|
||||
sse_mcp_servers: List of SSE MCP server URLs
|
||||
|
||||
Returns:
|
||||
The observation from the MCP server
|
||||
@@ -116,20 +124,55 @@ async def call_tool_mcp(mcp_clients: list[MCPClient], action: McpAction) -> Obse
|
||||
raise ValueError('No MCP clients found')
|
||||
|
||||
logger.debug(f'MCP action received: {action}')
|
||||
# Find the MCP agent that has the matching tool name
|
||||
|
||||
# Find the MCP client that has the matching tool name
|
||||
matching_client = None
|
||||
logger.debug(f'MCP clients: {mcp_clients}')
|
||||
logger.debug(f'MCP action name: {action.name}')
|
||||
|
||||
for client in mcp_clients:
|
||||
logger.debug(f'MCP client tools: {client.tools}')
|
||||
if action.name in [tool.name for tool in client.tools]:
|
||||
matching_client = client
|
||||
break
|
||||
|
||||
if matching_client is None:
|
||||
raise ValueError(f'No matching MCP agent found for tool name: {action.name}')
|
||||
|
||||
logger.debug(f'Matching client: {matching_client}')
|
||||
args_dict = json.loads(action.arguments) if action.arguments else {}
|
||||
response = await matching_client.call_tool(action.name, args_dict)
|
||||
|
||||
# Call the tool - this will create a new connection internally
|
||||
response = await matching_client.call_tool(action.name, action.arguments)
|
||||
logger.debug(f'MCP response: {response}')
|
||||
|
||||
return MCPObservation(content=f'MCP result:{response.model_dump(mode="json")}')
|
||||
return MCPObservation(content=json.dumps(response.model_dump(mode='json')))
|
||||
|
||||
|
||||
async def add_mcp_tools_to_agent(
|
||||
agent: 'Agent', runtime: Runtime, mcp_config: MCPConfig
|
||||
):
|
||||
"""
|
||||
Add MCP tools to an agent.
|
||||
"""
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient, # inline import to avoid circular import
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
runtime, ActionExecutionClient
|
||||
), 'Runtime must be an instance of ActionExecutionClient'
|
||||
assert (
|
||||
runtime.runtime_initialized
|
||||
), 'Runtime must be initialized before adding MCP tools'
|
||||
|
||||
# Add the runtime as another MCP server
|
||||
updated_mcp_config = runtime.get_updated_mcp_config()
|
||||
# Fetch the MCP tools
|
||||
mcp_tools = await fetch_mcp_tools_from_config(updated_mcp_config)
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(mcp_tools)} MCP tools: {[tool['function']['name'] for tool in mcp_tools]}"
|
||||
)
|
||||
|
||||
# Set the MCP tools on the agent
|
||||
agent.set_mcp_tools(mcp_tools)
|
||||
|
||||
@@ -19,7 +19,7 @@ from openhands.events.action import (
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.observation import (
|
||||
@@ -184,7 +184,7 @@ class ConversationMemory:
|
||||
- BrowseInteractiveAction: For browsing the web
|
||||
- AgentFinishAction: For ending the interaction
|
||||
- MessageAction: For sending messages
|
||||
- McpAction: For interacting with the MCP server
|
||||
- MCPAction: For interacting with the MCP server
|
||||
pending_tool_call_action_messages: Dictionary mapping response IDs to their corresponding messages.
|
||||
Used in function calling mode to track tool calls that are waiting for their results.
|
||||
|
||||
@@ -210,7 +210,7 @@ class ConversationMemory:
|
||||
FileReadAction,
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
McpAction,
|
||||
MCPAction,
|
||||
),
|
||||
) or (isinstance(action, CmdRunAction) and action.source == 'agent'):
|
||||
tool_metadata = action.tool_call_metadata
|
||||
|
||||
@@ -8,6 +8,8 @@ NOTE: this will be executed inside the docker sandbox.
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import shutil
|
||||
@@ -23,6 +25,8 @@ from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from mcpm import MCPRouter, RouterConfig
|
||||
from mcpm.router.router import logger as mcp_router_logger
|
||||
from openhands_aci.editor.editor import OHEditor
|
||||
from openhands_aci.editor.exceptions import ToolError
|
||||
from openhands_aci.editor.results import ToolResult
|
||||
@@ -68,6 +72,9 @@ from openhands.runtime.utils.runtime_init import init_user_and_working_directory
|
||||
from openhands.runtime.utils.system_stats import get_system_stats
|
||||
from openhands.utils.async_utils import call_sync_from_async, wait_all
|
||||
|
||||
# Set MCP router logger to the same level as the main logger
|
||||
mcp_router_logger.setLevel(logger.getEffectiveLevel())
|
||||
|
||||
|
||||
class ActionRequest(BaseModel):
|
||||
action: dict
|
||||
@@ -572,10 +579,15 @@ if __name__ == '__main__':
|
||||
plugins_to_load.append(ALL_PLUGINS[plugin]()) # type: ignore
|
||||
|
||||
client: ActionExecutor | None = None
|
||||
mcp_router: MCPRouter | None = None
|
||||
MCP_ROUTER_PROFILE_PATH = os.path.join(
|
||||
os.path.dirname(__file__), 'mcp', 'config.json'
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global client
|
||||
global client, mcp_router
|
||||
logger.info('Initializing ActionExecutor...')
|
||||
client = ActionExecutor(
|
||||
plugins_to_load,
|
||||
work_dir=args.working_dir,
|
||||
@@ -584,9 +596,70 @@ if __name__ == '__main__':
|
||||
browsergym_eval_env=args.browsergym_eval_env,
|
||||
)
|
||||
await client.ainit()
|
||||
logger.info('ActionExecutor initialized.')
|
||||
|
||||
# Initialize and mount MCP Router
|
||||
logger.info('Initializing MCP Router...')
|
||||
mcp_router = MCPRouter(
|
||||
profile_path=MCP_ROUTER_PROFILE_PATH,
|
||||
router_config=RouterConfig(
|
||||
api_key=SESSION_API_KEY,
|
||||
auth_enabled=bool(SESSION_API_KEY),
|
||||
),
|
||||
)
|
||||
allowed_origins = ['*']
|
||||
sse_app = await mcp_router.get_sse_server_app(
|
||||
allow_origins=allowed_origins, include_lifespan=False
|
||||
)
|
||||
|
||||
# Check for route conflicts before mounting
|
||||
main_app_routes = {route.path for route in app.routes}
|
||||
sse_app_routes = {route.path for route in sse_app.routes}
|
||||
conflicting_routes = main_app_routes.intersection(sse_app_routes)
|
||||
|
||||
if conflicting_routes:
|
||||
logger.error(f'Route conflicts detected: {conflicting_routes}')
|
||||
raise RuntimeError(
|
||||
f'Cannot mount SSE app - conflicting routes found: {conflicting_routes}'
|
||||
)
|
||||
|
||||
app.mount('/', sse_app)
|
||||
logger.info(
|
||||
f'Mounted MCP Router SSE app at root path with allowed origins: {allowed_origins}'
|
||||
)
|
||||
|
||||
# Additional debug logging
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug('Main app routes:')
|
||||
for route in main_app_routes:
|
||||
logger.debug(f' {route}')
|
||||
logger.debug('MCP SSE server app routes:')
|
||||
for route in sse_app_routes:
|
||||
logger.debug(f' {route}')
|
||||
|
||||
yield
|
||||
|
||||
# Clean up & release the resources
|
||||
client.close()
|
||||
logger.info('Shutting down MCP Router...')
|
||||
if mcp_router:
|
||||
try:
|
||||
await mcp_router.shutdown()
|
||||
logger.info('MCP Router shutdown successfully.')
|
||||
except Exception as e:
|
||||
logger.error(f'Error shutting down MCP Router: {e}', exc_info=True)
|
||||
else:
|
||||
logger.info('MCP Router instance not found for shutdown.')
|
||||
|
||||
logger.info('Closing ActionExecutor...')
|
||||
if client:
|
||||
try:
|
||||
client.close()
|
||||
logger.info('ActionExecutor closed successfully.')
|
||||
except Exception as e:
|
||||
logger.error(f'Error closing ActionExecutor: {e}', exc_info=True)
|
||||
else:
|
||||
logger.info('ActionExecutor instance not found for closing.')
|
||||
logger.info('Shutdown complete.')
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@@ -663,6 +736,51 @@ if __name__ == '__main__':
|
||||
detail=traceback.format_exc(),
|
||||
)
|
||||
|
||||
@app.post('/update_mcp_server')
|
||||
async def update_mcp_server(request: Request):
|
||||
assert mcp_router is not None
|
||||
assert os.path.exists(MCP_ROUTER_PROFILE_PATH)
|
||||
|
||||
# Use synchronous file operations outside of async function
|
||||
def read_profile():
|
||||
with open(MCP_ROUTER_PROFILE_PATH, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
current_profile = read_profile()
|
||||
assert 'default' in current_profile
|
||||
assert isinstance(current_profile['default'], list)
|
||||
|
||||
# Get the request body
|
||||
mcp_tools_to_sync = await request.json()
|
||||
if not isinstance(mcp_tools_to_sync, list):
|
||||
raise HTTPException(
|
||||
status_code=400, detail='Request must be a list of MCP tools to sync'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Updating MCP server to: {json.dumps(mcp_tools_to_sync, indent=2)}.\nPrevious profile: {json.dumps(current_profile, indent=2)}'
|
||||
)
|
||||
current_profile['default'] = mcp_tools_to_sync
|
||||
|
||||
# Use synchronous file operations outside of async function
|
||||
def write_profile(profile):
|
||||
with open(MCP_ROUTER_PROFILE_PATH, 'w') as f:
|
||||
json.dump(profile, f)
|
||||
|
||||
write_profile(current_profile)
|
||||
|
||||
# Manually reload the profile and update the servers
|
||||
mcp_router.profile_manager.reload()
|
||||
servers_wait_for_update = mcp_router.get_unique_servers()
|
||||
await mcp_router.update_servers(servers_wait_for_update)
|
||||
logger.info(
|
||||
f'MCP router updated successfully with unique servers: {servers_wait_for_update}'
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200, content={'detail': 'MCP server updated successfully'}
|
||||
)
|
||||
|
||||
@app.post('/upload_file')
|
||||
async def upload_file(
|
||||
file: UploadFile, destination: str = '/', recursive: bool = False
|
||||
|
||||
@@ -30,7 +30,7 @@ from openhands.events.action import (
|
||||
FileWriteAction,
|
||||
IPythonRunCellAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentThinkObservation,
|
||||
@@ -282,9 +282,8 @@ class Runtime(FileEditRuntimeMixin):
|
||||
assert event.timeout is not None
|
||||
try:
|
||||
await self._export_latest_git_provider_tokens(event)
|
||||
if isinstance(event, McpAction):
|
||||
# we don't call call_tool_mcp impl directly because there can be other action ActionExecutionClient
|
||||
observation: Observation = await getattr(self, McpAction.action)(event)
|
||||
if isinstance(event, MCPAction):
|
||||
observation: Observation = await self.call_tool_mcp(event)
|
||||
else:
|
||||
observation = await call_sync_from_async(self.run_action, event)
|
||||
except Exception as e:
|
||||
@@ -571,7 +570,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool_mcp(self, action: McpAction) -> Observation:
|
||||
async def call_tool_mcp(self, action: MCPAction) -> Observation:
|
||||
pass
|
||||
|
||||
# ====================================================================
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
@@ -10,6 +11,7 @@ import httpx
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig, MCPSSEServerConfig
|
||||
from openhands.core.exceptions import (
|
||||
AgentRuntimeTimeoutError,
|
||||
)
|
||||
@@ -27,7 +29,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.files import FileEditSource
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation import (
|
||||
AgentThinkObservation,
|
||||
ErrorObservation,
|
||||
@@ -38,16 +40,12 @@ from openhands.events.observation import (
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.mcp import MCPClient, create_mcp_clients
|
||||
from openhands.mcp import call_tool_mcp as call_tool_mcp_handler
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils.request import send_request
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
from openhands.utils.http_session import HttpSession
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
def _is_retryable_error(exception):
|
||||
return isinstance(
|
||||
exception, (httpx.RemoteProtocolError, httpcore.RemoteProtocolError)
|
||||
@@ -79,7 +77,6 @@ class ActionExecutionClient(Runtime):
|
||||
self._runtime_initialized: bool = False
|
||||
self._runtime_closed: bool = False
|
||||
self._vscode_token: str | None = None # initial dummy value
|
||||
self.mcp_clients: list[MCPClient] | None = None
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
@@ -329,19 +326,59 @@ class ActionExecutionClient(Runtime):
|
||||
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
return self.send_action_for_execution(action)
|
||||
|
||||
async def call_tool_mcp(self, action: McpAction) -> Observation:
|
||||
if self.mcp_clients is None:
|
||||
self.log(
|
||||
'debug',
|
||||
f'Creating MCP clients with servers: {self.config.mcp.mcp_servers}',
|
||||
)
|
||||
self.mcp_clients = await create_mcp_clients(self.config.mcp.mcp_servers)
|
||||
return await call_tool_mcp_handler(self.mcp_clients, action)
|
||||
def get_updated_mcp_config(self) -> MCPConfig:
|
||||
# Add the runtime as another MCP server
|
||||
updated_mcp_config = self.config.mcp.model_copy()
|
||||
# Send a request to the action execution server to updated MCP config
|
||||
stdio_tools = [
|
||||
server.model_dump(mode='json')
|
||||
for server in updated_mcp_config.stdio_servers
|
||||
]
|
||||
self.log('debug', f'Updating MCP server to: {stdio_tools}')
|
||||
response = self._send_action_server_request(
|
||||
'POST',
|
||||
f'{self.action_execution_server_url}/update_mcp_server',
|
||||
json=stdio_tools,
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f'Failed to update MCP server: {response.text}')
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self.mcp_clients:
|
||||
for client in self.mcp_clients:
|
||||
await client.disconnect()
|
||||
# No API key by default. Child runtime can override this when appropriate
|
||||
updated_mcp_config.sse_servers.append(
|
||||
MCPSSEServerConfig(
|
||||
url=self.action_execution_server_url.rstrip('/') + '/sse', api_key=None
|
||||
)
|
||||
)
|
||||
self.log(
|
||||
'debug',
|
||||
f'Updated MCP config by adding runtime as another server: {updated_mcp_config}',
|
||||
)
|
||||
return updated_mcp_config
|
||||
|
||||
async def call_tool_mcp(self, action: MCPAction) -> Observation:
|
||||
# Import here to avoid circular imports
|
||||
from openhands.mcp.utils import create_mcp_clients, call_tool_mcp as call_tool_mcp_handler
|
||||
|
||||
# Get the updated MCP config
|
||||
updated_mcp_config = self.get_updated_mcp_config()
|
||||
self.log(
|
||||
'debug',
|
||||
f'Creating MCP clients with servers: {updated_mcp_config.sse_servers}',
|
||||
)
|
||||
|
||||
# Create clients for this specific operation
|
||||
mcp_clients = await create_mcp_clients(updated_mcp_config.sse_servers)
|
||||
|
||||
# Call the tool and return the result
|
||||
# No need for try/finally since disconnect() is now just resetting state
|
||||
result = await call_tool_mcp_handler(mcp_clients, action)
|
||||
|
||||
# Reset client state (no active connections to worry about)
|
||||
for client in mcp_clients:
|
||||
await client.disconnect()
|
||||
|
||||
return result
|
||||
|
||||
def close(self) -> None:
|
||||
# Make sure we don't close the session multiple times
|
||||
@@ -350,4 +387,3 @@ class ActionExecutionClient(Runtime):
|
||||
return
|
||||
self._runtime_closed = True
|
||||
self.session.close()
|
||||
call_async_from_sync(self.aclose)
|
||||
|
||||
3
openhands/runtime/mcp/config.json
Normal file
3
openhands/runtime/mcp/config.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"default": []
|
||||
}
|
||||
@@ -34,10 +34,12 @@ RUN apt-get update && \
|
||||
{% if 'ubuntu' in base_image %}
|
||||
RUN ln -s "$(dirname $(which node))/corepack" /usr/local/bin/corepack && \
|
||||
npm install -g corepack && corepack enable yarn && \
|
||||
curl -fsSL --compressed https://install.python-poetry.org | python - && \
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
curl -fsSL --compressed https://install.python-poetry.org | python -
|
||||
{% endif %}
|
||||
|
||||
# Install uv (required by MCP)
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Remove UID 1000 named pn or ubuntu, so the 'openhands' user can be created from ubuntu hosts
|
||||
RUN (if getent passwd 1000 | grep -q pn; then userdel pn; fi) && \
|
||||
(if getent passwd 1000 | grep -q ubuntu; then userdel ubuntu; fi)
|
||||
|
||||
@@ -18,6 +18,7 @@ from openhands.events.event import Event, EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.mcp import add_mcp_tools_to_agent
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroagent
|
||||
from openhands.runtime import get_runtime_cls
|
||||
@@ -124,6 +125,11 @@ class AgentSession:
|
||||
selected_branch=selected_branch,
|
||||
)
|
||||
|
||||
# NOTE: this needs to happen before controller is created
|
||||
# so MCP tools can be included into the SystemMessageAction
|
||||
if self.runtime and runtime_connected:
|
||||
await add_mcp_tools_to_agent(agent, self.runtime, config.mcp)
|
||||
|
||||
if replay_json:
|
||||
initial_message = self._run_replay(
|
||||
initial_message,
|
||||
@@ -148,6 +154,7 @@ class AgentSession:
|
||||
repo_directory = None
|
||||
if self.runtime and runtime_connected and selected_repository:
|
||||
repo_directory = selected_repository.full_name.split('/')[-1]
|
||||
|
||||
self.memory = await self._create_memory(
|
||||
selected_repository=selected_repository,
|
||||
repo_directory=repo_directory,
|
||||
|
||||
@@ -25,7 +25,6 @@ from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.mcp import fetch_mcp_tools_from_config
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
@@ -147,9 +146,7 @@ class Session:
|
||||
self.logger.info(f'Enabling default condenser: {default_condenser_config}')
|
||||
agent_config.condenser = default_condenser_config
|
||||
|
||||
mcp_tools = await fetch_mcp_tools_from_config(self.config.mcp)
|
||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||
agent.set_mcp_tools(mcp_tools)
|
||||
|
||||
git_provider_tokens = None
|
||||
selected_repository = None
|
||||
|
||||
869
poetry.lock
generated
869
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -76,6 +76,8 @@ mcp = "1.6.0"
|
||||
python-json-logger = "^3.2.1"
|
||||
playwright = "^1.51.0"
|
||||
prompt-toolkit = "^3.0.50"
|
||||
mcpm = "1.8.0"
|
||||
poetry = "^2.1.2"
|
||||
anyio = "4.9.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
@@ -99,6 +101,7 @@ gevent = ">=24.2.1,<26.0.0"
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
@@ -128,6 +131,7 @@ ignore = ["D1"]
|
||||
convention = "google"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
import pytest
|
||||
from pytest import TempPathFactory
|
||||
|
||||
from openhands.core.config import AppConfig, load_app_config
|
||||
from openhands.core.config import AppConfig, MCPConfig, load_app_config
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.runtime.base import Runtime
|
||||
@@ -214,6 +214,7 @@ def _load_runtime(
|
||||
force_rebuild_runtime: bool = False,
|
||||
runtime_startup_env_vars: dict[str, str] | None = None,
|
||||
docker_runtime_kwargs: dict[str, str] | None = None,
|
||||
override_mcp_config: MCPConfig | None = None,
|
||||
) -> tuple[Runtime, AppConfig]:
|
||||
sid = 'rt_' + str(random.randint(100000, 999999))
|
||||
|
||||
@@ -256,6 +257,9 @@ def _load_runtime(
|
||||
config.sandbox.base_container_image = base_container_image
|
||||
config.sandbox.runtime_container_image = None
|
||||
|
||||
if override_mcp_config is not None:
|
||||
config.mcp = override_mcp_config
|
||||
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
event_stream = EventStream(sid, file_store)
|
||||
|
||||
|
||||
78
tests/runtime/test_mcp_action.py
Normal file
78
tests/runtime/test_mcp_action.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Bash-related tests for the DockerRuntime, which connects to the ActionExecutor running in the sandbox."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from conftest import (
|
||||
_load_runtime,
|
||||
)
|
||||
|
||||
import openhands
|
||||
from openhands.core.config import MCPConfig
|
||||
from openhands.core.config.mcp_config import MCPStdioServerConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import CmdRunAction, MCPAction
|
||||
from openhands.events.observation import CmdOutputObservation, MCPObservation
|
||||
|
||||
# ============================================================================================================================
|
||||
# Bash-specific tests
|
||||
# ============================================================================================================================
|
||||
|
||||
|
||||
def test_default_activated_tools():
|
||||
project_root = os.path.dirname(openhands.__file__)
|
||||
mcp_config_path = os.path.join(project_root, 'runtime', 'mcp', 'config.json')
|
||||
assert os.path.exists(
|
||||
mcp_config_path
|
||||
), f'MCP config file not found at {mcp_config_path}'
|
||||
with open(mcp_config_path, 'r') as f:
|
||||
mcp_config = json.load(f)
|
||||
assert 'default' in mcp_config
|
||||
# no tools are always activated yet
|
||||
assert len(mcp_config['default']) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_mcp_via_stdio(temp_dir, runtime_cls, run_as_openhands):
|
||||
mcp_stdio_server_config = MCPStdioServerConfig(
|
||||
name='fetch', command='uvx', args=['mcp-server-fetch']
|
||||
)
|
||||
override_mcp_config = MCPConfig(stdio_servers=[mcp_stdio_server_config])
|
||||
runtime, config = _load_runtime(
|
||||
temp_dir, runtime_cls, run_as_openhands, override_mcp_config=override_mcp_config
|
||||
)
|
||||
|
||||
# Test browser server
|
||||
action_cmd = CmdRunAction(command='python3 -m http.server 8000 > server.log 2>&1 &')
|
||||
logger.info(action_cmd, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action_cmd)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
assert '[1]' in obs.content
|
||||
|
||||
action_cmd = CmdRunAction(command='sleep 3 && cat server.log')
|
||||
logger.info(action_cmd, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action_cmd)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
mcp_action = MCPAction(name='fetch', arguments={'url': 'http://localhost:8000'})
|
||||
obs = await runtime.call_tool_mcp(mcp_action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(
|
||||
obs, MCPObservation
|
||||
), 'The observation should be a MCPObservation.'
|
||||
|
||||
result_json = json.loads(obs.content)
|
||||
assert not result_json['isError']
|
||||
assert len(result_json['content']) == 1
|
||||
assert result_json['content'][0]['type'] == 'text'
|
||||
assert (
|
||||
result_json['content'][0]['text']
|
||||
== 'Contents of http://localhost:8000/:\n---\n\n* <server.log>\n\n---'
|
||||
)
|
||||
|
||||
runtime.close()
|
||||
@@ -32,6 +32,9 @@ from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@@ -83,8 +86,12 @@ def test_event_stream():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime() -> Runtime:
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
|
||||
runtime = MagicMock(
|
||||
spec=Runtime,
|
||||
spec=ActionExecutionClient,
|
||||
event_stream=test_event_stream,
|
||||
)
|
||||
return runtime
|
||||
@@ -233,7 +240,7 @@ async def test_run_controller_with_fatal_error(
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
def on_event(event: Event):
|
||||
if isinstance(event, CmdRunAction):
|
||||
@@ -298,7 +305,7 @@ async def test_run_controller_stop_with_stuck(
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
def on_event(event: Event):
|
||||
if isinstance(event, CmdRunAction):
|
||||
@@ -660,7 +667,7 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
|
||||
mock_agent.step = agent_step_fn
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
def on_event(event: Event):
|
||||
if isinstance(event, CmdRunAction):
|
||||
@@ -1064,7 +1071,7 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
||||
|
||||
mock_agent.step = agent_step_fn
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Create a real Memory instance
|
||||
|
||||
@@ -11,7 +11,9 @@ from openhands.events import EventStream, EventStreamSubscriber
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
@@ -58,7 +60,7 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
mock_runtime = MagicMock(spec=Runtime)
|
||||
mock_runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
# Mock the runtime creation to set up the runtime attribute
|
||||
async def mock_create_runtime(*args, **kwargs):
|
||||
@@ -141,7 +143,7 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
mock_runtime = MagicMock(spec=Runtime)
|
||||
mock_runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
# Mock the runtime creation to set up the runtime attribute
|
||||
async def mock_create_runtime(*args, **kwargs):
|
||||
|
||||
@@ -123,7 +123,7 @@ def mock_settings_store():
|
||||
@patch('openhands.core.cli.display_runtime_initialization_message')
|
||||
@patch('openhands.core.cli.display_initialization_animation')
|
||||
@patch('openhands.core.cli.create_agent')
|
||||
@patch('openhands.core.cli.fetch_mcp_tools_from_config')
|
||||
@patch('openhands.core.cli.add_mcp_tools_to_agent')
|
||||
@patch('openhands.core.cli.create_runtime')
|
||||
@patch('openhands.core.cli.create_controller')
|
||||
@patch('openhands.core.cli.create_memory')
|
||||
@@ -137,7 +137,7 @@ async def test_run_session_without_initial_action(
|
||||
mock_create_memory,
|
||||
mock_create_controller,
|
||||
mock_create_runtime,
|
||||
mock_fetch_mcp_tools,
|
||||
mock_add_mcp_tools,
|
||||
mock_create_agent,
|
||||
mock_display_animation,
|
||||
mock_display_runtime_init,
|
||||
@@ -154,9 +154,6 @@ async def test_run_session_without_initial_action(
|
||||
mock_agent = AsyncMock()
|
||||
mock_create_agent.return_value = mock_agent
|
||||
|
||||
mock_mcp_tools = []
|
||||
mock_fetch_mcp_tools.return_value = mock_mcp_tools
|
||||
|
||||
mock_runtime = AsyncMock()
|
||||
mock_runtime.event_stream = MagicMock()
|
||||
mock_create_runtime.return_value = mock_runtime
|
||||
@@ -193,8 +190,9 @@ async def test_run_session_without_initial_action(
|
||||
mock_display_runtime_init.assert_called_once_with('local')
|
||||
mock_display_animation.assert_called_once()
|
||||
mock_create_agent.assert_called_once_with(mock_config)
|
||||
mock_fetch_mcp_tools.assert_called_once()
|
||||
mock_agent.set_mcp_tools.assert_called_once_with(mock_mcp_tools)
|
||||
mock_add_mcp_tools.assert_called_once_with(
|
||||
mock_agent, mock_runtime, mock_config.mcp
|
||||
)
|
||||
mock_create_runtime.assert_called_once()
|
||||
mock_create_controller.assert_called_once()
|
||||
mock_create_memory.assert_called_once()
|
||||
@@ -213,7 +211,7 @@ async def test_run_session_without_initial_action(
|
||||
@patch('openhands.core.cli.display_runtime_initialization_message')
|
||||
@patch('openhands.core.cli.display_initialization_animation')
|
||||
@patch('openhands.core.cli.create_agent')
|
||||
@patch('openhands.core.cli.fetch_mcp_tools_from_config')
|
||||
@patch('openhands.core.cli.add_mcp_tools_to_agent')
|
||||
@patch('openhands.core.cli.create_runtime')
|
||||
@patch('openhands.core.cli.create_controller')
|
||||
@patch('openhands.core.cli.create_memory')
|
||||
@@ -227,7 +225,7 @@ async def test_run_session_with_initial_action(
|
||||
mock_create_memory,
|
||||
mock_create_controller,
|
||||
mock_create_runtime,
|
||||
mock_fetch_mcp_tools,
|
||||
mock_add_mcp_tools,
|
||||
mock_create_agent,
|
||||
mock_display_animation,
|
||||
mock_display_runtime_init,
|
||||
@@ -244,9 +242,6 @@ async def test_run_session_with_initial_action(
|
||||
mock_agent = AsyncMock()
|
||||
mock_create_agent.return_value = mock_agent
|
||||
|
||||
mock_mcp_tools = []
|
||||
mock_fetch_mcp_tools.return_value = mock_mcp_tools
|
||||
|
||||
mock_runtime = AsyncMock()
|
||||
mock_runtime.event_stream = MagicMock()
|
||||
mock_create_runtime.return_value = mock_runtime
|
||||
|
||||
108
tests/unit/test_mcp_action_observation.py
Normal file
108
tests/unit/test_mcp_action_observation.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import json
|
||||
|
||||
from openhands.core.schema import ActionType, ObservationType
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
|
||||
|
||||
def test_mcp_action_creation():
|
||||
"""Test creating an MCPAction."""
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1', 'arg2': 42})
|
||||
|
||||
assert action.name == 'test_tool'
|
||||
assert action.arguments == {'arg1': 'value1', 'arg2': 42}
|
||||
assert action.action == ActionType.MCP
|
||||
assert action.thought == ''
|
||||
assert action.runnable is True
|
||||
assert action.security_risk is None
|
||||
|
||||
|
||||
def test_mcp_action_with_thought():
|
||||
"""Test creating an MCPAction with a thought."""
|
||||
action = MCPAction(
|
||||
name='test_tool',
|
||||
arguments={'arg1': 'value1', 'arg2': 42},
|
||||
thought='This is a test thought',
|
||||
)
|
||||
|
||||
assert action.name == 'test_tool'
|
||||
assert action.arguments == {'arg1': 'value1', 'arg2': 42}
|
||||
assert action.thought == 'This is a test thought'
|
||||
|
||||
|
||||
def test_mcp_action_message():
|
||||
"""Test the message property of MCPAction."""
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1', 'arg2': 42})
|
||||
|
||||
message = action.message
|
||||
assert 'test_tool' in message
|
||||
assert 'arg1' in message
|
||||
assert 'value1' in message
|
||||
assert '42' in message
|
||||
|
||||
|
||||
def test_mcp_action_str_representation():
|
||||
"""Test the string representation of MCPAction."""
|
||||
action = MCPAction(
|
||||
name='test_tool',
|
||||
arguments={'arg1': 'value1', 'arg2': 42},
|
||||
thought='This is a test thought',
|
||||
)
|
||||
|
||||
str_repr = str(action)
|
||||
assert 'MCPAction' in str_repr
|
||||
assert 'THOUGHT: This is a test thought' in str_repr
|
||||
assert 'NAME: test_tool' in str_repr
|
||||
assert 'ARGUMENTS:' in str_repr
|
||||
assert 'arg1' in str_repr
|
||||
assert 'value1' in str_repr
|
||||
assert '42' in str_repr
|
||||
|
||||
|
||||
def test_mcp_observation_creation():
|
||||
"""Test creating an MCPObservation."""
|
||||
observation = MCPObservation(
|
||||
content=json.dumps({'result': 'success', 'data': 'test data'})
|
||||
)
|
||||
|
||||
assert observation.content == json.dumps({'result': 'success', 'data': 'test data'})
|
||||
assert observation.observation == ObservationType.MCP
|
||||
|
||||
|
||||
def test_mcp_observation_message():
|
||||
"""Test the message property of MCPObservation."""
|
||||
observation = MCPObservation(
|
||||
content=json.dumps({'result': 'success', 'data': 'test data'})
|
||||
)
|
||||
|
||||
message = observation.message
|
||||
assert message == json.dumps({'result': 'success', 'data': 'test data'})
|
||||
assert 'result' in message
|
||||
assert 'success' in message
|
||||
assert 'data' in message
|
||||
assert 'test data' in message
|
||||
|
||||
|
||||
def test_mcp_action_with_complex_arguments():
|
||||
"""Test MCPAction with complex nested arguments."""
|
||||
complex_args = {
|
||||
'simple_arg': 'value',
|
||||
'number_arg': 42,
|
||||
'boolean_arg': True,
|
||||
'nested_arg': {'inner_key': 'inner_value', 'inner_list': [1, 2, 3]},
|
||||
'list_arg': ['a', 'b', 'c'],
|
||||
}
|
||||
|
||||
action = MCPAction(name='complex_tool', arguments=complex_args)
|
||||
|
||||
assert action.name == 'complex_tool'
|
||||
assert action.arguments == complex_args
|
||||
assert action.arguments['nested_arg']['inner_key'] == 'inner_value'
|
||||
assert action.arguments['list_arg'] == ['a', 'b', 'c']
|
||||
|
||||
# Check that the message contains the complex arguments
|
||||
message = action.message
|
||||
assert 'complex_tool' in message
|
||||
assert 'nested_arg' in message
|
||||
assert 'inner_key' in message
|
||||
assert 'inner_value' in message
|
||||
@@ -1,23 +1,33 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.config.mcp_config import (
|
||||
MCPConfig,
|
||||
MCPSSEServerConfig,
|
||||
MCPStdioServerConfig,
|
||||
)
|
||||
|
||||
|
||||
def test_valid_sse_config():
|
||||
"""Test a valid SSE configuration."""
|
||||
config = MCPConfig(mcp_servers=['http://server1:8080', 'http://server2:8080'])
|
||||
config = MCPConfig(
|
||||
sse_servers=[
|
||||
MCPSSEServerConfig(url='http://server1:8080'),
|
||||
MCPSSEServerConfig(url='http://server2:8080'),
|
||||
]
|
||||
)
|
||||
config.validate_servers() # Should not raise any exception
|
||||
|
||||
|
||||
def test_empty_sse_config():
|
||||
"""Test SSE configuration with empty servers list."""
|
||||
config = MCPConfig(mcp_servers=[])
|
||||
config = MCPConfig(sse_servers=[])
|
||||
config.validate_servers()
|
||||
|
||||
|
||||
def test_invalid_sse_url():
|
||||
"""Test SSE configuration with invalid URL format."""
|
||||
config = MCPConfig(mcp_servers=['not_a_url'])
|
||||
config = MCPConfig(sse_servers=[MCPSSEServerConfig(url='not_a_url')])
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
config.validate_servers()
|
||||
assert 'Invalid URL' in str(exc_info.value)
|
||||
@@ -25,7 +35,12 @@ def test_invalid_sse_url():
|
||||
|
||||
def test_duplicate_sse_urls():
|
||||
"""Test SSE configuration with duplicate server URLs."""
|
||||
config = MCPConfig(mcp_servers=['http://server1:8080', 'http://server1:8080'])
|
||||
config = MCPConfig(
|
||||
sse_servers=[
|
||||
MCPSSEServerConfig(url='http://server1:8080'),
|
||||
MCPSSEServerConfig(url='http://server1:8080'),
|
||||
]
|
||||
)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
config.validate_servers()
|
||||
assert 'Duplicate MCP server URLs are not allowed' in str(exc_info.value)
|
||||
@@ -34,17 +49,18 @@ def test_duplicate_sse_urls():
|
||||
def test_from_toml_section_valid():
|
||||
"""Test creating config from valid TOML section."""
|
||||
data = {
|
||||
'mcp_servers': ['http://server1:8080'],
|
||||
'sse_servers': ['http://server1:8080'],
|
||||
}
|
||||
result = MCPConfig.from_toml_section(data)
|
||||
assert 'mcp' in result
|
||||
assert result['mcp'].mcp_servers == ['http://server1:8080']
|
||||
assert len(result['mcp'].sse_servers) == 1
|
||||
assert result['mcp'].sse_servers[0].url == 'http://server1:8080'
|
||||
|
||||
|
||||
def test_from_toml_section_invalid_sse():
|
||||
"""Test creating config from TOML section with invalid SSE URL."""
|
||||
data = {
|
||||
'mcp_servers': ['not_a_url'],
|
||||
'sse_servers': ['not_a_url'],
|
||||
}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
MCPConfig.from_toml_section(data)
|
||||
@@ -54,10 +70,125 @@ def test_from_toml_section_invalid_sse():
|
||||
def test_complex_urls():
|
||||
"""Test SSE configuration with complex URLs."""
|
||||
config = MCPConfig(
|
||||
mcp_servers=[
|
||||
'https://user:pass@server1:8080/path?query=1',
|
||||
'wss://server2:8443/ws',
|
||||
'http://subdomain.example.com:9090',
|
||||
sse_servers=[
|
||||
MCPSSEServerConfig(url='https://user:pass@server1:8080/path?query=1'),
|
||||
MCPSSEServerConfig(url='wss://server2:8443/ws'),
|
||||
MCPSSEServerConfig(url='http://subdomain.example.com:9090'),
|
||||
]
|
||||
)
|
||||
config.validate_servers() # Should not raise any exception
|
||||
|
||||
|
||||
def test_mcp_sse_server_config_with_api_key():
|
||||
"""Test MCPSSEServerConfig with API key."""
|
||||
config = MCPSSEServerConfig(url='http://server1:8080', api_key='test-api-key')
|
||||
assert config.url == 'http://server1:8080'
|
||||
assert config.api_key == 'test-api-key'
|
||||
|
||||
|
||||
def test_mcp_sse_server_config_without_api_key():
|
||||
"""Test MCPSSEServerConfig without API key."""
|
||||
config = MCPSSEServerConfig(url='http://server1:8080')
|
||||
assert config.url == 'http://server1:8080'
|
||||
assert config.api_key is None
|
||||
|
||||
|
||||
def test_mcp_stdio_server_config_basic():
|
||||
"""Test basic MCPStdioServerConfig."""
|
||||
config = MCPStdioServerConfig(name='test-server', command='python')
|
||||
assert config.name == 'test-server'
|
||||
assert config.command == 'python'
|
||||
assert config.args == []
|
||||
assert config.env == {}
|
||||
|
||||
|
||||
def test_mcp_stdio_server_config_with_args_and_env():
|
||||
"""Test MCPStdioServerConfig with args and env."""
|
||||
config = MCPStdioServerConfig(
|
||||
name='test-server',
|
||||
command='python',
|
||||
args=['-m', 'server'],
|
||||
env={'DEBUG': 'true', 'PORT': '8080'},
|
||||
)
|
||||
assert config.name == 'test-server'
|
||||
assert config.command == 'python'
|
||||
assert config.args == ['-m', 'server']
|
||||
assert config.env == {'DEBUG': 'true', 'PORT': '8080'}
|
||||
|
||||
|
||||
def test_mcp_config_with_stdio_servers():
|
||||
"""Test MCPConfig with stdio servers."""
|
||||
stdio_server = MCPStdioServerConfig(
|
||||
name='test-server',
|
||||
command='python',
|
||||
args=['-m', 'server'],
|
||||
env={'DEBUG': 'true'},
|
||||
)
|
||||
config = MCPConfig(stdio_servers=[stdio_server])
|
||||
assert len(config.stdio_servers) == 1
|
||||
assert config.stdio_servers[0].name == 'test-server'
|
||||
assert config.stdio_servers[0].command == 'python'
|
||||
assert config.stdio_servers[0].args == ['-m', 'server']
|
||||
assert config.stdio_servers[0].env == {'DEBUG': 'true'}
|
||||
|
||||
|
||||
def test_from_toml_section_with_stdio_servers():
|
||||
"""Test creating config from TOML section with stdio servers."""
|
||||
data = {
|
||||
'sse_servers': ['http://server1:8080'],
|
||||
'stdio_servers': [
|
||||
{
|
||||
'name': 'test-server',
|
||||
'command': 'python',
|
||||
'args': ['-m', 'server'],
|
||||
'env': {'DEBUG': 'true'},
|
||||
}
|
||||
],
|
||||
}
|
||||
result = MCPConfig.from_toml_section(data)
|
||||
assert 'mcp' in result
|
||||
assert len(result['mcp'].sse_servers) == 1
|
||||
assert result['mcp'].sse_servers[0].url == 'http://server1:8080'
|
||||
assert len(result['mcp'].stdio_servers) == 1
|
||||
assert result['mcp'].stdio_servers[0].name == 'test-server'
|
||||
assert result['mcp'].stdio_servers[0].command == 'python'
|
||||
assert result['mcp'].stdio_servers[0].args == ['-m', 'server']
|
||||
assert result['mcp'].stdio_servers[0].env == {'DEBUG': 'true'}
|
||||
|
||||
|
||||
def test_mcp_config_with_both_server_types():
|
||||
"""Test MCPConfig with both SSE and stdio servers."""
|
||||
sse_server = MCPSSEServerConfig(url='http://server1:8080', api_key='test-api-key')
|
||||
stdio_server = MCPStdioServerConfig(
|
||||
name='test-server',
|
||||
command='python',
|
||||
args=['-m', 'server'],
|
||||
env={'DEBUG': 'true'},
|
||||
)
|
||||
config = MCPConfig(sse_servers=[sse_server], stdio_servers=[stdio_server])
|
||||
assert len(config.sse_servers) == 1
|
||||
assert config.sse_servers[0].url == 'http://server1:8080'
|
||||
assert config.sse_servers[0].api_key == 'test-api-key'
|
||||
assert len(config.stdio_servers) == 1
|
||||
assert config.stdio_servers[0].name == 'test-server'
|
||||
assert config.stdio_servers[0].command == 'python'
|
||||
|
||||
|
||||
def test_mcp_config_model_validation_error():
|
||||
"""Test MCPConfig validation error with invalid data."""
|
||||
with pytest.raises(ValidationError):
|
||||
# Missing required 'url' field
|
||||
MCPSSEServerConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
# Missing required 'name' and 'command' fields
|
||||
MCPStdioServerConfig()
|
||||
|
||||
|
||||
def test_mcp_config_extra_fields_forbidden():
|
||||
"""Test that extra fields are forbidden in MCPConfig."""
|
||||
with pytest.raises(ValidationError):
|
||||
MCPConfig(extra_field='value')
|
||||
|
||||
# Note: The nested models don't have 'extra': 'forbid' set, so they allow extra fields
|
||||
# We're only testing the main MCPConfig class here
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig, MCPSSEServerConfig
|
||||
from openhands.mcp import MCPClient, create_mcp_clients, fetch_mcp_tools_from_config
|
||||
|
||||
|
||||
@@ -24,10 +24,13 @@ async def test_sse_connection_timeout():
|
||||
# Mock the MCPClient constructor to return our mock
|
||||
with mock.patch('openhands.mcp.utils.MCPClient', return_value=mock_client):
|
||||
# Create a list of server URLs to test
|
||||
servers = ['http://server1:8080', 'http://server2:8080']
|
||||
servers = [
|
||||
MCPSSEServerConfig(url='http://server1:8080'),
|
||||
MCPSSEServerConfig(url='http://server2:8080'),
|
||||
]
|
||||
|
||||
# Call create_mcp_clients with the server URLs
|
||||
clients = await create_mcp_clients(mcp_servers=servers)
|
||||
clients = await create_mcp_clients(sse_servers=servers)
|
||||
|
||||
# Verify that no clients were successfully connected
|
||||
assert len(clients) == 0
|
||||
@@ -46,7 +49,7 @@ async def test_fetch_mcp_tools_with_timeout():
|
||||
mock_config = mock.MagicMock(spec=MCPConfig)
|
||||
|
||||
# Configure the mock config
|
||||
mock_config.mcp_servers = ['http://server1:8080']
|
||||
mock_config.sse_servers = ['http://server1:8080']
|
||||
|
||||
# Mock create_mcp_clients to return an empty list (simulating all connections failing)
|
||||
with mock.patch('openhands.mcp.utils.create_mcp_clients', return_value=[]):
|
||||
@@ -64,7 +67,7 @@ async def test_mixed_connection_results():
|
||||
mock_config = mock.MagicMock(spec=MCPConfig)
|
||||
|
||||
# Configure the mock config
|
||||
mock_config.mcp_servers = ['http://server1:8080', 'http://server2:8080']
|
||||
mock_config.sse_servers = ['http://server1:8080', 'http://server2:8080']
|
||||
|
||||
# Create a successful client
|
||||
successful_client = mock.MagicMock(spec=MCPClient)
|
||||
|
||||
165
tests/unit/test_mcp_utils.py
Normal file
165
tests/unit/test_mcp_utils.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Import the module, not the functions directly to avoid circular imports
|
||||
import openhands.mcp.utils
|
||||
from openhands.core.config.mcp_config import MCPSSEServerConfig
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mcp_clients_empty():
|
||||
"""Test creating MCP clients with empty server list."""
|
||||
clients = await openhands.mcp.utils.create_mcp_clients([])
|
||||
assert clients == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.mcp.utils.MCPClient')
|
||||
async def test_create_mcp_clients_success(mock_mcp_client):
|
||||
"""Test successful creation of MCP clients."""
|
||||
# Setup mock
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_mcp_client.return_value = mock_client_instance
|
||||
mock_client_instance.connect_sse = AsyncMock()
|
||||
|
||||
# Test with two servers
|
||||
server_configs = [
|
||||
MCPSSEServerConfig(url='http://server1:8080'),
|
||||
MCPSSEServerConfig(url='http://server2:8080', api_key='test-key'),
|
||||
]
|
||||
|
||||
clients = await openhands.mcp.utils.create_mcp_clients(server_configs)
|
||||
|
||||
# Verify
|
||||
assert len(clients) == 2
|
||||
assert mock_mcp_client.call_count == 2
|
||||
|
||||
# Check that connect_sse was called with correct parameters
|
||||
mock_client_instance.connect_sse.assert_any_call(
|
||||
'http://server1:8080', api_key=None
|
||||
)
|
||||
mock_client_instance.connect_sse.assert_any_call(
|
||||
'http://server2:8080', api_key='test-key'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.mcp.utils.MCPClient')
|
||||
async def test_create_mcp_clients_connection_failure(mock_mcp_client):
|
||||
"""Test handling of connection failures when creating MCP clients."""
|
||||
# Setup mock
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_mcp_client.return_value = mock_client_instance
|
||||
|
||||
# First connection succeeds, second fails
|
||||
mock_client_instance.connect_sse.side_effect = [
|
||||
None, # Success
|
||||
Exception('Connection failed'), # Failure
|
||||
]
|
||||
mock_client_instance.disconnect = AsyncMock()
|
||||
|
||||
server_configs = [
|
||||
MCPSSEServerConfig(url='http://server1:8080'),
|
||||
MCPSSEServerConfig(url='http://server2:8080'),
|
||||
]
|
||||
|
||||
clients = await openhands.mcp.utils.create_mcp_clients(server_configs)
|
||||
|
||||
# Verify only one client was successfully created
|
||||
assert len(clients) == 1
|
||||
assert mock_client_instance.disconnect.call_count == 1
|
||||
|
||||
|
||||
def test_convert_mcp_clients_to_tools_empty():
|
||||
"""Test converting empty MCP clients list to tools."""
|
||||
tools = openhands.mcp.utils.convert_mcp_clients_to_tools(None)
|
||||
assert tools == []
|
||||
|
||||
tools = openhands.mcp.utils.convert_mcp_clients_to_tools([])
|
||||
assert tools == []
|
||||
|
||||
|
||||
def test_convert_mcp_clients_to_tools():
|
||||
"""Test converting MCP clients to tools."""
|
||||
# Create mock clients with tools
|
||||
mock_client1 = MagicMock()
|
||||
mock_client2 = MagicMock()
|
||||
|
||||
# Create mock tools
|
||||
mock_tool1 = MagicMock()
|
||||
mock_tool1.to_param.return_value = {'function': {'name': 'tool1'}}
|
||||
|
||||
mock_tool2 = MagicMock()
|
||||
mock_tool2.to_param.return_value = {'function': {'name': 'tool2'}}
|
||||
|
||||
mock_tool3 = MagicMock()
|
||||
mock_tool3.to_param.return_value = {'function': {'name': 'tool3'}}
|
||||
|
||||
# Set up the clients with their tools
|
||||
mock_client1.tools = [mock_tool1, mock_tool2]
|
||||
mock_client2.tools = [mock_tool3]
|
||||
|
||||
# Convert to tools
|
||||
tools = openhands.mcp.utils.convert_mcp_clients_to_tools(
|
||||
[mock_client1, mock_client2]
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(tools) == 3
|
||||
assert tools[0] == {'function': {'name': 'tool1'}}
|
||||
assert tools[1] == {'function': {'name': 'tool2'}}
|
||||
assert tools[2] == {'function': {'name': 'tool3'}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_mcp_no_clients():
|
||||
"""Test calling MCP tool with no clients."""
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
|
||||
|
||||
with pytest.raises(ValueError, match='No MCP clients found'):
|
||||
await openhands.mcp.utils.call_tool_mcp([], action)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_mcp_no_matching_client():
|
||||
"""Test calling MCP tool with no matching client."""
|
||||
# Create mock client without the requested tool
|
||||
mock_client = MagicMock()
|
||||
mock_client.tools = [MagicMock(name='other_tool')]
|
||||
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
|
||||
|
||||
with pytest.raises(ValueError, match='No matching MCP agent found for tool name'):
|
||||
await openhands.mcp.utils.call_tool_mcp([mock_client], action)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_mcp_success():
|
||||
"""Test successful MCP tool call."""
|
||||
# Create mock client with the requested tool
|
||||
mock_client = MagicMock()
|
||||
mock_tool = MagicMock()
|
||||
# Set the name attribute properly for the tool
|
||||
mock_tool.name = 'test_tool'
|
||||
mock_client.tools = [mock_tool]
|
||||
|
||||
# Setup response
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {'result': 'success'}
|
||||
|
||||
# Setup call_tool method
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_response)
|
||||
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1'})
|
||||
|
||||
# Call the function
|
||||
observation = await openhands.mcp.utils.call_tool_mcp([mock_client], action)
|
||||
|
||||
# Verify
|
||||
assert isinstance(observation, MCPObservation)
|
||||
assert json.loads(observation.content) == {'result': 'success'}
|
||||
mock_client.call_tool.assert_called_once_with('test_tool', {'arg1': 'value1'})
|
||||
@@ -21,7 +21,9 @@ from openhands.events.stream import EventStream
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@@ -77,7 +79,7 @@ def mock_agent():
|
||||
async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent):
|
||||
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
|
||||
# Create a mock runtime
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory method to raise an exception
|
||||
@@ -106,7 +108,7 @@ async def test_memory_on_workspace_context_recall_exception_handling(
|
||||
):
|
||||
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
|
||||
# Create a mock runtime
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory._on_workspace_context_recall to raise an exception
|
||||
|
||||
Reference in New Issue
Block a user