refactor(mcp): simplify MCP config & fix timeout (#7820)

Co-authored-by: ducphamle2 <ducphamle212@gmail.com>
Co-authored-by: trungbach <trunga2k29@gmail.com>
Co-authored-by: quangdz1704 <Ntq.1704@gmail.com>
Co-authored-by: Duc Pham <44611780+ducphamle2@users.noreply.github.com>
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Xingyao Wang 2025-04-15 23:04:21 -04:00 committed by GitHub
parent 7e14a512e0
commit 07e400b73d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 148 additions and 46 deletions

View File

@ -4,11 +4,11 @@ from urllib.parse import urlparse
from pydantic import BaseModel, Field, ValidationError
class MCPSSEConfig(BaseModel):
"""Configuration for MCP SSE (Server-Sent Events) settings.
class MCPConfig(BaseModel):
"""Configuration for MCP (Message Control Protocol) settings.
Attributes:
mcp_servers: List of MCP server URLs.
mcp_servers: List of MCP SSE (Server-Sent Events) server URLs.
"""
mcp_servers: List[str] = Field(default_factory=list)
@ -30,18 +30,6 @@ class MCPSSEConfig(BaseModel):
except Exception as e:
raise ValueError(f'Invalid URL {url}: {str(e)}')
class MCPConfig(BaseModel):
"""Configuration for MCP (Message Control Protocol) settings.
Attributes:
sse: SSE-specific configuration.
"""
sse: MCPSSEConfig = Field(default_factory=MCPSSEConfig)
model_config = {'extra': 'forbid'}
@classmethod
def from_toml_section(cls, data: dict) -> dict[str, 'MCPConfig']:
"""
@ -57,11 +45,10 @@ class MCPConfig(BaseModel):
try:
# Create SSE config if present
sse_config = MCPSSEConfig.model_validate(data)
sse_config.validate_servers()
mcp_config = MCPConfig.model_validate(data)
mcp_config.validate_servers()
# Create the main MCP config
mcp_mapping['mcp'] = cls(sse=sse_config)
mcp_mapping['mcp'] = cls(mcp_servers=mcp_config.mcp_servers)
except ValidationError as e:
raise ValueError(f'Invalid MCP configuration: {e}')

View File

@ -1,3 +1,4 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Dict, List, Optional
@ -36,17 +37,29 @@ class MCPClient(BaseModel):
await self.disconnect()
try:
streams_context = sse_client(
url=server_url,
)
streams = await self.exit_stack.enter_async_context(streams_context)
self.session = await self.exit_stack.enter_async_context(
ClientSession(*streams)
)
# Use asyncio.wait_for to enforce the timeout
async def connect_with_timeout():
streams_context = sse_client(
url=server_url,
timeout=timeout, # Pass the timeout to sse_client
)
streams = await self.exit_stack.enter_async_context(streams_context)
self.session = await self.exit_stack.enter_async_context(
ClientSession(*streams)
)
await self._initialize_and_list_tools()
await self._initialize_and_list_tools()
# Apply timeout to the entire connection process
await asyncio.wait_for(connect_with_timeout(), timeout=timeout)
except asyncio.TimeoutError:
logger.error(
f'Connection to {server_url} timed out after {timeout} seconds'
)
await self.disconnect() # Clean up resources
raise # Re-raise the TimeoutError
except Exception as e:
logger.error(f'Error connecting to {server_url}: {str(e)}')
await self.disconnect() # Clean up resources
raise
async def _initialize_and_list_tools(self) -> None:

View File

@ -38,12 +38,12 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di
async def create_mcp_clients(
sse_mcp_server: list[str],
mcp_servers: list[str],
) -> list[MCPClient]:
mcp_clients: list[MCPClient] = []
# Initialize SSE connections
if sse_mcp_server:
for server_url in sse_mcp_server:
if mcp_servers:
for server_url in mcp_servers:
logger.info(
f'Initializing MCP agent for {server_url} with SSE connection...'
)
@ -78,11 +78,11 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
try:
logger.debug(f'Creating MCP clients with config: {mcp_config}')
mcp_clients = await create_mcp_clients(
mcp_config.sse.mcp_servers,
mcp_config.mcp_servers,
)
if not mcp_clients:
logger.warning('No MCP clients were successfully connected')
logger.debug('No MCP clients were successfully connected')
return []
mcp_tools = convert_mcp_clients_to_tools(mcp_clients)

View File

@ -0,0 +1,28 @@
import asyncio
from contextlib import asynccontextmanager
from unittest import mock
import pytest
from openhands.mcp.client import MCPClient
@pytest.mark.asyncio
async def test_connect_sse_timeout():
"""Test that connect_sse properly times out when server_url is invalid."""
client = MCPClient()
# Create a mock async context manager that simulates a timeout
@asynccontextmanager
async def mock_slow_context(*args, **kwargs):
# This will hang for longer than our timeout
await asyncio.sleep(10.0)
yield (mock.AsyncMock(), mock.AsyncMock())
# Patch the sse_client function to return our slow context manager
with mock.patch(
'openhands.mcp.client.sse_client', return_value=mock_slow_context()
):
# Test with a very short timeout
with pytest.raises(asyncio.TimeoutError):
await client.connect_sse('http://example.com', timeout=0.1)

View File

@ -1,23 +1,23 @@
import pytest
from openhands.core.config.mcp_config import MCPConfig, MCPSSEConfig
from openhands.core.config.mcp_config import MCPConfig
def test_valid_sse_config():
"""Test a valid SSE configuration."""
config = MCPSSEConfig(mcp_servers=['http://server1:8080', 'http://server2:8080'])
config = MCPConfig(mcp_servers=['http://server1:8080', 'http://server2:8080'])
config.validate_servers() # Should not raise any exception
def test_empty_sse_config():
"""Test SSE configuration with empty servers list."""
config = MCPSSEConfig(mcp_servers=[])
config = MCPConfig(mcp_servers=[])
config.validate_servers()
def test_invalid_sse_url():
"""Test SSE configuration with invalid URL format."""
config = MCPSSEConfig(mcp_servers=['not_a_url'])
config = MCPConfig(mcp_servers=['not_a_url'])
with pytest.raises(ValueError) as exc_info:
config.validate_servers()
assert 'Invalid URL' in str(exc_info.value)
@ -25,7 +25,7 @@ def test_invalid_sse_url():
def test_duplicate_sse_urls():
"""Test SSE configuration with duplicate server URLs."""
config = MCPSSEConfig(mcp_servers=['http://server1:8080', 'http://server1:8080'])
config = MCPConfig(mcp_servers=['http://server1:8080', 'http://server1:8080'])
with pytest.raises(ValueError) as exc_info:
config.validate_servers()
assert 'Duplicate MCP server URLs are not allowed' in str(exc_info.value)
@ -38,7 +38,7 @@ def test_from_toml_section_valid():
}
result = MCPConfig.from_toml_section(data)
assert 'mcp' in result
assert result['mcp'].sse.mcp_servers == ['http://server1:8080']
assert result['mcp'].mcp_servers == ['http://server1:8080']
def test_from_toml_section_invalid_sse():
@ -53,7 +53,7 @@ def test_from_toml_section_invalid_sse():
def test_complex_urls():
"""Test SSE configuration with complex URLs."""
config = MCPSSEConfig(
config = MCPConfig(
mcp_servers=[
'https://user:pass@server1:8080/path?query=1',
'wss://server2:8443/ws',

View File

@ -0,0 +1,76 @@
import asyncio
import pytest
from openhands.mcp.client import MCPClient
from openhands.mcp.utils import create_mcp_clients
@pytest.mark.asyncio
async def test_create_mcp_clients_timeout_with_invalid_url():
"""Test that create_mcp_clients properly times out when given an invalid URL."""
# Use a non-existent domain that should cause a connection timeout
invalid_url = 'http://non-existent-domain-that-will-timeout.invalid'
# Temporarily modify the default timeout for the MCPClient.connect_sse method
original_connect_sse = MCPClient.connect_sse
# Create a wrapper that calls the original method but with a shorter timeout
async def connect_sse_with_short_timeout(self, server_url, timeout=30.0):
return await original_connect_sse(self, server_url, timeout=0.5)
try:
# Replace the method with our wrapper
MCPClient.connect_sse = connect_sse_with_short_timeout
# Call create_mcp_clients with the invalid URL
start_time = asyncio.get_event_loop().time()
clients = await create_mcp_clients([invalid_url])
end_time = asyncio.get_event_loop().time()
# Verify that no clients were successfully connected
assert len(clients) == 0
# Verify that the operation completed in a reasonable time (less than 5 seconds)
# This ensures the timeout is working properly
assert (
end_time - start_time < 5.0
), 'Operation took too long, timeout may not be working'
finally:
# Restore the original method
MCPClient.connect_sse = original_connect_sse
@pytest.mark.asyncio
async def test_create_mcp_clients_with_unreachable_host():
"""Test that create_mcp_clients handles unreachable hosts properly."""
# Use a URL with a valid format but pointing to a non-routable IP address
# This IP is in the TEST-NET-1 range (192.0.2.0/24) reserved for documentation and examples
unreachable_url = 'http://192.0.2.1:8080'
# Temporarily modify the default timeout for the MCPClient.connect_sse method
original_connect_sse = MCPClient.connect_sse
# Create a wrapper that calls the original method but with a shorter timeout
async def connect_sse_with_short_timeout(self, server_url, timeout=30.0):
return await original_connect_sse(self, server_url, timeout=1.0)
try:
# Replace the method with our wrapper
MCPClient.connect_sse = connect_sse_with_short_timeout
# Call create_mcp_clients with the unreachable URL
start_time = asyncio.get_event_loop().time()
clients = await create_mcp_clients([unreachable_url])
end_time = asyncio.get_event_loop().time()
# Verify that no clients were successfully connected
assert len(clients) == 0
# Verify that the operation completed in a reasonable time (less than 5 seconds)
assert (
end_time - start_time < 5.0
), 'Operation took too long, timeout may not be working'
finally:
# Restore the original method
MCPClient.connect_sse = original_connect_sse

View File

@ -3,7 +3,7 @@ from unittest import mock
import pytest
from openhands.core.config.mcp_config import MCPConfig, MCPSSEConfig
from openhands.core.config.mcp_config import MCPConfig
from openhands.mcp import MCPClient, create_mcp_clients, fetch_mcp_tools_from_config
@ -24,10 +24,10 @@ async def test_sse_connection_timeout():
# Mock the MCPClient constructor to return our mock
with mock.patch('openhands.mcp.utils.MCPClient', return_value=mock_client):
# Create a list of server URLs to test
sse_servers = ['http://server1:8080', 'http://server2:8080']
servers = ['http://server1:8080', 'http://server2:8080']
# Call create_mcp_clients with the server URLs
clients = await create_mcp_clients(sse_mcp_server=sse_servers)
clients = await create_mcp_clients(mcp_servers=servers)
# Verify that no clients were successfully connected
assert len(clients) == 0
@ -44,10 +44,9 @@ async def test_fetch_mcp_tools_with_timeout():
"""Test that fetch_mcp_tools_from_config handles timeouts gracefully."""
# Create a mock MCPConfig
mock_config = mock.MagicMock(spec=MCPConfig)
mock_config.sse = mock.MagicMock(spec=MCPSSEConfig)
# Configure the mock config
mock_config.sse.mcp_servers = ['http://server1:8080']
mock_config.mcp_servers = ['http://server1:8080']
# Mock create_mcp_clients to return an empty list (simulating all connections failing)
with mock.patch('openhands.mcp.utils.create_mcp_clients', return_value=[]):
@ -63,10 +62,9 @@ async def test_mixed_connection_results():
"""Test that fetch_mcp_tools_from_config returns tools even when some connections fail."""
# Create a mock MCPConfig
mock_config = mock.MagicMock(spec=MCPConfig)
mock_config.sse = mock.MagicMock(spec=MCPSSEConfig)
# Configure the mock config
mock_config.sse.mcp_servers = ['http://server1:8080', 'http://server2:8080']
mock_config.mcp_servers = ['http://server1:8080', 'http://server2:8080']
# Create a successful client
successful_client = mock.MagicMock(spec=MCPClient)