diff --git a/openhands/cli/main.py b/openhands/cli/main.py index 798c0ee0df..c9beae8d57 100644 --- a/openhands/cli/main.py +++ b/openhands/cli/main.py @@ -263,14 +263,12 @@ async def run_session( # Add MCP tools to the agent if agent.config.enable_mcp: # Add OpenHands' MCP server by default - openhands_mcp_server, openhands_mcp_stdio_servers = ( + _, openhands_mcp_stdio_servers = ( OpenHandsMCPConfigImpl.create_default_mcp_server_config( config.mcp_host, config, None ) ) - # FIXME: OpenHands' SSE server may not be running when CLI mode is started - # if openhands_mcp_server: - # config.mcp.sse_servers.append(openhands_mcp_server) + config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers) await add_mcp_tools_to_agent(agent, runtime, memory, config) diff --git a/openhands/core/config/mcp_config.py b/openhands/core/config/mcp_config.py index 58d368060c..a3fdc93c33 100644 --- a/openhands/core/config/mcp_config.py +++ b/openhands/core/config/mcp_config.py @@ -54,6 +54,10 @@ class MCPStdioServerConfig(BaseModel): and set(self.env.items()) == set(other.env.items()) ) +class MCPSHTTPServerConfig(BaseModel): + url: str + api_key: str | None = None + class MCPConfig(BaseModel): """Configuration for MCP (Message Control Protocol) settings. @@ -65,11 +69,12 @@ class MCPConfig(BaseModel): sse_servers: list[MCPSSEServerConfig] = Field(default_factory=list) stdio_servers: list[MCPStdioServerConfig] = Field(default_factory=list) + shttp_servers: list[MCPSHTTPServerConfig] = Field(default_factory=list) model_config = {'extra': 'forbid'} @staticmethod - def _normalize_sse_servers(servers_data: list[dict | str]) -> list[dict]: + def _normalize_servers(servers_data: list[dict | str]) -> list[dict]: """Helper method to normalize SSE server configurations.""" normalized = [] for server in servers_data: @@ -82,8 +87,13 @@ class MCPConfig(BaseModel): @model_validator(mode='before') def convert_string_urls(cls, data): """Convert string URLs to MCPSSEServerConfig objects.""" - if isinstance(data, dict) and 'sse_servers' in data: - data['sse_servers'] = cls._normalize_sse_servers(data['sse_servers']) + if isinstance(data, dict): + if 'sse_servers' in data: + data['sse_servers'] = cls._normalize_servers(data['sse_servers']) + + if 'shttp_servers' in data: + data['shttp_servers'] = cls._normalize_servers(data['shttp_servers']) + return data def validate_servers(self) -> None: @@ -119,7 +129,7 @@ class MCPConfig(BaseModel): try: # Convert all entries in sse_servers to MCPSSEServerConfig objects if 'sse_servers' in data: - data['sse_servers'] = cls._normalize_sse_servers(data['sse_servers']) + data['sse_servers'] = cls._normalize_servers(data['sse_servers']) servers = [] for server in data['sse_servers']: servers.append(MCPSSEServerConfig(**server)) @@ -132,6 +142,13 @@ class MCPConfig(BaseModel): servers.append(MCPStdioServerConfig(**server)) data['stdio_servers'] = servers + if 'shttp_servers' in data: + data['shttp_servers'] = cls._normalize_servers(data['shttp_servers']) + servers = [] + for server in data['shttp_servers']: + servers.append(MCPSHTTPServerConfig(**server)) + data['shttp_servers'] = servers + # Create SSE config if present mcp_config = MCPConfig.model_validate(data) mcp_config.validate_servers() @@ -169,7 +186,7 @@ class OpenHandsMCPConfig: @staticmethod def create_default_mcp_server_config( host: str, config: 'OpenHandsConfig', user_id: str | None = None - ) -> tuple[MCPSSEServerConfig, list[MCPStdioServerConfig]]: + ) -> tuple[MCPSHTTPServerConfig, list[MCPStdioServerConfig]]: """ Create a default MCP server configuration. @@ -179,12 +196,13 @@ class OpenHandsMCPConfig: Returns: tuple[MCPSSEServerConfig, list[MCPStdioServerConfig]]: A tuple containing the default SSE server configuration and a list of MCP stdio server configurations """ - sse_server = MCPSSEServerConfig(url=f'http://{host}/mcp/sse', api_key=None) stdio_servers = [] search_engine_stdio_server = OpenHandsMCPConfig.add_search_engine(config) if search_engine_stdio_server: stdio_servers.append(search_engine_stdio_server) - return sse_server, stdio_servers + + shttp_servers = MCPSHTTPServerConfig(url=f'http://{host}/mcp/mcp', api_key=None) + return shttp_servers, stdio_servers openhands_mcp_config_cls = os.environ.get( diff --git a/openhands/core/main.py b/openhands/core/main.py index 8303cfebc3..885241ca3e 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -134,14 +134,11 @@ async def run_controller( # Add MCP tools to the agent if agent.config.enable_mcp: # Add OpenHands' MCP server by default - openhands_mcp_server, openhands_mcp_stdio_servers = ( + _, openhands_mcp_stdio_servers = ( OpenHandsMCPConfigImpl.create_default_mcp_server_config( config.mcp_host, config, None ) ) - # FIXME: OpenHands' SSE server may not be running when headless mode is started - # if openhands_mcp_server: - # config.mcp.sse_servers.append(openhands_mcp_server) config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers) await add_mcp_tools_to_agent(agent, runtime, memory, config) diff --git a/openhands/mcp/client.py b/openhands/mcp/client.py index d4062a1d36..14957716ca 100644 --- a/openhands/mcp/client.py +++ b/openhands/mcp/client.py @@ -1,9 +1,11 @@ import asyncio +import datetime from contextlib import AsyncExitStack from typing import Optional from mcp import ClientSession from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client from pydantic import BaseModel, Field from openhands.core.logger import openhands_logger as logger @@ -58,14 +60,21 @@ class MCPClient(BaseModel): if conversation_id: headers['X-OpenHands-Conversation-ID'] = conversation_id + # Convert float timeout to datetime.timedelta for consistency + timeout_delta = datetime.timedelta(seconds=timeout) + streams_context = sse_client( url=server_url, headers=headers if headers else None, timeout=timeout, ) streams = await self.exit_stack.enter_async_context(streams_context) + # For SSE client, we only get read_stream and write_stream (2 values) + read_stream, write_stream = streams self.session = await self.exit_stack.enter_async_context( - ClientSession(*streams) + ClientSession( + read_stream, write_stream, read_timeout_seconds=timeout_delta + ) ) await self._initialize_and_list_tools() @@ -117,6 +126,77 @@ class MCPClient(BaseModel): raise RuntimeError('Client session is not available.') return await self.session.call_tool(name=tool_name, arguments=args) + async def connect_shttp( + self, + server_url: str, + api_key: str | None = None, + conversation_id: str | None = None, + timeout: float = 30.0, + ) -> None: + """Connect to an MCP server using StreamableHTTP transport. + + Args: + server_url: The URL of the StreamableHTTP server to connect to. + api_key: Optional API key for authentication. + conversation_id: Optional conversation ID for session tracking. + timeout: Connection timeout in seconds. Default is 30 seconds. + """ + if not server_url: + raise ValueError('Server URL is required.') + if self.session: + await self.disconnect() + + try: + # Use asyncio.wait_for to enforce the timeout + async def connect_with_timeout(): + headers = ( + { + 'Authorization': f'Bearer {api_key}', + 's': api_key, # We need this for action execution server's MCP Router + 'X-Session-API-Key': api_key, # We need this for Remote Runtime + } + if api_key + else {} + ) + + if conversation_id: + headers['X-OpenHands-Conversation-ID'] = conversation_id + + # Convert float timeout to datetime.timedelta + timeout_delta = datetime.timedelta(seconds=timeout) + sse_read_timeout_delta = datetime.timedelta( + seconds=timeout * 10 + ) # 10x longer for read timeout + + streams_context = streamablehttp_client( + url=server_url, + headers=headers if headers else None, + timeout=timeout_delta, + sse_read_timeout=sse_read_timeout_delta, + ) + streams = await self.exit_stack.enter_async_context(streams_context) + # For StreamableHTTP client, we get read_stream, write_stream, and get_session_id (3 values) + read_stream, write_stream, _ = streams + self.session = await self.exit_stack.enter_async_context( + ClientSession( + read_stream, write_stream, read_timeout_seconds=timeout_delta + ) + ) + await self._initialize_and_list_tools() + + # Apply timeout to the entire connection process + await asyncio.wait_for(connect_with_timeout(), timeout=timeout) + except asyncio.TimeoutError: + logger.error( + f'Connection to {server_url} timed out after {timeout} seconds' + ) + await self.disconnect() # Clean up resources + raise # Re-raise the TimeoutError + except Exception as e: + logger.error(f'Error connecting to {server_url}: {str(e)}') + await self.disconnect() # Clean up resources + raise + async def disconnect(self) -> None: """Disconnect from the MCP server and clean up resources.""" if self.session: diff --git a/openhands/mcp/utils.py b/openhands/mcp/utils.py index fb932edee3..122fee46e4 100644 --- a/openhands/mcp/utils.py +++ b/openhands/mcp/utils.py @@ -4,8 +4,10 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from openhands.controller.agent import Agent + from openhands.core.config.mcp_config import ( MCPConfig, + MCPSHTTPServerConfig, MCPSSEServerConfig, ) from openhands.core.config.openhands_config import OpenHandsConfig @@ -48,7 +50,9 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di async def create_mcp_clients( - sse_servers: list[MCPSSEServerConfig], conversation_id: str | None = None + sse_servers: list[MCPSSEServerConfig], + shttp_servers: list[MCPSHTTPServerConfig], + conversation_id: str | None = None, ) -> list[MCPClient]: import sys @@ -59,42 +63,60 @@ async def create_mcp_clients( ) return [] - mcp_clients: list[MCPClient] = [] - # Initialize SSE connections - if sse_servers: - for server_url in sse_servers: - logger.info( - f'Initializing MCP agent for {server_url} with SSE connection...' - ) + servers: list[MCPSSEServerConfig | MCPSHTTPServerConfig] = sse_servers.copy() + servers.extend(shttp_servers.copy()) - client = MCPClient() - try: + if not servers: + return [] + + mcp_clients = [] + + for server in servers: + is_sse = isinstance(server, MCPSSEServerConfig) + connection_type = 'SSE' if is_sse else 'SHTTP' + logger.info( + f'Initializing MCP agent for {server} with {connection_type} connection...' + ) + client = MCPClient() + + try: + if is_sse: await client.connect_sse( - server_url.url, - api_key=server_url.api_key, + server.url, + api_key=server.api_key, conversation_id=conversation_id, ) - # Only add the client to the list after a successful connection - mcp_clients.append(client) - logger.info(f'Connected to MCP server {server_url} via SSE') - except Exception as e: - logger.error( - f'Failed to connect to {server_url}: {str(e)}', exc_info=True + else: + await client.connect_shttp( + server.url, + api_key=server.api_key, + conversation_id=conversation_id, ) - try: - await client.disconnect() - except Exception as disconnect_error: - logger.error( - f'Error during disconnect after failed connection: {str(disconnect_error)}' - ) + # Only add the client to the list after a successful connection + mcp_clients.append(client) + + except Exception as e: + logger.error(f'Failed to connect to {server}: {str(e)}', exc_info=True) + try: + await client.disconnect() + except Exception as disconnect_error: + logger.error( + f'Error during disconnect after failed connection: {str(disconnect_error)}' + ) return mcp_clients -async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]: +async def fetch_mcp_tools_from_config( + mcp_config: MCPConfig, conversation_id: str | None = None +) -> list[dict]: """ Retrieves the list of MCP tools from the MCP clients. + Args: + mcp_config: The MCP configuration + conversation_id: Optional conversation ID to associate with the MCP clients + Returns: A list of tool dictionaries. Returns an empty list if no connections could be established. """ @@ -111,7 +133,7 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]: logger.debug(f'Creating MCP clients with config: {mcp_config}') # Create clients - this will fetch tools but not maintain active connections mcp_clients = await create_mcp_clients( - mcp_config.sse_servers, + mcp_config.sse_servers, mcp_config.shttp_servers, conversation_id ) if not mcp_clients: diff --git a/openhands/runtime/impl/action_execution/action_execution_client.py b/openhands/runtime/impl/action_execution/action_execution_client.py index 47c1534a76..acc003cd02 100644 --- a/openhands/runtime/impl/action_execution/action_execution_client.py +++ b/openhands/runtime/impl/action_execution/action_execution_client.py @@ -464,7 +464,7 @@ class ActionExecutionClient(Runtime): ) # Create clients for this specific operation - mcp_clients = await create_mcp_clients(updated_mcp_config.sse_servers, self.sid) + mcp_clients = await create_mcp_clients(updated_mcp_config.sse_servers, updated_mcp_config.shttp_servers, self.sid) # Call the tool and return the result # No need for try/finally since disconnect() is now just resetting state diff --git a/openhands/server/app.py b/openhands/server/app.py index 82fab085a8..f1c0ad7073 100644 --- a/openhands/server/app.py +++ b/openhands/server/app.py @@ -1,3 +1,4 @@ +import contextlib import warnings from contextlib import asynccontextmanager from typing import AsyncIterator @@ -29,6 +30,20 @@ from openhands.server.routes.settings import app as settings_router from openhands.server.routes.trajectory import app as trajectory_router from openhands.server.shared import conversation_manager +mcp_app = mcp_server.http_app(path='/mcp') + + +def combine_lifespans(*lifespans): + # Create a combined lifespan to manage multiple session managers + @contextlib.asynccontextmanager + async def combined_lifespan(app): + async with contextlib.AsyncExitStack() as stack: + for lifespan in lifespans: + await stack.enter_async_context(lifespan(app)) + yield + + return combined_lifespan + @asynccontextmanager async def _lifespan(app: FastAPI) -> AsyncIterator[None]: @@ -40,8 +55,8 @@ app = FastAPI( title='OpenHands', description='OpenHands: Code Less, Make More', version=__version__, - lifespan=_lifespan, - routes=[Mount(path='/mcp', app=mcp_server.sse_app())], + lifespan=combine_lifespans(_lifespan, mcp_app.lifespan), + routes=[Mount(path='/mcp', app=mcp_app)], ) diff --git a/openhands/server/routes/mcp.py b/openhands/server/routes/mcp.py index 7a3a61b3f3..780824174a 100644 --- a/openhands/server/routes/mcp.py +++ b/openhands/server/routes/mcp.py @@ -19,8 +19,7 @@ from openhands.server.user_auth import ( ) from openhands.storage.data_models.conversation_metadata import ConversationMetadata -mcp_server = FastMCP('mcp', dependencies=get_dependencies()) - +mcp_server = FastMCP('mcp', stateless_http=True, dependencies=get_dependencies()) async def save_pr_metadata( user_id: str, conversation_id: str, tool_result: str diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 4d2c760f94..586f48ea2a 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -138,7 +138,7 @@ class Session: ) ) if openhands_mcp_server: - self.config.mcp.sse_servers.append(openhands_mcp_server) + self.config.mcp.shttp_servers.append(openhands_mcp_server) self.config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers) # TODO: override other LLM config & agent config groups (#2075) diff --git a/tests/unit/test_mcp_client_timeout.py b/tests/unit/test_mcp_client_timeout.py index a55c81e6f3..602a217d26 100644 --- a/tests/unit/test_mcp_client_timeout.py +++ b/tests/unit/test_mcp_client_timeout.py @@ -26,3 +26,24 @@ async def test_connect_sse_timeout(): # Test with a very short timeout with pytest.raises(asyncio.TimeoutError): await client.connect_sse('http://example.com', timeout=0.1) + + +@pytest.mark.asyncio +async def test_connect_streamable_http_timeout(): + """Test that connect_streamable_http properly times out when server_url is invalid.""" + client = MCPClient() + + # Create a mock async context manager that simulates a timeout + @asynccontextmanager + async def mock_slow_context(*args, **kwargs): + # This will hang for longer than our timeout + await asyncio.sleep(10.0) + yield (mock.AsyncMock(), mock.AsyncMock(), mock.AsyncMock()) + + # Patch the streamablehttp_client function to return our slow context manager + with mock.patch( + 'openhands.mcp.client.streamablehttp_client', return_value=mock_slow_context() + ): + # Test with a very short timeout + with pytest.raises(asyncio.TimeoutError): + await client.connect_shttp('http://example.com', timeout=0.1) diff --git a/tests/unit/test_mcp_create_clients_timeout.py b/tests/unit/test_mcp_create_clients_timeout.py index fe73b9027c..53454c2b5f 100644 --- a/tests/unit/test_mcp_create_clients_timeout.py +++ b/tests/unit/test_mcp_create_clients_timeout.py @@ -25,7 +25,7 @@ async def test_create_mcp_clients_timeout_with_invalid_url(): # Call create_mcp_clients with the invalid URL start_time = asyncio.get_event_loop().time() - clients = await create_mcp_clients([invalid_url]) + clients = await create_mcp_clients([invalid_url], []) end_time = asyncio.get_event_loop().time() # Verify that no clients were successfully connected @@ -61,7 +61,7 @@ async def test_create_mcp_clients_with_unreachable_host(): # Call create_mcp_clients with the unreachable URL start_time = asyncio.get_event_loop().time() - clients = await create_mcp_clients([unreachable_url]) + clients = await create_mcp_clients([unreachable_url], []) end_time = asyncio.get_event_loop().time() # Verify that no clients were successfully connected diff --git a/tests/unit/test_mcp_timeout.py b/tests/unit/test_mcp_timeout.py index 8d0ed4b7f5..898820276f 100644 --- a/tests/unit/test_mcp_timeout.py +++ b/tests/unit/test_mcp_timeout.py @@ -30,7 +30,7 @@ async def test_sse_connection_timeout(): ] # Call create_mcp_clients with the server URLs - clients = await create_mcp_clients(sse_servers=servers) + clients = await create_mcp_clients(sse_servers=servers, shttp_servers=[]) # Verify that no clients were successfully connected assert len(clients) == 0 @@ -50,11 +50,12 @@ async def test_fetch_mcp_tools_with_timeout(): # Configure the mock config mock_config.sse_servers = ['http://server1:8080'] + mock_config.shttp_servers = [] # Mock create_mcp_clients to return an empty list (simulating all connections failing) with mock.patch('openhands.mcp.utils.create_mcp_clients', return_value=[]): # Call fetch_mcp_tools_from_config - tools = await fetch_mcp_tools_from_config(mock_config) + tools = await fetch_mcp_tools_from_config(mock_config, None) # Verify that an empty list of tools is returned assert tools == [] @@ -68,6 +69,7 @@ async def test_mixed_connection_results(): # Configure the mock config mock_config.sse_servers = ['http://server1:8080', 'http://server2:8080'] + mock_config.shttp_servers = [] # Create a successful client successful_client = mock.MagicMock(spec=MCPClient) @@ -78,7 +80,7 @@ async def test_mixed_connection_results(): 'openhands.mcp.utils.create_mcp_clients', return_value=[successful_client] ): # Call fetch_mcp_tools_from_config - tools = await fetch_mcp_tools_from_config(mock_config) + tools = await fetch_mcp_tools_from_config(mock_config, None) # Verify that tools were returned assert len(tools) > 0 diff --git a/tests/unit/test_mcp_utils.py b/tests/unit/test_mcp_utils.py index 42c427f8da..91e479a447 100644 --- a/tests/unit/test_mcp_utils.py +++ b/tests/unit/test_mcp_utils.py @@ -13,7 +13,7 @@ from openhands.events.observation.mcp import MCPObservation @pytest.mark.asyncio async def test_create_mcp_clients_empty(): """Test creating MCP clients with empty server list.""" - clients = await openhands.mcp.utils.create_mcp_clients([]) + clients = await openhands.mcp.utils.create_mcp_clients([], []) assert clients == [] @@ -32,7 +32,7 @@ async def test_create_mcp_clients_success(mock_mcp_client): MCPSSEServerConfig(url='http://server2:8080', api_key='test-key'), ] - clients = await openhands.mcp.utils.create_mcp_clients(server_configs) + clients = await openhands.mcp.utils.create_mcp_clients(server_configs, []) # Verify assert len(clients) == 2 @@ -67,7 +67,7 @@ async def test_create_mcp_clients_connection_failure(mock_mcp_client): MCPSSEServerConfig(url='http://server2:8080'), ] - clients = await openhands.mcp.utils.create_mcp_clients(server_configs) + clients = await openhands.mcp.utils.create_mcp_clients(server_configs, []) # Verify only one client was successfully created assert len(clients) == 1