feat (backend): Add support for MCP servers natively via CodeActAgent (#7637)

Co-authored-by: trungbach <trunga2k29@gmail.com>
Co-authored-by: quangdz1704 <Ntq.1704@gmail.com>
Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
This commit is contained in:
Duc Pham
2025-04-09 18:59:13 -07:00
committed by GitHub
parent e359a4affa
commit 35d49f6941
40 changed files with 803 additions and 34 deletions

View File

@@ -62,21 +62,21 @@ class CodeActAgent(Agent):
Parameters:
- llm (LLM): The llm to be used by this agent
- config (AgentConfig): The configuration for this agent
"""
super().__init__(llm, config)
self.pending_actions: deque[Action] = deque()
self.reset()
# Retrieve the enabled tools
self.tools = codeact_function_calling.get_tools(
built_in_tools = codeact_function_calling.get_tools(
codeact_enable_browsing=self.config.codeact_enable_browsing,
codeact_enable_jupyter=self.config.codeact_enable_jupyter,
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
llm=self.llm,
)
logger.debug(
f"TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}"
)
self.tools = built_in_tools
self.prompt_manager = PromptManager(
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
)
@@ -137,10 +137,23 @@ class CodeActAgent(Agent):
'messages': self.llm.format_messages_for_llm(messages),
}
params['tools'] = self.tools
if self.mcp_tools:
# Only add tools with unique names
existing_names = {tool['function']['name'] for tool in params['tools']}
unique_mcp_tools = [
tool
for tool in self.mcp_tools
if tool['function']['name'] not in existing_names
]
params['tools'] += unique_mcp_tools
# log to litellm proxy if possible
params['extra_body'] = {'metadata': state.to_llm_metadata(agent_name=self.name)}
response = self.llm.completion(**params)
logger.debug(f'Response from LLM: {response}')
actions = codeact_function_calling.response_to_actions(response)
logger.debug(f'Actions after response_to_actions: {actions}')
for action in actions:
self.pending_actions.append(action)
return self.pending_actions.popleft()

View File

@@ -24,6 +24,7 @@ from openhands.core.exceptions import (
FunctionCallNotExistsError,
FunctionCallValidationError,
)
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import (
Action,
AgentDelegateAction,
@@ -37,9 +38,11 @@ from openhands.events.action import (
IPythonRunCellAction,
MessageAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.event import FileEditSource, FileReadSource
from openhands.events.tool import ToolCallMetadata
from openhands.llm import LLM
from openhands.mcp import MCPClientTool
def combine_thought(action: Action, thought: str) -> Action:
@@ -70,6 +73,7 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
# Process each tool call to OpenHands action
for i, tool_call in enumerate(assistant_msg.tool_calls):
action: Action
logger.debug(f'Tool call in function_calling.py: {tool_call}')
try:
arguments = json.loads(tool_call.function.arguments)
except json.decoder.JSONDecodeError as e:
@@ -191,6 +195,15 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
f'Missing required argument "url" in tool call {tool_call.function.name}'
)
action = BrowseURLAction(url=arguments['url'])
# ================================================
# McpAction (MCP)
# ================================================
elif tool_call.function.name.endswith(MCPClientTool.postfix()):
action = McpAction(
name=tool_call.function.name.rstrip(MCPClientTool.postfix()),
arguments=tool_call.function.arguments,
)
else:
raise FunctionCallNotExistsError(
f'Tool {tool_call.function.name} is not registered. (arguments: {arguments}). Please check the tool name and retry with an existing tool.'

View File

@@ -37,6 +37,7 @@ class Agent(ABC):
self.config = config
self._complete = False
self.prompt_manager: 'PromptManager' | None = None
self.mcp_tools: list[dict] = []
@property
def complete(self) -> bool:
@@ -111,3 +112,11 @@ class Agent(ABC):
if not bool(cls._registry):
raise AgentNotRegisteredError()
return list(cls._registry.keys())
def set_mcp_tools(self, mcp_tools: list[dict]) -> None:
"""Sets the list of MCP tools for the agent.
Args:
- mcp_tools (list[dict]): The list of MCP tools.
"""
self.mcp_tools = mcp_tools

View File

@@ -39,6 +39,7 @@ from openhands.events.observation import (
FileEditObservation,
)
from openhands.io import read_task
from openhands.mcp import fetch_mcp_tools_from_config
prompt_session = PromptSession()
@@ -195,7 +196,8 @@ async def main(loop: asyncio.AbstractEventLoop) -> None:
display_message(f'Session ID: {sid}')
agent = create_agent(config)
mcp_tools = await fetch_mcp_tools_from_config(config.mcp)
agent.set_mcp_tools(mcp_tools)
runtime = create_runtime(
config,
sid=sid,

View File

@@ -11,6 +11,7 @@ from openhands.core.config.config_utils import (
)
from openhands.core.config.extended_config import ExtendedConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.mcp_config import MCPConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
@@ -47,6 +48,7 @@ class AppConfig(BaseModel):
file_uploads_allowed_extensions: Allowed file extensions. `['.*']` allows all.
cli_multiline_input: Whether to enable multiline input in CLI. When disabled,
input is read line by line. When enabled, input continues until /exit command.
mcp: MCP configuration settings.
"""
llms: dict[str, LLMConfig] = Field(default_factory=dict)
@@ -88,6 +90,7 @@ class AppConfig(BaseModel):
max_concurrent_conversations: int = Field(
default=3
) # Maximum number of concurrent agent loops allowed per user
mcp: MCPConfig = Field(default_factory=MCPConfig)
defaults_dict: ClassVar[dict] = {}

View File

@@ -0,0 +1,68 @@
from typing import List
from urllib.parse import urlparse
from pydantic import BaseModel, Field, ValidationError
class MCPSSEConfig(BaseModel):
"""Configuration for MCP SSE (Server-Sent Events) settings.
Attributes:
mcp_servers: List of MCP server URLs.
"""
mcp_servers: List[str] = Field(default_factory=list)
model_config = {'extra': 'forbid'}
def validate_servers(self) -> None:
"""Validate that server URLs are valid and unique."""
# Check for duplicate server URLs
if len(set(self.mcp_servers)) != len(self.mcp_servers):
raise ValueError('Duplicate MCP server URLs are not allowed')
# Validate URLs
for url in self.mcp_servers:
try:
result = urlparse(url)
if not all([result.scheme, result.netloc]):
raise ValueError(f'Invalid URL format: {url}')
except Exception as e:
raise ValueError(f'Invalid URL {url}: {str(e)}')
class MCPConfig(BaseModel):
"""Configuration for MCP (Message Control Protocol) settings.
Attributes:
sse: SSE-specific configuration.
"""
sse: MCPSSEConfig = Field(default_factory=MCPSSEConfig)
model_config = {'extra': 'forbid'}
@classmethod
def from_toml_section(cls, data: dict) -> dict[str, 'MCPConfig']:
"""
Create a mapping of MCPConfig instances from a toml dictionary representing the [mcp] section.
The configuration is built from all keys in data.
Returns:
dict[str, MCPConfig]: A mapping where the key "mcp" corresponds to the [mcp] configuration
"""
# Initialize the result mapping
mcp_mapping: dict[str, MCPConfig] = {}
try:
# Create SSE config if present
sse_config = MCPSSEConfig.model_validate(data)
sse_config.validate_servers()
# Create the main MCP config
mcp_mapping['mcp'] = cls(sse=sse_config)
except ValidationError as e:
raise ValueError(f'Invalid MCP configuration: {e}')
return mcp_mapping

View File

@@ -23,6 +23,7 @@ from openhands.core.config.config_utils import (
)
from openhands.core.config.extended_config import ExtendedConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.mcp_config import MCPConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
from openhands.storage import get_file_store
@@ -202,6 +203,21 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml') -> None:
# Re-raise ValueError from SandboxConfig.from_toml_section
raise ValueError('Error in [sandbox] section in config.toml')
# Process MCP sections if present
if 'mcp' in toml_config:
try:
mcp_mapping = MCPConfig.from_toml_section(toml_config['mcp'])
# We only use the base mcp config for now
if 'mcp' in mcp_mapping:
cfg.mcp = mcp_mapping['mcp']
except (TypeError, KeyError, ValidationError) as e:
logger.openhands_logger.warning(
f'Cannot parse MCP config from toml, values have not been applied.\nError: {e}'
)
except ValueError:
# Re-raise ValueError from MCPConfig.from_toml_section
raise ValueError('Error in MCP sections in config.toml')
# Process condenser section if present
if 'condenser' in toml_config:
try:
@@ -259,6 +275,7 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml') -> None:
'security',
'sandbox',
'condenser',
'mcp',
}
for key in toml_config:
if key.lower() not in known_sections:

