diff --git a/openhands/core/config/mcp_config.py b/openhands/core/config/mcp_config.py index a5fe534925..30c0b29b2f 100644 --- a/openhands/core/config/mcp_config.py +++ b/openhands/core/config/mcp_config.py @@ -1,6 +1,6 @@ from urllib.parse import urlparse -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel, Field, ValidationError, model_validator class MCPSSEServerConfig(BaseModel): @@ -44,6 +44,24 @@ class MCPConfig(BaseModel): model_config = {'extra': 'forbid'} + @staticmethod + def _normalize_sse_servers(servers_data: list[dict | str]) -> list[dict]: + """Helper method to normalize SSE server configurations.""" + normalized = [] + for server in servers_data: + if isinstance(server, str): + normalized.append({'url': server}) + else: + normalized.append(server) + return normalized + + @model_validator(mode='before') + def convert_string_urls(cls, data): + """Convert string URLs to MCPSSEServerConfig objects.""" + if isinstance(data, dict) and 'sse_servers' in data: + data['sse_servers'] = cls._normalize_sse_servers(data['sse_servers']) + return data + def validate_servers(self) -> None: """Validate that server URLs are valid and unique.""" urls = [server.url for server in self.sse_servers] @@ -77,13 +95,10 @@ class MCPConfig(BaseModel): try: # Convert all entries in sse_servers to MCPSSEServerConfig objects if 'sse_servers' in data: + data['sse_servers'] = cls._normalize_sse_servers(data['sse_servers']) servers = [] for server in data['sse_servers']: - if isinstance(server, dict): - servers.append(MCPSSEServerConfig(**server)) - else: - # Convert string URLs to MCPSSEServerConfig objects with no API key - servers.append(MCPSSEServerConfig(url=server)) + servers.append(MCPSSEServerConfig(**server)) data['sse_servers'] = servers # Convert all entries in stdio_servers to MCPStdioServerConfig objects @@ -96,6 +111,7 @@ class MCPConfig(BaseModel): # Create SSE config if present mcp_config = MCPConfig.model_validate(data) mcp_config.validate_servers() + # Create the main MCP config mcp_mapping['mcp'] = cls( sse_servers=mcp_config.sse_servers, @@ -103,5 +119,4 @@ class MCPConfig(BaseModel): ) except ValidationError as e: raise ValueError(f'Invalid MCP configuration: {e}') - return mcp_mapping