mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Co-authored-by: Tejas Goyal <tejas@Tejass-MacBook-Pro.local>
This commit is contained in:
@@ -189,8 +189,17 @@ class MCPStdioServerConfig(BaseModel):
|
||||
|
||||
|
||||
class MCPSHTTPServerConfig(BaseModel):
|
||||
"""Configuration for a MCP server that uses SHTTP.
|
||||
|
||||
Attributes:
|
||||
url: The server URL
|
||||
api_key: Optional API key for authentication
|
||||
timeout: Timeout in seconds for tool calls (default: 60s)
|
||||
"""
|
||||
|
||||
url: str
|
||||
api_key: str | None = None
|
||||
timeout: int | None = 60
|
||||
|
||||
@field_validator('url', mode='before')
|
||||
@classmethod
|
||||
@@ -198,6 +207,17 @@ class MCPSHTTPServerConfig(BaseModel):
|
||||
"""Validate URL format for MCP servers."""
|
||||
return _validate_mcp_url(v)
|
||||
|
||||
@field_validator('timeout')
|
||||
@classmethod
|
||||
def validate_timeout(cls, v: int | None) -> int | None:
|
||||
"""Validate timeout value for MCP tool calls."""
|
||||
if v is not None:
|
||||
if v <= 0:
|
||||
raise ValueError('Timeout must be positive')
|
||||
if v > 3600: # 1 hour max
|
||||
raise ValueError('Timeout cannot exceed 3600 seconds')
|
||||
return v
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Configuration for MCP (Message Control Protocol) settings.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from fastmcp import Client
|
||||
@@ -29,6 +30,7 @@ class MCPClient(BaseModel):
|
||||
description: str = 'MCP client tools for server interaction'
|
||||
tools: list[MCPClientTool] = Field(default_factory=list)
|
||||
tool_map: dict[str, MCPClientTool] = Field(default_factory=dict)
|
||||
server_timeout: Optional[float] = None # Timeout from server config for tool calls
|
||||
|
||||
async def _initialize_and_list_tools(self) -> None:
|
||||
"""Initialize session and populate tool map."""
|
||||
@@ -60,7 +62,7 @@ class MCPClient(BaseModel):
|
||||
conversation_id: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
"""Connect to MCP server using SHTTP or SSE transport"""
|
||||
"""Connect to MCP server using SHTTP or SSE transport."""
|
||||
server_url = server.url
|
||||
api_key = server.api_key
|
||||
|
||||
@@ -123,7 +125,7 @@ class MCPClient(BaseModel):
|
||||
raise
|
||||
|
||||
async def connect_stdio(self, server: MCPStdioServerConfig, timeout: float = 30.0):
|
||||
"""Connect to MCP server using stdio transport"""
|
||||
"""Connect to MCP server using stdio transport."""
|
||||
try:
|
||||
transport = StdioTransport(
|
||||
command=server.command, args=server.args or [], env=server.env
|
||||
@@ -145,7 +147,20 @@ class MCPClient(BaseModel):
|
||||
raise
|
||||
|
||||
async def call_tool(self, tool_name: str, args: dict) -> CallToolResult:
|
||||
"""Call a tool on the MCP server."""
|
||||
"""Call a tool on the MCP server with timeout from server configuration.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
args: Arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
CallToolResult from the MCP server
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If the tool call times out
|
||||
ValueError: If the tool is not found
|
||||
RuntimeError: If the client session is not available
|
||||
"""
|
||||
if tool_name not in self.tool_map:
|
||||
raise ValueError(f'Tool {tool_name} not found.')
|
||||
# The MCPClientTool is primarily for metadata; use the session to call the actual tool.
|
||||
@@ -153,4 +168,11 @@ class MCPClient(BaseModel):
|
||||
raise RuntimeError('Client session is not available.')
|
||||
|
||||
async with self.client:
|
||||
return await self.client.call_tool_mcp(name=tool_name, arguments=args)
|
||||
# Use server timeout if configured
|
||||
if self.server_timeout is not None:
|
||||
return await asyncio.wait_for(
|
||||
self.client.call_tool_mcp(name=tool_name, arguments=args),
|
||||
timeout=self.server_timeout,
|
||||
)
|
||||
else:
|
||||
return await self.client.call_tool_mcp(name=tool_name, arguments=args)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -128,6 +129,11 @@ async def create_mcp_clients(
|
||||
)
|
||||
client = MCPClient()
|
||||
|
||||
# Set server timeout for SHTTP servers
|
||||
if isinstance(server, MCPSHTTPServerConfig) and server.timeout is not None:
|
||||
client.server_timeout = float(server.timeout)
|
||||
logger.debug(f'Set SHTTP server timeout to {server.timeout}s')
|
||||
|
||||
try:
|
||||
await client.connect_http(server, conversation_id=conversation_id)
|
||||
|
||||
@@ -253,6 +259,22 @@ async def call_tool_mcp(mcp_clients: list[MCPClient], action: MCPAction) -> Obse
|
||||
name=action.name,
|
||||
arguments=action.arguments,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Handle timeout errors specifically
|
||||
timeout_val = getattr(matching_client, 'server_timeout', 'unknown')
|
||||
logger.error(f'MCP tool {action.name} timed out after {timeout_val}s')
|
||||
error_content = json.dumps(
|
||||
{
|
||||
'isError': True,
|
||||
'error': f'Tool "{action.name}" timed out after {timeout_val} seconds',
|
||||
'content': [],
|
||||
}
|
||||
)
|
||||
return MCPObservation(
|
||||
content=error_content,
|
||||
name=action.name,
|
||||
arguments=action.arguments,
|
||||
)
|
||||
except McpError as e:
|
||||
# Handle MCP errors by returning an error observation instead of raising
|
||||
logger.error(f'MCP error when calling tool {action.name}: {e}')
|
||||
|
||||
Reference in New Issue
Block a user