mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Xingyao Wang <xingyao@all-hands.dev> Co-authored-by: Robert Brennan <accounts@rbren.io>
127 lines
4.6 KiB
Python
127 lines
4.6 KiB
Python
import asyncio
|
|
from contextlib import AsyncExitStack
|
|
from typing import Optional
|
|
|
|
from mcp import ClientSession
|
|
from mcp.client.sse import sse_client
|
|
from pydantic import BaseModel, Field
|
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
from openhands.mcp.tool import MCPClientTool
|
|
|
|
|
|
class MCPClient(BaseModel):
|
|
"""
|
|
A collection of tools that connects to an MCP server and manages available tools through the Model Context Protocol.
|
|
"""
|
|
|
|
session: Optional[ClientSession] = None
|
|
exit_stack: AsyncExitStack = AsyncExitStack()
|
|
description: str = 'MCP client tools for server interaction'
|
|
tools: list[MCPClientTool] = Field(default_factory=list)
|
|
tool_map: dict[str, MCPClientTool] = Field(default_factory=dict)
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
async def connect_sse(
|
|
self,
|
|
server_url: str,
|
|
api_key: str | None = None,
|
|
conversation_id: str | None = None,
|
|
timeout: float = 30.0,
|
|
) -> None:
|
|
"""Connect to an MCP server using SSE transport.
|
|
|
|
Args:
|
|
server_url: The URL of the SSE server to connect to.
|
|
timeout: Connection timeout in seconds. Default is 30 seconds.
|
|
"""
|
|
if not server_url:
|
|
raise ValueError('Server URL is required.')
|
|
if self.session:
|
|
await self.disconnect()
|
|
|
|
try:
|
|
# Use asyncio.wait_for to enforce the timeout
|
|
async def connect_with_timeout():
|
|
headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
|
|
|
if conversation_id:
|
|
headers['X-OpenHands-Conversation-ID'] = conversation_id
|
|
|
|
streams_context = sse_client(
|
|
url=server_url,
|
|
headers=headers if headers else None,
|
|
timeout=timeout,
|
|
)
|
|
streams = await self.exit_stack.enter_async_context(streams_context)
|
|
self.session = await self.exit_stack.enter_async_context(
|
|
ClientSession(*streams)
|
|
)
|
|
await self._initialize_and_list_tools()
|
|
|
|
# Apply timeout to the entire connection process
|
|
await asyncio.wait_for(connect_with_timeout(), timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
logger.error(
|
|
f'Connection to {server_url} timed out after {timeout} seconds'
|
|
)
|
|
await self.disconnect() # Clean up resources
|
|
raise # Re-raise the TimeoutError
|
|
except Exception as e:
|
|
logger.error(f'Error connecting to {server_url}: {str(e)}')
|
|
await self.disconnect() # Clean up resources
|
|
raise
|
|
|
|
async def _initialize_and_list_tools(self) -> None:
|
|
"""Initialize session and populate tool map."""
|
|
if not self.session:
|
|
raise RuntimeError('Session not initialized.')
|
|
|
|
await self.session.initialize()
|
|
response = await self.session.list_tools()
|
|
|
|
# Clear existing tools
|
|
self.tools = []
|
|
|
|
# Create proper tool objects for each server tool
|
|
for tool in response.tools:
|
|
server_tool = MCPClientTool(
|
|
name=tool.name,
|
|
description=tool.description,
|
|
inputSchema=tool.inputSchema,
|
|
session=self.session,
|
|
)
|
|
self.tool_map[tool.name] = server_tool
|
|
self.tools.append(server_tool)
|
|
|
|
logger.info(
|
|
f'Connected to server with tools: {[tool.name for tool in response.tools]}'
|
|
)
|
|
|
|
async def call_tool(self, tool_name: str, args: dict):
|
|
"""Call a tool on the MCP server."""
|
|
if tool_name not in self.tool_map:
|
|
raise ValueError(f'Tool {tool_name} not found.')
|
|
# The MCPClientTool is primarily for metadata; use the session to call the actual tool.
|
|
if not self.session:
|
|
raise RuntimeError('Client session is not available.')
|
|
return await self.session.call_tool(name=tool_name, arguments=args)
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from the MCP server and clean up resources."""
|
|
if self.session:
|
|
try:
|
|
# Close the session first
|
|
if hasattr(self.session, 'close'):
|
|
await self.session.close()
|
|
# Then close the exit stack
|
|
await self.exit_stack.aclose()
|
|
except Exception as e:
|
|
logger.error(f'Error during disconnect: {str(e)}')
|
|
finally:
|
|
self.session = None
|
|
self.tools = []
|
|
logger.info('Disconnected from MCP server')
|