mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
refactor(mcp): simplify MCP config & fix timeout (#7820)
Co-authored-by: ducphamle2 <ducphamle212@gmail.com> Co-authored-by: trungbach <trunga2k29@gmail.com> Co-authored-by: quangdz1704 <Ntq.1704@gmail.com> Co-authored-by: Duc Pham <44611780+ducphamle2@users.noreply.github.com> Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
7e14a512e0
commit
07e400b73d
@ -4,11 +4,11 @@ from urllib.parse import urlparse
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
|
||||
class MCPSSEConfig(BaseModel):
|
||||
"""Configuration for MCP SSE (Server-Sent Events) settings.
|
||||
class MCPConfig(BaseModel):
|
||||
"""Configuration for MCP (Message Control Protocol) settings.
|
||||
|
||||
Attributes:
|
||||
mcp_servers: List of MCP server URLs.
|
||||
mcp_servers: List of MCP SSE (Server-Sent Events) server URLs.
|
||||
"""
|
||||
|
||||
mcp_servers: List[str] = Field(default_factory=list)
|
||||
@ -30,18 +30,6 @@ class MCPSSEConfig(BaseModel):
|
||||
except Exception as e:
|
||||
raise ValueError(f'Invalid URL {url}: {str(e)}')
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Configuration for MCP (Message Control Protocol) settings.
|
||||
|
||||
Attributes:
|
||||
sse: SSE-specific configuration.
|
||||
"""
|
||||
|
||||
sse: MCPSSEConfig = Field(default_factory=MCPSSEConfig)
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
@classmethod
|
||||
def from_toml_section(cls, data: dict) -> dict[str, 'MCPConfig']:
|
||||
"""
|
||||
@ -57,11 +45,10 @@ class MCPConfig(BaseModel):
|
||||
|
||||
try:
|
||||
# Create SSE config if present
|
||||
sse_config = MCPSSEConfig.model_validate(data)
|
||||
sse_config.validate_servers()
|
||||
|
||||
mcp_config = MCPConfig.model_validate(data)
|
||||
mcp_config.validate_servers()
|
||||
# Create the main MCP config
|
||||
mcp_mapping['mcp'] = cls(sse=sse_config)
|
||||
mcp_mapping['mcp'] = cls(mcp_servers=mcp_config.mcp_servers)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f'Invalid MCP configuration: {e}')
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@ -36,17 +37,29 @@ class MCPClient(BaseModel):
|
||||
await self.disconnect()
|
||||
|
||||
try:
|
||||
streams_context = sse_client(
|
||||
url=server_url,
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(streams_context)
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(*streams)
|
||||
)
|
||||
# Use asyncio.wait_for to enforce the timeout
|
||||
async def connect_with_timeout():
|
||||
streams_context = sse_client(
|
||||
url=server_url,
|
||||
timeout=timeout, # Pass the timeout to sse_client
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(streams_context)
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(*streams)
|
||||
)
|
||||
await self._initialize_and_list_tools()
|
||||
|
||||
await self._initialize_and_list_tools()
|
||||
# Apply timeout to the entire connection process
|
||||
await asyncio.wait_for(connect_with_timeout(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f'Connection to {server_url} timed out after {timeout} seconds'
|
||||
)
|
||||
await self.disconnect() # Clean up resources
|
||||
raise # Re-raise the TimeoutError
|
||||
except Exception as e:
|
||||
logger.error(f'Error connecting to {server_url}: {str(e)}')
|
||||
await self.disconnect() # Clean up resources
|
||||
raise
|
||||
|
||||
async def _initialize_and_list_tools(self) -> None:
|
||||
|
||||
@ -38,12 +38,12 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di
|
||||
|
||||
|
||||
async def create_mcp_clients(
|
||||
sse_mcp_server: list[str],
|
||||
mcp_servers: list[str],
|
||||
) -> list[MCPClient]:
|
||||
mcp_clients: list[MCPClient] = []
|
||||
# Initialize SSE connections
|
||||
if sse_mcp_server:
|
||||
for server_url in sse_mcp_server:
|
||||
if mcp_servers:
|
||||
for server_url in mcp_servers:
|
||||
logger.info(
|
||||
f'Initializing MCP agent for {server_url} with SSE connection...'
|
||||
)
|
||||
@ -78,11 +78,11 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
|
||||
try:
|
||||
logger.debug(f'Creating MCP clients with config: {mcp_config}')
|
||||
mcp_clients = await create_mcp_clients(
|
||||
mcp_config.sse.mcp_servers,
|
||||
mcp_config.mcp_servers,
|
||||
)
|
||||
|
||||
if not mcp_clients:
|
||||
logger.warning('No MCP clients were successfully connected')
|
||||
logger.debug('No MCP clients were successfully connected')
|
||||
return []
|
||||
|
||||
mcp_tools = convert_mcp_clients_to_tools(mcp_clients)
|
||||
|
||||
28
tests/unit/test_mcp_client_timeout.py
Normal file
28
tests/unit/test_mcp_client_timeout.py
Normal file
@ -0,0 +1,28 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.mcp.client import MCPClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_sse_timeout():
|
||||
"""Test that connect_sse properly times out when server_url is invalid."""
|
||||
client = MCPClient()
|
||||
|
||||
# Create a mock async context manager that simulates a timeout
|
||||
@asynccontextmanager
|
||||
async def mock_slow_context(*args, **kwargs):
|
||||
# This will hang for longer than our timeout
|
||||
await asyncio.sleep(10.0)
|
||||
yield (mock.AsyncMock(), mock.AsyncMock())
|
||||
|
||||
# Patch the sse_client function to return our slow context manager
|
||||
with mock.patch(
|
||||
'openhands.mcp.client.sse_client', return_value=mock_slow_context()
|
||||
):
|
||||
# Test with a very short timeout
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await client.connect_sse('http://example.com', timeout=0.1)
|
||||
@ -1,23 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from openhands.core.config.mcp_config import MCPConfig, MCPSSEConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
|
||||
|
||||
def test_valid_sse_config():
|
||||
"""Test a valid SSE configuration."""
|
||||
config = MCPSSEConfig(mcp_servers=['http://server1:8080', 'http://server2:8080'])
|
||||
config = MCPConfig(mcp_servers=['http://server1:8080', 'http://server2:8080'])
|
||||
config.validate_servers() # Should not raise any exception
|
||||
|
||||
|
||||
def test_empty_sse_config():
|
||||
"""Test SSE configuration with empty servers list."""
|
||||
config = MCPSSEConfig(mcp_servers=[])
|
||||
config = MCPConfig(mcp_servers=[])
|
||||
config.validate_servers()
|
||||
|
||||
|
||||
def test_invalid_sse_url():
|
||||
"""Test SSE configuration with invalid URL format."""
|
||||
config = MCPSSEConfig(mcp_servers=['not_a_url'])
|
||||
config = MCPConfig(mcp_servers=['not_a_url'])
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
config.validate_servers()
|
||||
assert 'Invalid URL' in str(exc_info.value)
|
||||
@ -25,7 +25,7 @@ def test_invalid_sse_url():
|
||||
|
||||
def test_duplicate_sse_urls():
|
||||
"""Test SSE configuration with duplicate server URLs."""
|
||||
config = MCPSSEConfig(mcp_servers=['http://server1:8080', 'http://server1:8080'])
|
||||
config = MCPConfig(mcp_servers=['http://server1:8080', 'http://server1:8080'])
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
config.validate_servers()
|
||||
assert 'Duplicate MCP server URLs are not allowed' in str(exc_info.value)
|
||||
@ -38,7 +38,7 @@ def test_from_toml_section_valid():
|
||||
}
|
||||
result = MCPConfig.from_toml_section(data)
|
||||
assert 'mcp' in result
|
||||
assert result['mcp'].sse.mcp_servers == ['http://server1:8080']
|
||||
assert result['mcp'].mcp_servers == ['http://server1:8080']
|
||||
|
||||
|
||||
def test_from_toml_section_invalid_sse():
|
||||
@ -53,7 +53,7 @@ def test_from_toml_section_invalid_sse():
|
||||
|
||||
def test_complex_urls():
|
||||
"""Test SSE configuration with complex URLs."""
|
||||
config = MCPSSEConfig(
|
||||
config = MCPConfig(
|
||||
mcp_servers=[
|
||||
'https://user:pass@server1:8080/path?query=1',
|
||||
'wss://server2:8443/ws',
|
||||
|
||||
76
tests/unit/test_mcp_create_clients_timeout.py
Normal file
76
tests/unit/test_mcp_create_clients_timeout.py
Normal file
@ -0,0 +1,76 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.mcp.client import MCPClient
|
||||
from openhands.mcp.utils import create_mcp_clients
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mcp_clients_timeout_with_invalid_url():
|
||||
"""Test that create_mcp_clients properly times out when given an invalid URL."""
|
||||
# Use a non-existent domain that should cause a connection timeout
|
||||
invalid_url = 'http://non-existent-domain-that-will-timeout.invalid'
|
||||
|
||||
# Temporarily modify the default timeout for the MCPClient.connect_sse method
|
||||
original_connect_sse = MCPClient.connect_sse
|
||||
|
||||
# Create a wrapper that calls the original method but with a shorter timeout
|
||||
async def connect_sse_with_short_timeout(self, server_url, timeout=30.0):
|
||||
return await original_connect_sse(self, server_url, timeout=0.5)
|
||||
|
||||
try:
|
||||
# Replace the method with our wrapper
|
||||
MCPClient.connect_sse = connect_sse_with_short_timeout
|
||||
|
||||
# Call create_mcp_clients with the invalid URL
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
clients = await create_mcp_clients([invalid_url])
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Verify that no clients were successfully connected
|
||||
assert len(clients) == 0
|
||||
|
||||
# Verify that the operation completed in a reasonable time (less than 5 seconds)
|
||||
# This ensures the timeout is working properly
|
||||
assert (
|
||||
end_time - start_time < 5.0
|
||||
), 'Operation took too long, timeout may not be working'
|
||||
finally:
|
||||
# Restore the original method
|
||||
MCPClient.connect_sse = original_connect_sse
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mcp_clients_with_unreachable_host():
|
||||
"""Test that create_mcp_clients handles unreachable hosts properly."""
|
||||
# Use a URL with a valid format but pointing to a non-routable IP address
|
||||
# This IP is in the TEST-NET-1 range (192.0.2.0/24) reserved for documentation and examples
|
||||
unreachable_url = 'http://192.0.2.1:8080'
|
||||
|
||||
# Temporarily modify the default timeout for the MCPClient.connect_sse method
|
||||
original_connect_sse = MCPClient.connect_sse
|
||||
|
||||
# Create a wrapper that calls the original method but with a shorter timeout
|
||||
async def connect_sse_with_short_timeout(self, server_url, timeout=30.0):
|
||||
return await original_connect_sse(self, server_url, timeout=1.0)
|
||||
|
||||
try:
|
||||
# Replace the method with our wrapper
|
||||
MCPClient.connect_sse = connect_sse_with_short_timeout
|
||||
|
||||
# Call create_mcp_clients with the unreachable URL
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
clients = await create_mcp_clients([unreachable_url])
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Verify that no clients were successfully connected
|
||||
assert len(clients) == 0
|
||||
|
||||
# Verify that the operation completed in a reasonable time (less than 5 seconds)
|
||||
assert (
|
||||
end_time - start_time < 5.0
|
||||
), 'Operation took too long, timeout may not be working'
|
||||
finally:
|
||||
# Restore the original method
|
||||
MCPClient.connect_sse = original_connect_sse
|
||||
@ -3,7 +3,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config.mcp_config import MCPConfig, MCPSSEConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.mcp import MCPClient, create_mcp_clients, fetch_mcp_tools_from_config
|
||||
|
||||
|
||||
@ -24,10 +24,10 @@ async def test_sse_connection_timeout():
|
||||
# Mock the MCPClient constructor to return our mock
|
||||
with mock.patch('openhands.mcp.utils.MCPClient', return_value=mock_client):
|
||||
# Create a list of server URLs to test
|
||||
sse_servers = ['http://server1:8080', 'http://server2:8080']
|
||||
servers = ['http://server1:8080', 'http://server2:8080']
|
||||
|
||||
# Call create_mcp_clients with the server URLs
|
||||
clients = await create_mcp_clients(sse_mcp_server=sse_servers)
|
||||
clients = await create_mcp_clients(mcp_servers=servers)
|
||||
|
||||
# Verify that no clients were successfully connected
|
||||
assert len(clients) == 0
|
||||
@ -44,10 +44,9 @@ async def test_fetch_mcp_tools_with_timeout():
|
||||
"""Test that fetch_mcp_tools_from_config handles timeouts gracefully."""
|
||||
# Create a mock MCPConfig
|
||||
mock_config = mock.MagicMock(spec=MCPConfig)
|
||||
mock_config.sse = mock.MagicMock(spec=MCPSSEConfig)
|
||||
|
||||
# Configure the mock config
|
||||
mock_config.sse.mcp_servers = ['http://server1:8080']
|
||||
mock_config.mcp_servers = ['http://server1:8080']
|
||||
|
||||
# Mock create_mcp_clients to return an empty list (simulating all connections failing)
|
||||
with mock.patch('openhands.mcp.utils.create_mcp_clients', return_value=[]):
|
||||
@ -63,10 +62,9 @@ async def test_mixed_connection_results():
|
||||
"""Test that fetch_mcp_tools_from_config returns tools even when some connections fail."""
|
||||
# Create a mock MCPConfig
|
||||
mock_config = mock.MagicMock(spec=MCPConfig)
|
||||
mock_config.sse = mock.MagicMock(spec=MCPSSEConfig)
|
||||
|
||||
# Configure the mock config
|
||||
mock_config.sse.mcp_servers = ['http://server1:8080', 'http://server2:8080']
|
||||
mock_config.mcp_servers = ['http://server1:8080', 'http://server2:8080']
|
||||
|
||||
# Create a successful client
|
||||
successful_client = mock.MagicMock(spec=MCPClient)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user