[Feat]: support streamable http mcp (#8864)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra 2025-06-03 13:06:44 -04:00 committed by GitHub
parent 1850d572b5
commit a348840534
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 208 additions and 56 deletions

View File

@ -263,14 +263,12 @@ async def run_session(
# Add MCP tools to the agent
if agent.config.enable_mcp:
# Add OpenHands' MCP server by default
openhands_mcp_server, openhands_mcp_stdio_servers = (
_, openhands_mcp_stdio_servers = (
OpenHandsMCPConfigImpl.create_default_mcp_server_config(
config.mcp_host, config, None
)
)
# FIXME: OpenHands' SSE server may not be running when CLI mode is started
# if openhands_mcp_server:
# config.mcp.sse_servers.append(openhands_mcp_server)
config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
await add_mcp_tools_to_agent(agent, runtime, memory, config)

View File

@ -54,6 +54,10 @@ class MCPStdioServerConfig(BaseModel):
and set(self.env.items()) == set(other.env.items())
)
class MCPSHTTPServerConfig(BaseModel):
url: str
api_key: str | None = None
class MCPConfig(BaseModel):
"""Configuration for MCP (Message Control Protocol) settings.
@ -65,11 +69,12 @@ class MCPConfig(BaseModel):
sse_servers: list[MCPSSEServerConfig] = Field(default_factory=list)
stdio_servers: list[MCPStdioServerConfig] = Field(default_factory=list)
shttp_servers: list[MCPSHTTPServerConfig] = Field(default_factory=list)
model_config = {'extra': 'forbid'}
@staticmethod
def _normalize_sse_servers(servers_data: list[dict | str]) -> list[dict]:
def _normalize_servers(servers_data: list[dict | str]) -> list[dict]:
"""Helper method to normalize SSE server configurations."""
normalized = []
for server in servers_data:
@ -82,8 +87,13 @@ class MCPConfig(BaseModel):
@model_validator(mode='before')
def convert_string_urls(cls, data):
"""Convert string URLs to MCPSSEServerConfig objects."""
if isinstance(data, dict) and 'sse_servers' in data:
data['sse_servers'] = cls._normalize_sse_servers(data['sse_servers'])
if isinstance(data, dict):
if 'sse_servers' in data:
data['sse_servers'] = cls._normalize_servers(data['sse_servers'])
if 'shttp_servers' in data:
data['shttp_servers'] = cls._normalize_servers(data['shttp_servers'])
return data
def validate_servers(self) -> None:
@ -119,7 +129,7 @@ class MCPConfig(BaseModel):
try:
# Convert all entries in sse_servers to MCPSSEServerConfig objects
if 'sse_servers' in data:
data['sse_servers'] = cls._normalize_sse_servers(data['sse_servers'])
data['sse_servers'] = cls._normalize_servers(data['sse_servers'])
servers = []
for server in data['sse_servers']:
servers.append(MCPSSEServerConfig(**server))
@ -132,6 +142,13 @@ class MCPConfig(BaseModel):
servers.append(MCPStdioServerConfig(**server))
data['stdio_servers'] = servers
if 'shttp_servers' in data:
data['shttp_servers'] = cls._normalize_servers(data['shttp_servers'])
servers = []
for server in data['shttp_servers']:
servers.append(MCPSHTTPServerConfig(**server))
data['shttp_servers'] = servers
# Create SSE config if present
mcp_config = MCPConfig.model_validate(data)
mcp_config.validate_servers()
@ -169,7 +186,7 @@ class OpenHandsMCPConfig:
@staticmethod
def create_default_mcp_server_config(
host: str, config: 'OpenHandsConfig', user_id: str | None = None
) -> tuple[MCPSSEServerConfig, list[MCPStdioServerConfig]]:
) -> tuple[MCPSHTTPServerConfig, list[MCPStdioServerConfig]]:
"""
Create a default MCP server configuration.
@ -179,12 +196,13 @@ class OpenHandsMCPConfig:
Returns:
tuple[MCPSSEServerConfig, list[MCPStdioServerConfig]]: A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
"""
sse_server = MCPSSEServerConfig(url=f'http://{host}/mcp/sse', api_key=None)
stdio_servers = []
search_engine_stdio_server = OpenHandsMCPConfig.add_search_engine(config)
if search_engine_stdio_server:
stdio_servers.append(search_engine_stdio_server)
return sse_server, stdio_servers
shttp_servers = MCPSHTTPServerConfig(url=f'http://{host}/mcp/mcp', api_key=None)
return shttp_servers, stdio_servers
openhands_mcp_config_cls = os.environ.get(

View File

@ -134,14 +134,11 @@ async def run_controller(
# Add MCP tools to the agent
if agent.config.enable_mcp:
# Add OpenHands' MCP server by default
openhands_mcp_server, openhands_mcp_stdio_servers = (
_, openhands_mcp_stdio_servers = (
OpenHandsMCPConfigImpl.create_default_mcp_server_config(
config.mcp_host, config, None
)
)
# FIXME: OpenHands' SSE server may not be running when headless mode is started
# if openhands_mcp_server:
# config.mcp.sse_servers.append(openhands_mcp_server)
config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
await add_mcp_tools_to_agent(agent, runtime, memory, config)

View File

@ -1,9 +1,11 @@
import asyncio
import datetime
from contextlib import AsyncExitStack
from typing import Optional
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel, Field
from openhands.core.logger import openhands_logger as logger
@ -58,14 +60,21 @@ class MCPClient(BaseModel):
if conversation_id:
headers['X-OpenHands-Conversation-ID'] = conversation_id
# Convert float timeout to datetime.timedelta for consistency
timeout_delta = datetime.timedelta(seconds=timeout)
streams_context = sse_client(
url=server_url,
headers=headers if headers else None,
timeout=timeout,
)
streams = await self.exit_stack.enter_async_context(streams_context)
# For SSE client, we only get read_stream and write_stream (2 values)
read_stream, write_stream = streams
self.session = await self.exit_stack.enter_async_context(
ClientSession(*streams)
ClientSession(
read_stream, write_stream, read_timeout_seconds=timeout_delta
)
)
await self._initialize_and_list_tools()
@ -117,6 +126,77 @@ class MCPClient(BaseModel):
raise RuntimeError('Client session is not available.')
return await self.session.call_tool(name=tool_name, arguments=args)
async def connect_shttp(
self,
server_url: str,
api_key: str | None = None,
conversation_id: str | None = None,
timeout: float = 30.0,
) -> None:
"""Connect to an MCP server using StreamableHTTP transport.
Args:
server_url: The URL of the StreamableHTTP server to connect to.
api_key: Optional API key for authentication.
conversation_id: Optional conversation ID for session tracking.
timeout: Connection timeout in seconds. Default is 30 seconds.
"""
if not server_url:
raise ValueError('Server URL is required.')
if self.session:
await self.disconnect()
try:
# Use asyncio.wait_for to enforce the timeout
async def connect_with_timeout():
headers = (
{
'Authorization': f'Bearer {api_key}',
's': api_key, # We need this for action execution server's MCP Router
'X-Session-API-Key': api_key, # We need this for Remote Runtime
}
if api_key
else {}
)
if conversation_id:
headers['X-OpenHands-Conversation-ID'] = conversation_id
# Convert float timeout to datetime.timedelta
timeout_delta = datetime.timedelta(seconds=timeout)
sse_read_timeout_delta = datetime.timedelta(
seconds=timeout * 10
) # 10x longer for read timeout
streams_context = streamablehttp_client(
url=server_url,
headers=headers if headers else None,
timeout=timeout_delta,
sse_read_timeout=sse_read_timeout_delta,
)
streams = await self.exit_stack.enter_async_context(streams_context)
# For StreamableHTTP client, we get read_stream, write_stream, and get_session_id (3 values)
read_stream, write_stream, _ = streams
self.session = await self.exit_stack.enter_async_context(
ClientSession(
read_stream, write_stream, read_timeout_seconds=timeout_delta
)
)
await self._initialize_and_list_tools()
# Apply timeout to the entire connection process
await asyncio.wait_for(connect_with_timeout(), timeout=timeout)
except asyncio.TimeoutError:
logger.error(
f'Connection to {server_url} timed out after {timeout} seconds'
)
await self.disconnect() # Clean up resources
raise # Re-raise the TimeoutError
except Exception as e:
logger.error(f'Error connecting to {server_url}: {str(e)}')
await self.disconnect() # Clean up resources
raise
async def disconnect(self) -> None:
"""Disconnect from the MCP server and clean up resources."""
if self.session:

View File

@ -4,8 +4,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from openhands.controller.agent import Agent
from openhands.core.config.mcp_config import (
MCPConfig,
MCPSHTTPServerConfig,
MCPSSEServerConfig,
)
from openhands.core.config.openhands_config import OpenHandsConfig
@ -48,7 +50,9 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di
async def create_mcp_clients(
sse_servers: list[MCPSSEServerConfig], conversation_id: str | None = None
sse_servers: list[MCPSSEServerConfig],
shttp_servers: list[MCPSHTTPServerConfig],
conversation_id: str | None = None,
) -> list[MCPClient]:
import sys
@ -59,42 +63,60 @@ async def create_mcp_clients(
)
return []
mcp_clients: list[MCPClient] = []
# Initialize SSE connections
if sse_servers:
for server_url in sse_servers:
logger.info(
f'Initializing MCP agent for {server_url} with SSE connection...'
)
servers: list[MCPSSEServerConfig | MCPSHTTPServerConfig] = sse_servers.copy()
servers.extend(shttp_servers.copy())
client = MCPClient()
try:
if not servers:
return []
mcp_clients = []
for server in servers:
is_sse = isinstance(server, MCPSSEServerConfig)
connection_type = 'SSE' if is_sse else 'SHTTP'
logger.info(
f'Initializing MCP agent for {server} with {connection_type} connection...'
)
client = MCPClient()
try:
if is_sse:
await client.connect_sse(
server_url.url,
api_key=server_url.api_key,
server.url,
api_key=server.api_key,
conversation_id=conversation_id,
)
# Only add the client to the list after a successful connection
mcp_clients.append(client)
logger.info(f'Connected to MCP server {server_url} via SSE')
except Exception as e:
logger.error(
f'Failed to connect to {server_url}: {str(e)}', exc_info=True
else:
await client.connect_shttp(
server.url,
api_key=server.api_key,
conversation_id=conversation_id,
)
try:
await client.disconnect()
except Exception as disconnect_error:
logger.error(
f'Error during disconnect after failed connection: {str(disconnect_error)}'
)
# Only add the client to the list after a successful connection
mcp_clients.append(client)
except Exception as e:
logger.error(f'Failed to connect to {server}: {str(e)}', exc_info=True)
try:
await client.disconnect()
except Exception as disconnect_error:
logger.error(
f'Error during disconnect after failed connection: {str(disconnect_error)}'
)
return mcp_clients
async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
async def fetch_mcp_tools_from_config(
mcp_config: MCPConfig, conversation_id: str | None = None
) -> list[dict]:
"""
Retrieves the list of MCP tools from the MCP clients.
Args:
mcp_config: The MCP configuration
conversation_id: Optional conversation ID to associate with the MCP clients
Returns:
A list of tool dictionaries. Returns an empty list if no connections could be established.
"""
@ -111,7 +133,7 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
logger.debug(f'Creating MCP clients with config: {mcp_config}')
# Create clients - this will fetch tools but not maintain active connections
mcp_clients = await create_mcp_clients(
mcp_config.sse_servers,
mcp_config.sse_servers, mcp_config.shttp_servers, conversation_id
)
if not mcp_clients:

View File

@ -464,7 +464,7 @@ class ActionExecutionClient(Runtime):
)
# Create clients for this specific operation
mcp_clients = await create_mcp_clients(updated_mcp_config.sse_servers, self.sid)
mcp_clients = await create_mcp_clients(updated_mcp_config.sse_servers, updated_mcp_config.shttp_servers, self.sid)
# Call the tool and return the result
# No need for try/finally since disconnect() is now just resetting state

View File

@ -1,3 +1,4 @@
import contextlib
import warnings
from contextlib import asynccontextmanager
from typing import AsyncIterator
@ -29,6 +30,20 @@ from openhands.server.routes.settings import app as settings_router
from openhands.server.routes.trajectory import app as trajectory_router
from openhands.server.shared import conversation_manager
mcp_app = mcp_server.http_app(path='/mcp')
def combine_lifespans(*lifespans):
# Create a combined lifespan to manage multiple session managers
@contextlib.asynccontextmanager
async def combined_lifespan(app):
async with contextlib.AsyncExitStack() as stack:
for lifespan in lifespans:
await stack.enter_async_context(lifespan(app))
yield
return combined_lifespan
@asynccontextmanager
async def _lifespan(app: FastAPI) -> AsyncIterator[None]:
@ -40,8 +55,8 @@ app = FastAPI(
title='OpenHands',
description='OpenHands: Code Less, Make More',
version=__version__,
lifespan=_lifespan,
routes=[Mount(path='/mcp', app=mcp_server.sse_app())],
lifespan=combine_lifespans(_lifespan, mcp_app.lifespan),
routes=[Mount(path='/mcp', app=mcp_app)],
)

View File

@ -19,8 +19,7 @@ from openhands.server.user_auth import (
)
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
mcp_server = FastMCP('mcp', dependencies=get_dependencies())
mcp_server = FastMCP('mcp', stateless_http=True, dependencies=get_dependencies())
async def save_pr_metadata(
user_id: str, conversation_id: str, tool_result: str

View File

@ -138,7 +138,7 @@ class Session:
)
)
if openhands_mcp_server:
self.config.mcp.sse_servers.append(openhands_mcp_server)
self.config.mcp.shttp_servers.append(openhands_mcp_server)
self.config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
# TODO: override other LLM config & agent config groups (#2075)

View File

@ -26,3 +26,24 @@ async def test_connect_sse_timeout():
# Test with a very short timeout
with pytest.raises(asyncio.TimeoutError):
await client.connect_sse('http://example.com', timeout=0.1)
@pytest.mark.asyncio
async def test_connect_streamable_http_timeout():
"""Test that connect_streamable_http properly times out when server_url is invalid."""
client = MCPClient()
# Create a mock async context manager that simulates a timeout
@asynccontextmanager
async def mock_slow_context(*args, **kwargs):
# This will hang for longer than our timeout
await asyncio.sleep(10.0)
yield (mock.AsyncMock(), mock.AsyncMock(), mock.AsyncMock())
# Patch the streamablehttp_client function to return our slow context manager
with mock.patch(
'openhands.mcp.client.streamablehttp_client', return_value=mock_slow_context()
):
# Test with a very short timeout
with pytest.raises(asyncio.TimeoutError):
await client.connect_shttp('http://example.com', timeout=0.1)

View File

@ -25,7 +25,7 @@ async def test_create_mcp_clients_timeout_with_invalid_url():
# Call create_mcp_clients with the invalid URL
start_time = asyncio.get_event_loop().time()
clients = await create_mcp_clients([invalid_url])
clients = await create_mcp_clients([invalid_url], [])
end_time = asyncio.get_event_loop().time()
# Verify that no clients were successfully connected
@ -61,7 +61,7 @@ async def test_create_mcp_clients_with_unreachable_host():
# Call create_mcp_clients with the unreachable URL
start_time = asyncio.get_event_loop().time()
clients = await create_mcp_clients([unreachable_url])
clients = await create_mcp_clients([unreachable_url], [])
end_time = asyncio.get_event_loop().time()
# Verify that no clients were successfully connected

View File

@ -30,7 +30,7 @@ async def test_sse_connection_timeout():
]
# Call create_mcp_clients with the server URLs
clients = await create_mcp_clients(sse_servers=servers)
clients = await create_mcp_clients(sse_servers=servers, shttp_servers=[])
# Verify that no clients were successfully connected
assert len(clients) == 0
@ -50,11 +50,12 @@ async def test_fetch_mcp_tools_with_timeout():
# Configure the mock config
mock_config.sse_servers = ['http://server1:8080']
mock_config.shttp_servers = []
# Mock create_mcp_clients to return an empty list (simulating all connections failing)
with mock.patch('openhands.mcp.utils.create_mcp_clients', return_value=[]):
# Call fetch_mcp_tools_from_config
tools = await fetch_mcp_tools_from_config(mock_config)
tools = await fetch_mcp_tools_from_config(mock_config, None)
# Verify that an empty list of tools is returned
assert tools == []
@ -68,6 +69,7 @@ async def test_mixed_connection_results():
# Configure the mock config
mock_config.sse_servers = ['http://server1:8080', 'http://server2:8080']
mock_config.shttp_servers = []
# Create a successful client
successful_client = mock.MagicMock(spec=MCPClient)
@ -78,7 +80,7 @@ async def test_mixed_connection_results():
'openhands.mcp.utils.create_mcp_clients', return_value=[successful_client]
):
# Call fetch_mcp_tools_from_config
tools = await fetch_mcp_tools_from_config(mock_config)
tools = await fetch_mcp_tools_from_config(mock_config, None)
# Verify that tools were returned
assert len(tools) > 0

View File

@ -13,7 +13,7 @@ from openhands.events.observation.mcp import MCPObservation
@pytest.mark.asyncio
async def test_create_mcp_clients_empty():
"""Test creating MCP clients with empty server list."""
clients = await openhands.mcp.utils.create_mcp_clients([])
clients = await openhands.mcp.utils.create_mcp_clients([], [])
assert clients == []
@ -32,7 +32,7 @@ async def test_create_mcp_clients_success(mock_mcp_client):
MCPSSEServerConfig(url='http://server2:8080', api_key='test-key'),
]
clients = await openhands.mcp.utils.create_mcp_clients(server_configs)
clients = await openhands.mcp.utils.create_mcp_clients(server_configs, [])
# Verify
assert len(clients) == 2
@ -67,7 +67,7 @@ async def test_create_mcp_clients_connection_failure(mock_mcp_client):
MCPSSEServerConfig(url='http://server2:8080'),
]
clients = await openhands.mcp.utils.create_mcp_clients(server_configs)
clients = await openhands.mcp.utils.create_mcp_clients(server_configs, [])
# Verify only one client was successfully created
assert len(clients) == 1