From 07e400b73d5b847f4acc0e82d014a5dbdba16c96 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Tue, 15 Apr 2025 23:04:21 -0400 Subject: [PATCH] refactor(mcp): simplify MCP config & fix timeout (#7820) Co-authored-by: ducphamle2 Co-authored-by: trungbach Co-authored-by: quangdz1704 Co-authored-by: Duc Pham <44611780+ducphamle2@users.noreply.github.com> Co-authored-by: openhands --- openhands/core/config/mcp_config.py | 25 ++---- openhands/mcp/client.py | 29 +++++-- openhands/mcp/utils.py | 10 +-- tests/unit/test_mcp_client_timeout.py | 28 +++++++ tests/unit/test_mcp_config.py | 14 ++-- tests/unit/test_mcp_create_clients_timeout.py | 76 +++++++++++++++++++ tests/unit/test_mcp_timeout.py | 12 ++- 7 files changed, 148 insertions(+), 46 deletions(-) create mode 100644 tests/unit/test_mcp_client_timeout.py create mode 100644 tests/unit/test_mcp_create_clients_timeout.py diff --git a/openhands/core/config/mcp_config.py b/openhands/core/config/mcp_config.py index 1a80f03322..2f362c81b0 100644 --- a/openhands/core/config/mcp_config.py +++ b/openhands/core/config/mcp_config.py @@ -4,11 +4,11 @@ from urllib.parse import urlparse from pydantic import BaseModel, Field, ValidationError -class MCPSSEConfig(BaseModel): - """Configuration for MCP SSE (Server-Sent Events) settings. +class MCPConfig(BaseModel): + """Configuration for MCP (Message Control Protocol) settings. Attributes: - mcp_servers: List of MCP server URLs. + mcp_servers: List of MCP SSE (Server-Sent Events) server URLs. """ mcp_servers: List[str] = Field(default_factory=list) @@ -30,18 +30,6 @@ class MCPSSEConfig(BaseModel): except Exception as e: raise ValueError(f'Invalid URL {url}: {str(e)}') - -class MCPConfig(BaseModel): - """Configuration for MCP (Message Control Protocol) settings. - - Attributes: - sse: SSE-specific configuration. - """ - - sse: MCPSSEConfig = Field(default_factory=MCPSSEConfig) - - model_config = {'extra': 'forbid'} - @classmethod def from_toml_section(cls, data: dict) -> dict[str, 'MCPConfig']: """ @@ -57,11 +45,10 @@ class MCPConfig(BaseModel): try: # Create SSE config if present - sse_config = MCPSSEConfig.model_validate(data) - sse_config.validate_servers() - + mcp_config = MCPConfig.model_validate(data) + mcp_config.validate_servers() # Create the main MCP config - mcp_mapping['mcp'] = cls(sse=sse_config) + mcp_mapping['mcp'] = cls(mcp_servers=mcp_config.mcp_servers) except ValidationError as e: raise ValueError(f'Invalid MCP configuration: {e}') diff --git a/openhands/mcp/client.py b/openhands/mcp/client.py index 1a50aacc63..71d7623632 100644 --- a/openhands/mcp/client.py +++ b/openhands/mcp/client.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import AsyncExitStack from typing import Dict, List, Optional @@ -36,17 +37,29 @@ class MCPClient(BaseModel): await self.disconnect() try: - streams_context = sse_client( - url=server_url, - ) - streams = await self.exit_stack.enter_async_context(streams_context) - self.session = await self.exit_stack.enter_async_context( - ClientSession(*streams) - ) + # Use asyncio.wait_for to enforce the timeout + async def connect_with_timeout(): + streams_context = sse_client( + url=server_url, + timeout=timeout, # Pass the timeout to sse_client + ) + streams = await self.exit_stack.enter_async_context(streams_context) + self.session = await self.exit_stack.enter_async_context( + ClientSession(*streams) + ) + await self._initialize_and_list_tools() - await self._initialize_and_list_tools() + # Apply timeout to the entire connection process + await asyncio.wait_for(connect_with_timeout(), timeout=timeout) + except asyncio.TimeoutError: + logger.error( + f'Connection to {server_url} timed out after {timeout} seconds' + ) + await self.disconnect() # Clean up resources + raise # Re-raise the TimeoutError except Exception as e: logger.error(f'Error connecting to {server_url}: {str(e)}') + await self.disconnect() # Clean up resources raise async def _initialize_and_list_tools(self) -> None: diff --git a/openhands/mcp/utils.py b/openhands/mcp/utils.py index c8894f1ca9..e6fd1403ca 100644 --- a/openhands/mcp/utils.py +++ b/openhands/mcp/utils.py @@ -38,12 +38,12 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di async def create_mcp_clients( - sse_mcp_server: list[str], + mcp_servers: list[str], ) -> list[MCPClient]: mcp_clients: list[MCPClient] = [] # Initialize SSE connections - if sse_mcp_server: - for server_url in sse_mcp_server: + if mcp_servers: + for server_url in mcp_servers: logger.info( f'Initializing MCP agent for {server_url} with SSE connection...' ) @@ -78,11 +78,11 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]: try: logger.debug(f'Creating MCP clients with config: {mcp_config}') mcp_clients = await create_mcp_clients( - mcp_config.sse.mcp_servers, + mcp_config.mcp_servers, ) if not mcp_clients: - logger.warning('No MCP clients were successfully connected') + logger.debug('No MCP clients were successfully connected') return [] mcp_tools = convert_mcp_clients_to_tools(mcp_clients) diff --git a/tests/unit/test_mcp_client_timeout.py b/tests/unit/test_mcp_client_timeout.py new file mode 100644 index 0000000000..a55c81e6f3 --- /dev/null +++ b/tests/unit/test_mcp_client_timeout.py @@ -0,0 +1,28 @@ +import asyncio +from contextlib import asynccontextmanager +from unittest import mock + +import pytest + +from openhands.mcp.client import MCPClient + + +@pytest.mark.asyncio +async def test_connect_sse_timeout(): + """Test that connect_sse properly times out when server_url is invalid.""" + client = MCPClient() + + # Create a mock async context manager that simulates a timeout + @asynccontextmanager + async def mock_slow_context(*args, **kwargs): + # This will hang for longer than our timeout + await asyncio.sleep(10.0) + yield (mock.AsyncMock(), mock.AsyncMock()) + + # Patch the sse_client function to return our slow context manager + with mock.patch( + 'openhands.mcp.client.sse_client', return_value=mock_slow_context() + ): + # Test with a very short timeout + with pytest.raises(asyncio.TimeoutError): + await client.connect_sse('http://example.com', timeout=0.1) diff --git a/tests/unit/test_mcp_config.py b/tests/unit/test_mcp_config.py index c91574025e..f069b3f4da 100644 --- a/tests/unit/test_mcp_config.py +++ b/tests/unit/test_mcp_config.py @@ -1,23 +1,23 @@ import pytest -from openhands.core.config.mcp_config import MCPConfig, MCPSSEConfig +from openhands.core.config.mcp_config import MCPConfig def test_valid_sse_config(): """Test a valid SSE configuration.""" - config = MCPSSEConfig(mcp_servers=['http://server1:8080', 'http://server2:8080']) + config = MCPConfig(mcp_servers=['http://server1:8080', 'http://server2:8080']) config.validate_servers() # Should not raise any exception def test_empty_sse_config(): """Test SSE configuration with empty servers list.""" - config = MCPSSEConfig(mcp_servers=[]) + config = MCPConfig(mcp_servers=[]) config.validate_servers() def test_invalid_sse_url(): """Test SSE configuration with invalid URL format.""" - config = MCPSSEConfig(mcp_servers=['not_a_url']) + config = MCPConfig(mcp_servers=['not_a_url']) with pytest.raises(ValueError) as exc_info: config.validate_servers() assert 'Invalid URL' in str(exc_info.value) @@ -25,7 +25,7 @@ def test_invalid_sse_url(): def test_duplicate_sse_urls(): """Test SSE configuration with duplicate server URLs.""" - config = MCPSSEConfig(mcp_servers=['http://server1:8080', 'http://server1:8080']) + config = MCPConfig(mcp_servers=['http://server1:8080', 'http://server1:8080']) with pytest.raises(ValueError) as exc_info: config.validate_servers() assert 'Duplicate MCP server URLs are not allowed' in str(exc_info.value) @@ -38,7 +38,7 @@ def test_from_toml_section_valid(): } result = MCPConfig.from_toml_section(data) assert 'mcp' in result - assert result['mcp'].sse.mcp_servers == ['http://server1:8080'] + assert result['mcp'].mcp_servers == ['http://server1:8080'] def test_from_toml_section_invalid_sse(): @@ -53,7 +53,7 @@ def test_from_toml_section_invalid_sse(): def test_complex_urls(): """Test SSE configuration with complex URLs.""" - config = MCPSSEConfig( + config = MCPConfig( mcp_servers=[ 'https://user:pass@server1:8080/path?query=1', 'wss://server2:8443/ws', diff --git a/tests/unit/test_mcp_create_clients_timeout.py b/tests/unit/test_mcp_create_clients_timeout.py new file mode 100644 index 0000000000..d0a09355de --- /dev/null +++ b/tests/unit/test_mcp_create_clients_timeout.py @@ -0,0 +1,76 @@ +import asyncio + +import pytest + +from openhands.mcp.client import MCPClient +from openhands.mcp.utils import create_mcp_clients + + +@pytest.mark.asyncio +async def test_create_mcp_clients_timeout_with_invalid_url(): + """Test that create_mcp_clients properly times out when given an invalid URL.""" + # Use a non-existent domain that should cause a connection timeout + invalid_url = 'http://non-existent-domain-that-will-timeout.invalid' + + # Temporarily modify the default timeout for the MCPClient.connect_sse method + original_connect_sse = MCPClient.connect_sse + + # Create a wrapper that calls the original method but with a shorter timeout + async def connect_sse_with_short_timeout(self, server_url, timeout=30.0): + return await original_connect_sse(self, server_url, timeout=0.5) + + try: + # Replace the method with our wrapper + MCPClient.connect_sse = connect_sse_with_short_timeout + + # Call create_mcp_clients with the invalid URL + start_time = asyncio.get_event_loop().time() + clients = await create_mcp_clients([invalid_url]) + end_time = asyncio.get_event_loop().time() + + # Verify that no clients were successfully connected + assert len(clients) == 0 + + # Verify that the operation completed in a reasonable time (less than 5 seconds) + # This ensures the timeout is working properly + assert ( + end_time - start_time < 5.0 + ), 'Operation took too long, timeout may not be working' + finally: + # Restore the original method + MCPClient.connect_sse = original_connect_sse + + +@pytest.mark.asyncio +async def test_create_mcp_clients_with_unreachable_host(): + """Test that create_mcp_clients handles unreachable hosts properly.""" + # Use a URL with a valid format but pointing to a non-routable IP address + # This IP is in the TEST-NET-1 range (192.0.2.0/24) reserved for documentation and examples + unreachable_url = 'http://192.0.2.1:8080' + + # Temporarily modify the default timeout for the MCPClient.connect_sse method + original_connect_sse = MCPClient.connect_sse + + # Create a wrapper that calls the original method but with a shorter timeout + async def connect_sse_with_short_timeout(self, server_url, timeout=30.0): + return await original_connect_sse(self, server_url, timeout=1.0) + + try: + # Replace the method with our wrapper + MCPClient.connect_sse = connect_sse_with_short_timeout + + # Call create_mcp_clients with the unreachable URL + start_time = asyncio.get_event_loop().time() + clients = await create_mcp_clients([unreachable_url]) + end_time = asyncio.get_event_loop().time() + + # Verify that no clients were successfully connected + assert len(clients) == 0 + + # Verify that the operation completed in a reasonable time (less than 5 seconds) + assert ( + end_time - start_time < 5.0 + ), 'Operation took too long, timeout may not be working' + finally: + # Restore the original method + MCPClient.connect_sse = original_connect_sse diff --git a/tests/unit/test_mcp_timeout.py b/tests/unit/test_mcp_timeout.py index d50f44fb59..8bdc932294 100644 --- a/tests/unit/test_mcp_timeout.py +++ b/tests/unit/test_mcp_timeout.py @@ -3,7 +3,7 @@ from unittest import mock import pytest -from openhands.core.config.mcp_config import MCPConfig, MCPSSEConfig +from openhands.core.config.mcp_config import MCPConfig from openhands.mcp import MCPClient, create_mcp_clients, fetch_mcp_tools_from_config @@ -24,10 +24,10 @@ async def test_sse_connection_timeout(): # Mock the MCPClient constructor to return our mock with mock.patch('openhands.mcp.utils.MCPClient', return_value=mock_client): # Create a list of server URLs to test - sse_servers = ['http://server1:8080', 'http://server2:8080'] + servers = ['http://server1:8080', 'http://server2:8080'] # Call create_mcp_clients with the server URLs - clients = await create_mcp_clients(sse_mcp_server=sse_servers) + clients = await create_mcp_clients(mcp_servers=servers) # Verify that no clients were successfully connected assert len(clients) == 0 @@ -44,10 +44,9 @@ async def test_fetch_mcp_tools_with_timeout(): """Test that fetch_mcp_tools_from_config handles timeouts gracefully.""" # Create a mock MCPConfig mock_config = mock.MagicMock(spec=MCPConfig) - mock_config.sse = mock.MagicMock(spec=MCPSSEConfig) # Configure the mock config - mock_config.sse.mcp_servers = ['http://server1:8080'] + mock_config.mcp_servers = ['http://server1:8080'] # Mock create_mcp_clients to return an empty list (simulating all connections failing) with mock.patch('openhands.mcp.utils.create_mcp_clients', return_value=[]): @@ -63,10 +62,9 @@ async def test_mixed_connection_results(): """Test that fetch_mcp_tools_from_config returns tools even when some connections fail.""" # Create a mock MCPConfig mock_config = mock.MagicMock(spec=MCPConfig) - mock_config.sse = mock.MagicMock(spec=MCPSSEConfig) # Configure the mock config - mock_config.sse.mcp_servers = ['http://server1:8080', 'http://server2:8080'] + mock_config.mcp_servers = ['http://server1:8080', 'http://server2:8080'] # Create a successful client successful_client = mock.MagicMock(spec=MCPClient)