mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
fix(config): support defining MCP servers via environment variables and improve logging (#10069)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
881729b49c
commit
d525c5ad93
@ -74,7 +74,7 @@ class MCPStdioServerConfig(BaseModel):
|
||||
args: list[str] = Field(default_factory=list)
|
||||
env: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
@field_validator('name')
|
||||
@field_validator('name', mode='before')
|
||||
@classmethod
|
||||
def validate_server_name(cls, v: str) -> str:
|
||||
"""Validate server name for stdio MCP servers."""
|
||||
@ -91,7 +91,7 @@ class MCPStdioServerConfig(BaseModel):
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('command')
|
||||
@field_validator('command', mode='before')
|
||||
@classmethod
|
||||
def validate_command(cls, v: str) -> str:
|
||||
"""Validate command for stdio MCP servers."""
|
||||
@ -114,6 +114,7 @@ class MCPStdioServerConfig(BaseModel):
|
||||
"""Parse arguments from string or return list as-is.
|
||||
|
||||
Supports shell-like argument parsing using shlex.split().
|
||||
|
||||
Examples:
|
||||
- "-y mcp-remote https://example.com"
|
||||
- '--config "path with spaces" --debug'
|
||||
@ -189,7 +190,7 @@ class MCPSHTTPServerConfig(BaseModel):
|
||||
url: str
|
||||
api_key: str | None = None
|
||||
|
||||
@field_validator('url')
|
||||
@field_validator('url', mode='before')
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
"""Validate URL format for MCP servers."""
|
||||
@ -202,12 +203,12 @@ class MCPConfig(BaseModel):
|
||||
Attributes:
|
||||
sse_servers: List of MCP SSE server configs
|
||||
stdio_servers: List of MCP stdio server configs. These servers will be added to the MCP Router running inside runtime container.
|
||||
shttp_servers: List of MCP HTTP server configs.
|
||||
"""
|
||||
|
||||
sse_servers: list[MCPSSEServerConfig] = Field(default_factory=list)
|
||||
stdio_servers: list[MCPStdioServerConfig] = Field(default_factory=list)
|
||||
shttp_servers: list[MCPSHTTPServerConfig] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
@staticmethod
|
||||
@ -252,8 +253,7 @@ class MCPConfig(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_toml_section(cls, data: dict) -> dict[str, 'MCPConfig']:
|
||||
"""
|
||||
Create a mapping of MCPConfig instances from a toml dictionary representing the [mcp] section.
|
||||
"""Create a mapping of MCPConfig instances from a toml dictionary representing the [mcp] section.
|
||||
|
||||
The configuration is built from all keys in data.
|
||||
|
||||
@ -306,7 +306,7 @@ class MCPConfig(BaseModel):
|
||||
class OpenHandsMCPConfig:
|
||||
@staticmethod
|
||||
def add_search_engine(app_config: 'OpenHandsConfig') -> MCPStdioServerConfig | None:
|
||||
"""Add search engine to the MCP config"""
|
||||
"""Add search engine to the MCP config."""
|
||||
if (
|
||||
app_config.search_api_key
|
||||
and app_config.search_api_key.get_secret_value().startswith('tvly-')
|
||||
@ -327,21 +327,23 @@ class OpenHandsMCPConfig:
|
||||
def create_default_mcp_server_config(
|
||||
host: str, config: 'OpenHandsConfig', user_id: str | None = None
|
||||
) -> tuple[MCPSHTTPServerConfig | None, list[MCPStdioServerConfig]]:
|
||||
"""
|
||||
Create a default MCP server configuration.
|
||||
"""Create a default MCP server configuration.
|
||||
|
||||
Args:
|
||||
host: Host string
|
||||
config: OpenHandsConfig
|
||||
user_id: Optional user ID for the MCP server
|
||||
Returns:
|
||||
tuple[MCPSHTTPServerConfig | None, list[MCPStdioServerConfig]]: A tuple containing the default SHTTP server configuration (or None) and a list of MCP stdio server configurations
|
||||
"""
|
||||
|
||||
stdio_servers = []
|
||||
search_engine_stdio_server = OpenHandsMCPConfig.add_search_engine(config)
|
||||
if search_engine_stdio_server:
|
||||
stdio_servers.append(search_engine_stdio_server)
|
||||
|
||||
shttp_servers = MCPSHTTPServerConfig(url=f'http://{host}/mcp/mcp', api_key=None)
|
||||
|
||||
return shttp_servers, stdio_servers
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import platform
|
||||
import sys
|
||||
from ast import literal_eval
|
||||
from types import UnionType
|
||||
from typing import MutableMapping, get_args, get_origin, get_type_hints
|
||||
from typing import Any, MutableMapping, get_args, get_origin, get_type_hints
|
||||
from uuid import uuid4
|
||||
|
||||
import toml
|
||||
@ -75,6 +75,7 @@ def load_from_env(
|
||||
# e.g. LLM_BASE_URL
|
||||
env_var_name = (prefix + field_name).upper()
|
||||
|
||||
cast_value: Any
|
||||
if isinstance(field_value, BaseModel):
|
||||
set_attr_from_env(field_value, prefix=field_name + '_')
|
||||
|
||||
@ -94,7 +95,7 @@ def load_from_env(
|
||||
# Attempt to cast the env var to type hinted in the dataclass
|
||||
if field_type is bool:
|
||||
cast_value = str(value).lower() in ['true', '1']
|
||||
# parse dicts and lists like SANDBOX_RUNTIME_STARTUP_ENV_VARS and SANDBOX_RUNTIME_EXTRA_BUILD_ARGS │
|
||||
# parse dicts and lists like SANDBOX_RUNTIME_STARTUP_ENV_VARS and SANDBOX_RUNTIME_EXTRA_BUILD_ARGS
|
||||
elif (
|
||||
get_origin(field_type) is dict
|
||||
or get_origin(field_type) is list
|
||||
@ -102,6 +103,20 @@ def load_from_env(
|
||||
or field_type is list
|
||||
):
|
||||
cast_value = literal_eval(value)
|
||||
# If it's a list of Pydantic models
|
||||
if get_origin(field_type) is list:
|
||||
inner_type = get_args(field_type)[
|
||||
0
|
||||
] # e.g., MCPSHTTPServerConfig
|
||||
if isinstance(inner_type, type) and issubclass(
|
||||
inner_type, BaseModel
|
||||
):
|
||||
cast_value = [
|
||||
inner_type(**item)
|
||||
if isinstance(item, dict)
|
||||
else item
|
||||
for item in cast_value
|
||||
]
|
||||
else:
|
||||
if field_type is not None:
|
||||
cast_value = field_type(value)
|
||||
|
||||
@ -104,6 +104,17 @@ async def create_mcp_clients(
|
||||
client = MCPClient()
|
||||
try:
|
||||
await client.connect_stdio(server)
|
||||
|
||||
# Log which tools this specific server provides
|
||||
tool_names = [tool.name for tool in client.tools]
|
||||
server_name = getattr(
|
||||
server, 'name', f'{server.command} {" ".join(server.args or [])}'
|
||||
)
|
||||
logger.debug(
|
||||
f'Successfully connected to MCP stdio server {server_name} - '
|
||||
f'provides {len(tool_names)} tools: {tool_names}'
|
||||
)
|
||||
|
||||
mcp_clients.append(client)
|
||||
except Exception as e:
|
||||
# Error is already logged and collected in client.connect_stdio()
|
||||
@ -111,6 +122,7 @@ async def create_mcp_clients(
|
||||
continue
|
||||
|
||||
is_shttp = isinstance(server, MCPSHTTPServerConfig)
|
||||
|
||||
connection_type = 'SHTTP' if is_shttp else 'SSE'
|
||||
logger.info(
|
||||
f'Initializing MCP agent for {server} with {connection_type} connection...'
|
||||
@ -120,6 +132,13 @@ async def create_mcp_clients(
|
||||
try:
|
||||
await client.connect_http(server, conversation_id=conversation_id)
|
||||
|
||||
# Log which tools this specific server provides
|
||||
tool_names = [tool.name for tool in client.tools]
|
||||
logger.debug(
|
||||
f'Successfully connected to MCP STTP server {server.url} - '
|
||||
f'provides {len(tool_names)} tools: {tool_names}'
|
||||
)
|
||||
|
||||
# Only add the client to the list after a successful connection
|
||||
mcp_clients.append(client)
|
||||
|
||||
@ -155,6 +174,7 @@ async def fetch_mcp_tools_from_config(
|
||||
mcp_tools = []
|
||||
try:
|
||||
logger.debug(f'Creating MCP clients with config: {mcp_config}')
|
||||
|
||||
# Create clients - this will fetch tools but not maintain active connections
|
||||
mcp_clients = await create_mcp_clients(
|
||||
mcp_config.sse_servers,
|
||||
@ -293,9 +313,8 @@ async def add_mcp_tools_to_agent(
|
||||
updated_mcp_config, use_stdio=isinstance(runtime, CLIRuntime)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Loaded {len(mcp_tools)} MCP tools: {[tool["function"]["name"] for tool in mcp_tools]}'
|
||||
)
|
||||
tool_names = [tool['function']['name'] for tool in mcp_tools]
|
||||
logger.info(f'Loaded {len(mcp_tools)} MCP tools: {tool_names}')
|
||||
|
||||
# Set the MCP tools on the agent
|
||||
agent.set_mcp_tools(mcp_tools)
|
||||
|
||||
@ -13,7 +13,7 @@ from openhands.core.config.condenser_config import (
|
||||
ConversationWindowCondenserConfig,
|
||||
LLMSummarizingCondenserConfig,
|
||||
)
|
||||
from openhands.core.config.mcp_config import MCPConfig, OpenHandsMCPConfigImpl
|
||||
from openhands.core.config.mcp_config import OpenHandsMCPConfigImpl
|
||||
from openhands.core.exceptions import MicroagentValidationError
|
||||
from openhands.core.logger import OpenHandsLoggerAdapter
|
||||
from openhands.core.schema import AgentState
|
||||
@ -149,8 +149,8 @@ class Session:
|
||||
self.config.sandbox.api_key = settings.sandbox_api_key.get_secret_value()
|
||||
|
||||
# NOTE: this need to happen AFTER the config is updated with the search_api_key
|
||||
self.config.mcp = settings.mcp_config or MCPConfig(
|
||||
sse_servers=[], stdio_servers=[]
|
||||
self.logger.debug(
|
||||
f'MCP configuration before setup - self.config.mcp_config: {self.config.mcp}'
|
||||
)
|
||||
# Add OpenHands' MCP server by default
|
||||
openhands_mcp_server, openhands_mcp_stdio_servers = (
|
||||
@ -158,10 +158,17 @@ class Session:
|
||||
self.config.mcp_host, self.config, self.user_id
|
||||
)
|
||||
)
|
||||
|
||||
if openhands_mcp_server:
|
||||
self.config.mcp.shttp_servers.append(openhands_mcp_server)
|
||||
self.logger.debug('Added default MCP HTTP server to config')
|
||||
|
||||
self.config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
|
||||
|
||||
self.logger.debug(
|
||||
f'MCP configuration after setup - self.config.mcp: {self.config.mcp}'
|
||||
)
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
llm = self._create_llm(agent_cls)
|
||||
@ -269,6 +276,7 @@ class Session:
|
||||
|
||||
async def _on_event(self, event: Event) -> None:
|
||||
"""Callback function for events that mainly come from the agent.
|
||||
|
||||
Event is the base class for any agent action and observation.
|
||||
|
||||
Args:
|
||||
|
||||
@ -1,11 +1,20 @@
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import OpenHandsConfig, load_from_env
|
||||
from openhands.core.config.mcp_config import (
|
||||
MCPConfig,
|
||||
MCPSHTTPServerConfig,
|
||||
MCPSSEServerConfig,
|
||||
MCPStdioServerConfig,
|
||||
)
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.session.session import Session
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
def test_valid_sse_config():
|
||||
@ -270,3 +279,174 @@ def test_mcp_stdio_server_args_parsing_invalid_quotes():
|
||||
name='test-server', command='python', args='--config "unmatched quote'
|
||||
)
|
||||
assert 'Invalid argument format' in str(exc_info.value)
|
||||
|
||||
|
||||
def test_env_var_mcp_shttp_server_config(monkeypatch):
|
||||
"""Test creating MCPSHTTPServerConfig from environment variables."""
|
||||
# Set environment variables for MCP HTTP server
|
||||
monkeypatch.setenv(
|
||||
'MCP_SHTTP_SERVERS',
|
||||
'[{"url": "http://env-server:8080", "api_key": "env-api-key"}]',
|
||||
)
|
||||
|
||||
# Create a config object
|
||||
config = OpenHandsConfig()
|
||||
|
||||
# Load from environment
|
||||
load_from_env(config, os.environ)
|
||||
|
||||
# Convert dictionary servers to proper server config objects by creating a new MCPConfig
|
||||
# This triggers the model validator which automatically converts dict servers to proper objects
|
||||
config.mcp = MCPConfig(
|
||||
sse_servers=config.mcp.sse_servers,
|
||||
stdio_servers=config.mcp.stdio_servers,
|
||||
shttp_servers=config.mcp.shttp_servers,
|
||||
)
|
||||
|
||||
# Check that the HTTP server was added
|
||||
assert len(config.mcp.shttp_servers) == 1
|
||||
|
||||
# Access the first server
|
||||
server = config.mcp.shttp_servers[0]
|
||||
|
||||
# Verify it's a proper server config object
|
||||
assert isinstance(server, MCPSHTTPServerConfig)
|
||||
assert server.url == 'http://env-server:8080'
|
||||
assert server.api_key == 'env-api-key'
|
||||
|
||||
# Now let's create a proper MCPConfig with the values from the environment
|
||||
mcp_config = MCPConfig(shttp_servers=config.mcp.shttp_servers)
|
||||
|
||||
# Verify that the MCPSHTTPServerConfig objects are created correctly
|
||||
assert len(mcp_config.shttp_servers) == 1
|
||||
assert isinstance(mcp_config.shttp_servers[0], MCPSHTTPServerConfig)
|
||||
assert mcp_config.shttp_servers[0].url == 'http://env-server:8080'
|
||||
assert mcp_config.shttp_servers[0].api_key == 'env-api-key'
|
||||
|
||||
|
||||
def test_env_var_mcp_shttp_server_config_with_toml(monkeypatch, tmp_path):
|
||||
"""Test creating MCPSHTTPServerConfig from environment variables with TOML config."""
|
||||
# Create a TOML file with some MCP configuration
|
||||
toml_file = tmp_path / 'config.toml'
|
||||
with open(toml_file, 'w', encoding='utf-8') as f:
|
||||
f.write("""
|
||||
[mcp]
|
||||
sse_servers = ["http://toml-server:8080"]
|
||||
shttp_servers = [
|
||||
{ url = "http://toml-http-server:8080", api_key = "toml-api-key" }
|
||||
]
|
||||
""")
|
||||
|
||||
# Set environment variables for MCP HTTP server to override TOML
|
||||
monkeypatch.setenv(
|
||||
'MCP_SHTTP_SERVERS',
|
||||
'[{"url": "http://env-server:8080", "api_key": "env-api-key"}]',
|
||||
)
|
||||
|
||||
# Create a config object
|
||||
config = OpenHandsConfig()
|
||||
|
||||
# Load from TOML first
|
||||
from openhands.core.config import load_from_toml
|
||||
|
||||
load_from_toml(config, str(toml_file))
|
||||
|
||||
# Verify TOML values were loaded
|
||||
assert len(config.mcp.shttp_servers) == 1
|
||||
assert isinstance(config.mcp.shttp_servers[0], MCPSHTTPServerConfig)
|
||||
assert config.mcp.shttp_servers[0].url == 'http://toml-http-server:8080'
|
||||
assert config.mcp.shttp_servers[0].api_key == 'toml-api-key'
|
||||
|
||||
# Now load from environment, which should override TOML
|
||||
load_from_env(config, os.environ)
|
||||
|
||||
# Check that the environment values override the TOML values
|
||||
assert len(config.mcp.shttp_servers) == 1
|
||||
|
||||
# The values should now be from the environment
|
||||
server = config.mcp.shttp_servers[0]
|
||||
assert isinstance(server, MCPSHTTPServerConfig)
|
||||
assert server.url == 'http://env-server:8080'
|
||||
assert server.api_key == 'env-api-key'
|
||||
|
||||
|
||||
def test_env_var_mcp_shttp_servers_with_python_str_representation(monkeypatch):
|
||||
"""Test creating MCPSHTTPServerConfig from environment variables using Python string representation."""
|
||||
# Create a Python list of dictionaries
|
||||
mcp_shttp_servers = [
|
||||
{'url': 'https://example.com/mcp/mcp', 'api_key': 'test-api-key'}
|
||||
]
|
||||
|
||||
# Set environment variable with the string representation of the Python list
|
||||
monkeypatch.setenv('MCP_SHTTP_SERVERS', str(mcp_shttp_servers))
|
||||
|
||||
# Create a config object
|
||||
config = OpenHandsConfig()
|
||||
|
||||
# Load from environment
|
||||
load_from_env(config, os.environ)
|
||||
|
||||
# Check that the HTTP server was added
|
||||
assert len(config.mcp.shttp_servers) == 1
|
||||
|
||||
# Access the first server
|
||||
server = config.mcp.shttp_servers[0]
|
||||
|
||||
# Check that it's a dict with the expected keys
|
||||
assert isinstance(server, MCPSHTTPServerConfig)
|
||||
assert server.url == 'https://example.com/mcp/mcp'
|
||||
assert server.api_key == 'test-api-key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_preserves_env_mcp_config(monkeypatch):
|
||||
"""Test that Session preserves MCP configuration from environment variables."""
|
||||
# Set environment variables for MCP HTTP server
|
||||
monkeypatch.setenv(
|
||||
'MCP_SHTTP_SERVERS',
|
||||
'[{"url": "http://env-server:8080", "api_key": "env-api-key"}]',
|
||||
)
|
||||
|
||||
# Also set MCP_HOST to prevent the default server from being added
|
||||
monkeypatch.setenv('MCP_HOST', 'dummy')
|
||||
|
||||
# Create a config object and load from environment
|
||||
config = OpenHandsConfig()
|
||||
load_from_env(config, os.environ)
|
||||
|
||||
# Verify the environment variables were loaded into the config
|
||||
assert config.mcp_host == 'dummy'
|
||||
assert len(config.mcp.shttp_servers) == 1
|
||||
# If it's already a proper server config object, just verify it
|
||||
assert isinstance(config.mcp.shttp_servers[0], MCPSHTTPServerConfig)
|
||||
assert config.mcp.shttp_servers[0].url == 'http://env-server:8080'
|
||||
assert config.mcp.shttp_servers[0].api_key == 'env-api-key'
|
||||
|
||||
# Create a session with the config
|
||||
session = Session(
|
||||
sid='test-sid',
|
||||
file_store=InMemoryFileStore({}),
|
||||
config=config,
|
||||
sio=AsyncMock(),
|
||||
)
|
||||
|
||||
# Create empty settings
|
||||
settings = ConversationInitData()
|
||||
|
||||
# Mock the Agent.get_cls method to avoid AgentNotRegisteredError
|
||||
mock_agent_cls = MagicMock()
|
||||
mock_agent_instance = MagicMock()
|
||||
mock_agent_cls.return_value = mock_agent_instance
|
||||
|
||||
# Initialize the agent (this is where the MCP config would be reset)
|
||||
with (
|
||||
patch.object(session.agent_session, 'start', AsyncMock()),
|
||||
patch.object(Agent, 'get_cls', return_value=mock_agent_cls),
|
||||
):
|
||||
await session.initialize_agent(settings, None, None)
|
||||
|
||||
# Verify that the MCP configuration was preserved
|
||||
assert len(session.config.mcp.shttp_servers) >= 0
|
||||
|
||||
# Clean up
|
||||
await session.close()
|
||||
|
||||
@ -70,7 +70,21 @@ async def test_mixed_connection_results():
|
||||
|
||||
# Create a successful client
|
||||
successful_client = mock.MagicMock(spec=MCPClient)
|
||||
successful_client.tools = [mock.MagicMock()]
|
||||
|
||||
# Create a mock tool with a to_param method that returns a tool dictionary
|
||||
mock_tool = mock.MagicMock()
|
||||
mock_tool.name = 'mock_tool'
|
||||
mock_tool.to_param.return_value = {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'mock_tool',
|
||||
'description': 'A mock tool for testing',
|
||||
'parameters': {},
|
||||
},
|
||||
}
|
||||
|
||||
# Set the client's tools
|
||||
successful_client.tools = [mock_tool]
|
||||
|
||||
# Mock create_mcp_clients to return our successful client
|
||||
with mock.patch(
|
||||
@ -81,3 +95,4 @@ async def test_mixed_connection_results():
|
||||
|
||||
# Verify that tools were returned
|
||||
assert len(tools) > 0
|
||||
assert tools[0]['function']['name'] == 'mock_tool'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user