View File

@@ -30,6 +30,7 @@ from openhands.events.action.action import Action
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.io import read_input, read_task
from openhands.mcp import fetch_mcp_tools_from_config
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
@@ -95,6 +96,8 @@ async def run_controller(
if agent is None:
agent = create_agent(config)
mcp_tools = await fetch_mcp_tools_from_config(config.mcp)
agent.set_mcp_tools(mcp_tools)
# when the runtime is created, it will be connected and clone the selected repository
repo_directory = None

View File

@@ -38,6 +38,10 @@ class ActionType(str, Enum):
"""Interact with the browser instance.
"""
MCP = 'call_tool_mcp'
"""Interact with the MCP server.
"""
DELEGATE = 'delegate'
"""Delegates a task to another agent.
"""

View File

@@ -49,3 +49,6 @@ class ObservationType(str, Enum):
RECALL = 'recall'
"""Result of a recall operation. This can be the workspace context, a microagent, or other types of information."""
MCP = 'mcp'
"""Result of a MCP Server operation"""

View File

@@ -175,6 +175,7 @@ def create_agent(config: AppConfig) -> Agent:
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
llm_config = config.get_llm_config_from_agent(config.default_agent)
agent = agent_cls(
llm=LLM(config=llm_config),
config=agent_config,

View File

@@ -15,6 +15,7 @@ from openhands.events.action.files import (
FileReadAction,
FileWriteAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.message import MessageAction
__all__ = [
@@ -35,4 +36,5 @@ __all__ = [
'ActionConfirmationStatus',
'AgentThinkAction',
'RecallAction',
'McpAction',
]

View File

@@ -0,0 +1,32 @@
from dataclasses import dataclass
from typing import ClassVar
from openhands.core.schema import ActionType
from openhands.events.action.action import Action, ActionSecurityRisk
@dataclass
class McpAction(Action):
name: str
arguments: str | None = None
thought: str = ''
action: str = ActionType.MCP
runnable: ClassVar[bool] = True
security_risk: ActionSecurityRisk | None = None
@property
def message(self) -> str:
return (
f'I am interacting with the MCP server with name:\n'
f'```\n{self.name}\n```\n'
f'and arguments:\n'
f'```\n{self.arguments}\n```'
)
def __str__(self) -> str:
ret = '**McpAction**\n'
if self.thought:
ret += f'THOUGHT: {self.thought}\n'
ret += f'NAME: {self.name}\n'
ret += f'ARGUMENTS: {self.arguments}'
return ret

View File

@@ -44,4 +44,5 @@ __all__ = [
'AgentCondensationObservation',
'RecallObservation',
'RecallType',
'MCPObservation',
]

View File

@@ -0,0 +1,15 @@
from dataclasses import dataclass
from openhands.core.schema import ObservationType
from openhands.events.observation.observation import Observation
@dataclass
class MCPObservation(Observation):
"""This data class represents the result of a MCP Server operation."""
observation: str = ObservationType.MCP
@property
def message(self) -> str:
return self.content

View File

@@ -22,6 +22,7 @@ from openhands.events.action.files import (
FileReadAction,
FileWriteAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.message import MessageAction
actions = (
@@ -41,6 +42,7 @@ actions = (
ChangeAgentStateAction,
MessageAction,
CondensationAction,
McpAction,
)
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]

View File

@@ -5,6 +5,7 @@ from typing import Any
from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.events import Event, EventSource
from openhands.events.serialization.action import action_from_dict
from openhands.events.serialization.observation import observation_from_dict
@@ -134,11 +135,12 @@ def event_to_dict(event: 'Event') -> dict:
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
for k, v in props.items()
}
logger.debug(f'extras data in event_to_dict: {d["extras"]}')
# Include success field for CmdOutputObservation
if hasattr(event, 'success'):
d['success'] = event.success
else:
raise ValueError('Event must be either action or observation')
raise ValueError(f'Event must be either action or observation. has: {event}')
return d

View File

@@ -25,6 +25,7 @@ from openhands.events.observation.files import (
FileReadObservation,
FileWriteObservation,
)
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.observation.success import SuccessObservation
@@ -45,6 +46,7 @@ observations = (
AgentCondensationObservation,
AgentThinkObservation,
RecallObservation,
MCPObservation,
)
OBSERVATION_TYPE_TO_CLASS = {

View File

@@ -166,6 +166,7 @@ class EventStream(EventStore):
logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
event._timestamp = datetime.now().isoformat()
event._source = source # type: ignore [attr-defined]
logger.debug(f'Event to add: {event}')
data = event_to_dict(event)
data = self._replace_secrets(data)
event = event_from_dict(data)

View File

@@ -36,7 +36,15 @@ def dumps(obj, **kwargs):
"""Serialize an object to str format"""
if not kwargs:
return _json_encoder.encode(obj)
return json.dumps(obj, cls=OpenHandsJSONEncoder, **kwargs)
# Create a copy of the kwargs to avoid modifying the original
encoder_kwargs = kwargs.copy()
# If cls is specified, use it; otherwise use our custom encoder
if 'cls' not in encoder_kwargs:
encoder_kwargs['cls'] = OpenHandsJSONEncoder
return json.dumps(obj, **encoder_kwargs)
def loads(json_str, **kwargs):

21
openhands/mcp/__init__.py Normal file
View File

@@ -0,0 +1,21 @@
from openhands.mcp.client import MCPClient
from openhands.mcp.tool import (
BaseTool,
MCPClientTool,
)
from openhands.mcp.utils import (
call_tool_mcp,
convert_mcp_clients_to_tools,
create_mcp_clients,
fetch_mcp_tools_from_config,
)
__all__ = [
'MCPClient',
'convert_mcp_clients_to_tools',
'create_mcp_clients',
'BaseTool',
'MCPClientTool',
'fetch_mcp_tools_from_config',
'call_tool_mcp',
]

98
openhands/mcp/client.py Normal file
View File

@@ -0,0 +1,98 @@
from contextlib import AsyncExitStack
from typing import Dict, List, Optional
from mcp import ClientSession
from mcp.client.sse import sse_client
from pydantic import BaseModel, Field
from openhands.core.logger import openhands_logger as logger
from openhands.mcp.tool import BaseTool, MCPClientTool
class MCPClient(BaseModel):
"""
A collection of tools that connects to an MCP server and manages available tools through the Model Context Protocol.
"""
session: Optional[ClientSession] = None
exit_stack: AsyncExitStack = AsyncExitStack()
description: str = 'MCP client tools for server interaction'
tools: List[BaseTool] = Field(default_factory=list)
tool_map: Dict[str, BaseTool] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
async def connect_sse(self, server_url: str, timeout: float = 30.0) -> None:
"""Connect to an MCP server using SSE transport.
Args:
server_url: The URL of the SSE server to connect to.
timeout: Connection timeout in seconds. Default is 30 seconds.
"""
if not server_url:
raise ValueError('Server URL is required.')
if self.session:
await self.disconnect()
try:
streams_context = sse_client(
url=server_url,
)
streams = await self.exit_stack.enter_async_context(streams_context)
self.session = await self.exit_stack.enter_async_context(
ClientSession(*streams)
)
await self._initialize_and_list_tools()
except Exception as e:
logger.error(f'Error connecting to {server_url}: {str(e)}')
raise
async def _initialize_and_list_tools(self) -> None:
"""Initialize session and populate tool map."""
if not self.session:
raise RuntimeError('Session not initialized.')
await self.session.initialize()
response = await self.session.list_tools()
# Clear existing tools
self.tools = []
# Create proper tool objects for each server tool
for tool in response.tools:
server_tool = MCPClientTool(
name=tool.name,
description=tool.description,
inputSchema=tool.inputSchema,
session=self.session,
)
self.tool_map[tool.name] = server_tool
self.tools.append(server_tool)
logger.info(
f'Connected to server with tools: {[tool.name for tool in response.tools]}'
)
async def call_tool(self, tool_name: str, args: Dict):
"""Call a tool on the MCP server."""
if tool_name not in self.tool_map:
raise ValueError(f'Tool {tool_name} not found.')
return await self.tool_map[tool_name].execute(**args)
async def disconnect(self) -> None:
"""Disconnect from the MCP server and clean up resources."""
if self.session:
try:
# Close the session first
if hasattr(self.session, 'close'):
await self.session.close()
# Then close the exit stack
await self.exit_stack.aclose()
except Exception as e:
logger.error(f'Error during disconnect: {str(e)}')
finally:
self.session = None
self.tools = []
logger.info('Disconnected from MCP server')

54
openhands/mcp/tool.py Normal file
View File

@@ -0,0 +1,54 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional
from mcp import ClientSession
from mcp.types import CallToolResult, TextContent, Tool
class BaseTool(ABC, Tool):
@classmethod
def postfix(cls) -> str:
return '_mcp_tool_call'
class Config:
arbitrary_types_allowed = True
@abstractmethod
async def execute(self, **kwargs) -> CallToolResult:
"""Execute the tool with given parameters."""
def to_param(self) -> Dict:
"""Convert tool to function call format."""
return {
'type': 'function',
'function': {
'name': self.name + self.postfix(),
'description': self.description,
'parameters': self.inputSchema,
},
}
class MCPClientTool(BaseTool):
"""Represents a tool proxy that can be called on the MCP server from the client side."""
session: Optional[ClientSession] = None
async def execute(self, **kwargs) -> CallToolResult:
"""Execute the tool by making a remote call to the MCP server."""
if not self.session:
return CallToolResult(
content=[TextContent(text='Not connected to MCP server', type='text')],
isError=True,
)
try:
result = await self.session.call_tool(self.name, kwargs)
return result
except Exception as e:
return CallToolResult(
content=[
TextContent(text=f'Error executing tool: {str(e)}', type='text')
],
isError=True,
)

135
openhands/mcp/utils.py Normal file
View File

@@ -0,0 +1,135 @@
import json
from openhands.core.config.mcp_config import MCPConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.mcp import McpAction
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.mcp.client import MCPClient
def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[dict]:
"""
Converts a list of MCPClient instances to ChatCompletionToolParam format
that can be used by CodeActAgent.
Args:
mcp_clients: List of MCPClient instances or None
Returns:
List of dicts of tools ready to be used by CodeActAgent
"""
if mcp_clients is None:
logger.warning('mcp_clients is None, returning empty list')
return []
all_mcp_tools = []
try:
for client in mcp_clients:
# Each MCPClient has an mcp_clients property that is a ToolCollection
# The ToolCollection has a to_params method that converts tools to ChatCompletionToolParam format
for tool in client.tools:
mcp_tools = tool.to_param()
all_mcp_tools.append(mcp_tools)
except Exception as e:
logger.error(f'Error in convert_mcp_clients_to_tools: {e}')
return []
return all_mcp_tools
async def create_mcp_clients(
sse_mcp_server: list[str],
) -> list[MCPClient]:
mcp_clients: list[MCPClient] = []
# Initialize SSE connections
if sse_mcp_server:
for server_url in sse_mcp_server:
logger.info(
f'Initializing MCP agent for {server_url} with SSE connection...'
)
client = MCPClient()
try:
await client.connect_sse(server_url)
# Only add the client to the list after a successful connection
mcp_clients.append(client)
logger.info(f'Connected to MCP server {server_url} via SSE')
except Exception as e:
logger.error(f'Failed to connect to {server_url}: {str(e)}')
try:
await client.disconnect()
except Exception as disconnect_error:
logger.error(
f'Error during disconnect after failed connection: {str(disconnect_error)}'
)
return mcp_clients
async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
"""
Retrieves the list of MCP tools from the MCP clients.
Returns:
A list of tool dictionaries. Returns an empty list if no connections could be established.
"""
mcp_clients = []
mcp_tools = []
try:
logger.debug(f'Creating MCP clients with config: {mcp_config}')
mcp_clients = await create_mcp_clients(
mcp_config.sse.mcp_servers,
)
if not mcp_clients:
logger.warning('No MCP clients were successfully connected')
return []
mcp_tools = convert_mcp_clients_to_tools(mcp_clients)
# Always disconnect clients to clean up resources
for mcp_client in mcp_clients:
try:
await mcp_client.disconnect()
except Exception as disconnect_error:
logger.error(f'Error disconnecting MCP client: {str(disconnect_error)}')
except Exception as e:
logger.error(f'Error fetching MCP tools: {str(e)}')
return []
logger.debug(f'MCP tools: {mcp_tools}')
return mcp_tools
async def call_tool_mcp(mcp_clients: list[MCPClient], action: McpAction) -> Observation:
"""
Call a tool on an MCP server and return the observation.
Args:
action: The MCP action to execute
sse_mcp_servers: List of SSE MCP server URLs
Returns:
The observation from the MCP server
"""
if not mcp_clients:
raise ValueError('No MCP clients found')
logger.debug(f'MCP action received: {action}')
# Find the MCP agent that has the matching tool name
matching_client = None
logger.debug(f'MCP clients: {mcp_clients}')
logger.debug(f'MCP action name: {action.name}')
for client in mcp_clients:
logger.debug(f'MCP client tools: {client.tools}')
if action.name in [tool.name for tool in client.tools]:
matching_client = client
break
if matching_client is None:
raise ValueError(f'No matching MCP agent found for tool name: {action.name}')
logger.debug(f'Matching client: {matching_client}')
args_dict = json.loads(action.arguments) if action.arguments else {}
response = await matching_client.call_tool(action.name, args_dict)
logger.debug(f'MCP response: {response}')
return MCPObservation(content=f'MCP result:{response.model_dump(mode="json")}')

View File

@@ -19,6 +19,7 @@ from openhands.events.action import (
IPythonRunCellAction,
MessageAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.event import Event, RecallType
from openhands.events.observation import (
AgentCondensationObservation,
@@ -36,6 +37,7 @@ from openhands.events.observation.agent import (
RecallObservation,
)
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import truncate_content
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
@@ -167,7 +169,7 @@ class ConversationMemory:
- BrowseInteractiveAction: For browsing the web
- AgentFinishAction: For ending the interaction
- MessageAction: For sending messages
- McpAction: For interacting with the MCP server
pending_tool_call_action_messages: Dictionary mapping response IDs to their corresponding messages.
Used in function calling mode to track tool calls that are waiting for their results.
@@ -193,6 +195,7 @@ class ConversationMemory:
FileReadAction,
BrowseInteractiveAction,
BrowseURLAction,
McpAction,
),
) or (isinstance(action, CmdRunAction) and action.source == 'agent'):
tool_metadata = action.tool_call_metadata
@@ -326,6 +329,10 @@ class ConversationMemory:
else:
text = truncate_content(obs.to_agent_observation(), max_message_chars)
message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, MCPObservation):
# logger.warning(f'MCPObservation: {obs}')
text = truncate_content(obs.content, max_message_chars)
message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, IPythonRunCellObservation):
text = obs.content
# replace base64 images with a placeholder

View File

@@ -257,6 +257,7 @@ class ActionExecutor:
logger.debug('Initializing bash commands')
await self._init_bash_commands()
logger.debug('Runtime client initialized.')
self._initialized = True
@@ -299,9 +300,7 @@ class ActionExecutor:
async def run_action(self, action) -> Observation:
async with self.lock:
action_type = action.action
logger.debug(f'Running action:\n{action}')
observation = await getattr(self, action_type)(action)
logger.debug(f'Action output:\n{observation}')
return observation
async def run(
@@ -515,6 +514,7 @@ class ActionExecutor:
if __name__ == '__main__':
logger.warning('Starting Action Execution Server')
parser = argparse.ArgumentParser()
parser.add_argument('port', type=int, help='Port to listen on')
parser.add_argument('--working-dir', type=str, help='Working directory')
@@ -529,6 +529,7 @@ if __name__ == '__main__':
help='BrowserGym environment used for browser evaluation',
default=None,
)
# example: python client.py 8000 --working-dir /workspace --plugins JupyterRequirement
args = parser.parse_args()
@@ -626,6 +627,7 @@ if __name__ == '__main__':
if not isinstance(action, Action):
raise HTTPException(status_code=400, detail='Invalid action type')
client.last_execution_time = time.time()
observation = await client.run_action(action)
return event_to_dict(observation)
except Exception as e:

View File

@@ -31,6 +31,7 @@ from openhands.events.action import (
FileWriteAction,
IPythonRunCellAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.event import Event
from openhands.events.observation import (
AgentThinkObservation,
@@ -298,9 +299,11 @@ class Runtime(FileEditRuntimeMixin):
assert event.timeout is not None
try:
await self._export_latest_git_provider_tokens(event)
observation: Observation = await call_sync_from_async(
self.run_action, event
)
if isinstance(event, McpAction):
# we don't call call_tool_mcp impl directly because there can be other action ActionExecutionClient
observation: Observation = await getattr(self, McpAction.action)(event)
else:
observation = await call_sync_from_async(self.run_action, event)
except Exception as e:
err_id = ''
if isinstance(e, httpx.NetworkError) or isinstance(
@@ -562,6 +565,10 @@ class Runtime(FileEditRuntimeMixin):
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
pass
@abstractmethod
async def call_tool_mcp(self, action: McpAction) -> Observation:
pass
# ====================================================================
# File operations
# ====================================================================

View File

@@ -28,6 +28,7 @@ from openhands.events.action import (
)
from openhands.events.action.action import Action
from openhands.events.action.files import FileEditSource
from openhands.events.action.mcp import McpAction
from openhands.events.observation import (
AgentThinkObservation,
ErrorObservation,
@@ -38,11 +39,13 @@ from openhands.events.observation import (
from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.mcp import call_tool_mcp as call_tool_mcp_handler, create_mcp_clients, MCPClient
from openhands.runtime.base import Runtime
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.request import send_request
from openhands.utils.http_session import HttpSession
from openhands.utils.tenacity_stop import stop_if_should_exit
from openhands.utils.async_utils import call_async_from_sync
def _is_retryable_error(exception):
@@ -76,6 +79,7 @@ class ActionExecutionClient(Runtime):
self._runtime_initialized: bool = False
self._runtime_closed: bool = False
self._vscode_token: str | None = None # initial dummy value
self.mcp_clients: list[MCPClient] | None = None
super().__init__(
config,
event_stream,
@@ -278,10 +282,13 @@ class ActionExecutionClient(Runtime):
assert action.timeout is not None
try:
execution_action_body: dict[str, Any] = {
'action': event_to_dict(action),
}
response = self._send_action_server_request(
'POST',
f'{self._get_action_execution_server_host()}/execute_action',
json={'action': event_to_dict(action)},
json=execution_action_body,
# wait a few more seconds to get the timeout error from client side
timeout=action.timeout + 5,
)
@@ -316,6 +323,19 @@ class ActionExecutionClient(Runtime):
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
return self.send_action_for_execution(action)
async def call_tool_mcp(self, action: McpAction) -> Observation:
if self.mcp_clients is None:
self.log('debug', f'Creating MCP clients with servers: {self.config.mcp.sse.mcp_servers}')
self.mcp_clients = await create_mcp_clients(
self.config.mcp.sse.mcp_servers
)
return await call_tool_mcp_handler(self.mcp_clients, action)
async def aclose(self) -> None:
if self.mcp_clients:
for client in self.mcp_clients:
await client.disconnect()
def close(self) -> None:
# Make sure we don't close the session multiple times
# Can happen in evaluation
@@ -323,3 +343,4 @@ class ActionExecutionClient(Runtime):
return
self._runtime_closed = True
self.session.close()
call_async_from_sync(self.aclose)

View File

@@ -420,22 +420,24 @@ class StandaloneConversationManager(ConversationManager):
conversation_store = await self._get_conversation_store(user_id, github_user_id)
conversation = await conversation_store.get_metadata(conversation_id)
conversation.last_updated_at = datetime.now(timezone.utc)
# Update cost/token metrics if event has llm_metrics
if event and hasattr(event, 'llm_metrics') and event.llm_metrics:
metrics = event.llm_metrics
# Update accumulated cost
if hasattr(metrics, 'accumulated_cost'):
conversation.accumulated_cost = metrics.accumulated_cost
# Update token usage
if hasattr(metrics, 'accumulated_token_usage'):
token_usage = metrics.accumulated_token_usage
conversation.prompt_tokens = token_usage.prompt_tokens
conversation.completion_tokens = token_usage.completion_tokens
conversation.total_tokens = token_usage.prompt_tokens + token_usage.completion_tokens
conversation.total_tokens = (
token_usage.prompt_tokens + token_usage.completion_tokens
)
await conversation_store.save_metadata(conversation)

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from pydantic import SecretStr
from openhands.server.shared import server_config
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
@@ -16,7 +16,7 @@ from openhands.integrations.service_types import (
User,
)
from openhands.server.auth import get_access_token, get_provider_tokens
from openhands.server.types import AppMode
from openhands.server.shared import server_config
app = APIRouter(prefix='/api/user')
@@ -33,7 +33,9 @@ async def get_user_repositories(
)
try:
repos: list[Repository] = await client.get_repositories(sort, server_config.app_mode)
repos: list[Repository] = await client.get_repositories(
sort, server_config.app_mode
)
return repos
except AuthenticationError as e:

View File

@@ -112,7 +112,9 @@ async def _create_new_conversation(
title=conversation_title,
user_id=user_id,
github_user_id=None,
selected_repository=selected_repository.full_name if selected_repository else selected_repository,
selected_repository=selected_repository.full_name
if selected_repository
else selected_repository,
selected_branch=selected_branch,
)
)

View File

@@ -21,6 +21,7 @@ from openhands.events.observation.error import ErrorObservation
from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.mcp import fetch_mcp_tools_from_config
from openhands.server.session.agent_session import AgentSession
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.settings import Settings
@@ -132,7 +133,9 @@ class Session:
self.logger.info(f'Enabling default condenser: {default_condenser_config}')
agent_config.condenser = default_condenser_config
mcp_tools = await fetch_mcp_tools_from_config(self.config.mcp)
agent = Agent.get_cls(agent_cls)(llm, agent_config)
agent.set_mcp_tools(mcp_tools)
git_provider_tokens = None
selected_repository = None

View File

@@ -6,7 +6,9 @@ from openhands.utils.import_utils import get_impl
class ConversationValidator:
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
async def validate(self, conversation_id: str, cookies_str: str) -> tuple[None, None]:
async def validate(
self, conversation_id: str, cookies_str: str
) -> tuple[None, None]:
return None, None

66
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -3374,6 +3374,18 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "httpx-sse"
version = "0.4.0"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"},
{file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"},
]
[[package]]
name = "huggingface-hub"
version = "0.29.2"
@@ -4799,6 +4811,33 @@ files = [
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
]
[[package]]
name = "mcp"
version = "1.4.1"
description = "Model Context Protocol SDK"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "mcp-1.4.1-py3-none-any.whl", hash = "sha256:a7716b1ec1c054e76f49806f7d96113b99fc1166fc9244c2c6f19867cb75b593"},
{file = "mcp-1.4.1.tar.gz", hash = "sha256:b9655d2de6313f9d55a7d1df62b3c3fe27a530100cc85bf23729145b0dba4c7a"},
]
[package.dependencies]
anyio = ">=4.5"
httpx = ">=0.27"
httpx-sse = ">=0.4"
pydantic = ">=2.7.2,<3.0.0"
pydantic-settings = ">=2.5.2"
sse-starlette = ">=1.6.1"
starlette = ">=0.27"
uvicorn = ">=0.23.1"
[package.extras]
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.12.4)"]
rich = ["rich (>=13.9.4)"]
ws = ["websockets (>=15.0.1)"]
[[package]]
name = "mdurl"
version = "0.1.2"
@@ -6620,6 +6659,27 @@ files = [
[package.dependencies]
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pydantic-settings"
version = "2.8.1"
description = "Settings management using Pydantic"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "pydantic_settings-2.8.1-py3-none-any.whl", hash = "sha256:81942d5ac3d905f7f3ee1a70df5dfb62d5569c12f51a5a647defc1c3d9ee2e9c"},
{file = "pydantic_settings-2.8.1.tar.gz", hash = "sha256:d5c663dfbe9db9d5e1c646b2e161da12f0d734d422ee56f567d0ea2cee4e8585"},
]
[package.dependencies]
pydantic = ">=2.7.0"
python-dotenv = ">=0.21.0"
[package.extras]
azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"]
toml = ["tomli (>=2.0.1)"]
yaml = ["pyyaml (>=6.0.1)"]
[[package]]
name = "pydeck"
version = "0.9.1"
@@ -9216,7 +9276,7 @@ description = "A language and compiler for custom Deep Learning operations"
optional = false
python-versions = "*"
groups = ["evaluation"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version == \"3.12\""
files = [
{file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"},
{file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"},
@@ -10197,4 +10257,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.1"
python-versions = "^3.12"
content-hash = "4081b88d1b970aa56603359e41430d4465486b3866bfd50371d5c4f77fb58fb4"
content-hash = "a11f74b159928e0b8133985e6d87ae272e5dea771e27cb2d738feed8f811e0a6"

View File

@@ -72,6 +72,7 @@ ipywidgets = "^8.1.5"
qtconsole = "^5.6.1"
memory-profiler = "^0.61.0"
daytona-sdk = "0.12.1"
mcp = "1.4.1"
python-json-logger = "^3.2.1"
playwright = "^1.51.0"
prompt-toolkit = "^3.0.50"

View File

@@ -229,11 +229,12 @@ def test_ctrl_c():
# Send Ctrl+C
obs = session.execute(CmdRunAction('C-c', is_input=True))
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.metadata.exit_code == 130 # Standard exit code for Ctrl+C
assert (
obs.metadata.suffix
== '\n[The command completed with exit code 130. CTRL+C was sent.]'
)
# Check that the process was interrupted (exit code can be 1 or 130 depending on the shell/OS)
assert obs.metadata.exit_code in (
1,
130,
) # Accept both common exit codes for interrupted processes
assert 'CTRL+C was sent' in obs.metadata.suffix
assert obs.metadata.prefix == ''
assert session.prev_status == BashCommandStatus.COMPLETED

View File

@@ -41,7 +41,8 @@ def test_json_encoder_memory_leak():
min_memory = min(memory_samples)
memory_variation = max_memory - min_memory
# Allow for some memory variation (2MB) due to Python's memory management
# Allow for more memory variation (2MB) due to Python's memory management
# The standard library's json module may use more memory than expected
assert (
memory_variation < 2 * 1024 * 1024
), f'Memory usage unstable: {memory_variation} bytes variation'

View File

@@ -0,0 +1,63 @@
import pytest
from openhands.core.config.mcp_config import MCPConfig, MCPSSEConfig
def test_valid_sse_config():
"""Test a valid SSE configuration."""
config = MCPSSEConfig(mcp_servers=['http://server1:8080', 'http://server2:8080'])
config.validate_servers() # Should not raise any exception
def test_empty_sse_config():
"""Test SSE configuration with empty servers list."""
config = MCPSSEConfig(mcp_servers=[])
config.validate_servers()
def test_invalid_sse_url():
"""Test SSE configuration with invalid URL format."""
config = MCPSSEConfig(mcp_servers=['not_a_url'])
with pytest.raises(ValueError) as exc_info:
config.validate_servers()
assert 'Invalid URL' in str(exc_info.value)
def test_duplicate_sse_urls():
"""Test SSE configuration with duplicate server URLs."""
config = MCPSSEConfig(mcp_servers=['http://server1:8080', 'http://server1:8080'])
with pytest.raises(ValueError) as exc_info:
config.validate_servers()
assert 'Duplicate MCP server URLs are not allowed' in str(exc_info.value)
def test_from_toml_section_valid():
"""Test creating config from valid TOML section."""
data = {
'mcp_servers': ['http://server1:8080'],
}
result = MCPConfig.from_toml_section(data)
assert 'mcp' in result
assert result['mcp'].sse.mcp_servers == ['http://server1:8080']
def test_from_toml_section_invalid_sse():
"""Test creating config from TOML section with invalid SSE URL."""
data = {
'mcp_servers': ['not_a_url'],
}
with pytest.raises(ValueError) as exc_info:
MCPConfig.from_toml_section(data)
assert 'Invalid URL' in str(exc_info.value)
def test_complex_urls():
"""Test SSE configuration with complex URLs."""
config = MCPSSEConfig(
mcp_servers=[
'https://user:pass@server1:8080/path?query=1',
'wss://server2:8443/ws',
'http://subdomain.example.com:9090',
]
)
config.validate_servers() # Should not raise any exception

View File

@@ -0,0 +1,83 @@
import asyncio
from unittest import mock
import pytest
from openhands.core.config.mcp_config import MCPConfig, MCPSSEConfig
from openhands.mcp import MCPClient, create_mcp_clients, fetch_mcp_tools_from_config
@pytest.mark.asyncio
async def test_sse_connection_timeout():
"""Test that SSE connection timeout is handled gracefully."""
# Create a mock MCPClient
mock_client = mock.MagicMock(spec=MCPClient)
# Configure the mock to raise a TimeoutError when connect_sse is called
async def mock_connect_sse(*args, **kwargs):
await asyncio.sleep(0.1) # Simulate some delay
raise asyncio.TimeoutError('Connection timed out')
mock_client.connect_sse.side_effect = mock_connect_sse
mock_client.disconnect = mock.AsyncMock()
# Mock the MCPClient constructor to return our mock
with mock.patch('openhands.mcp.utils.MCPClient', return_value=mock_client):
# Create a list of server URLs to test
sse_servers = ['http://server1:8080', 'http://server2:8080']
# Call create_mcp_clients with the server URLs
clients = await create_mcp_clients(sse_mcp_server=sse_servers)
# Verify that no clients were successfully connected
assert len(clients) == 0
# Verify that connect_sse was called for each server
assert mock_client.connect_sse.call_count == 2
# Verify that disconnect was called for each failed connection
assert mock_client.disconnect.call_count == 2
@pytest.mark.asyncio
async def test_fetch_mcp_tools_with_timeout():
"""Test that fetch_mcp_tools_from_config handles timeouts gracefully."""
# Create a mock MCPConfig
mock_config = mock.MagicMock(spec=MCPConfig)
mock_config.sse = mock.MagicMock(spec=MCPSSEConfig)
# Configure the mock config
mock_config.sse.mcp_servers = ['http://server1:8080']
# Mock create_mcp_clients to return an empty list (simulating all connections failing)
with mock.patch('openhands.mcp.utils.create_mcp_clients', return_value=[]):
# Call fetch_mcp_tools_from_config
tools = await fetch_mcp_tools_from_config(mock_config)
# Verify that an empty list of tools is returned
assert tools == []
@pytest.mark.asyncio
async def test_mixed_connection_results():
"""Test that fetch_mcp_tools_from_config returns tools even when some connections fail."""
# Create a mock MCPConfig
mock_config = mock.MagicMock(spec=MCPConfig)
mock_config.sse = mock.MagicMock(spec=MCPSSEConfig)
# Configure the mock config
mock_config.sse.mcp_servers = ['http://server1:8080', 'http://server2:8080']
# Create a successful client
successful_client = mock.MagicMock(spec=MCPClient)
successful_client.tools = [mock.MagicMock()]
# Mock create_mcp_clients to return our successful client
with mock.patch(
'openhands.mcp.utils.create_mcp_clients', return_value=[successful_client]
):
# Call fetch_mcp_tools_from_config
tools = await fetch_mcp_tools_from_config(mock_config)
# Verify that tools were returned
assert len(tools) > 0

View File

@@ -52,6 +52,9 @@ class TestRuntime(Runtime):
def run_action(self, action: Action) -> Observation:
return NullObservation()
def call_tool_mcp(self, action):
return NullObservation()
@pytest.fixture
def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str: