diff --git a/enterprise/integrations/github/github_view.py b/enterprise/integrations/github/github_view.py index 90027f3804..733cec6c2a 100644 --- a/enterprise/integrations/github/github_view.py +++ b/enterprise/integrations/github/github_view.py @@ -97,6 +97,9 @@ class GithubUserContext(UserContext): user_secrets = await self.secrets_store.load() return dict(user_secrets.custom_secrets) if user_secrets else {} + async def get_mcp_api_key(self) -> str | None: + raise NotImplementedError() + async def get_user_proactive_conversation_setting(user_id: str | None) -> bool: """Get the user's proactive conversation setting. diff --git a/enterprise/server/auth/saas_user_auth.py b/enterprise/server/auth/saas_user_auth.py index eafb7c5b74..d1390ae51b 100644 --- a/enterprise/server/auth/saas_user_auth.py +++ b/enterprise/server/auth/saas_user_auth.py @@ -203,6 +203,15 @@ class SaasUserAuth(UserAuth): self.settings_store = settings_store return settings_store + async def get_mcp_api_key(self) -> str: + api_key_store = ApiKeyStore.get_instance() + mcp_api_key = api_key_store.retrieve_mcp_api_key(self.user_id) + if not mcp_api_key: + mcp_api_key = api_key_store.create_api_key( + self.user_id, 'MCP_API_KEY', None + ) + return mcp_api_key + @classmethod async def get_instance(cls, request: Request) -> UserAuth: logger.debug('saas_user_auth_get_instance') diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index d1057c34d6..fff762a7dc 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -4,12 +4,12 @@ from collections import defaultdict from dataclasses import dataclass from datetime import datetime, timedelta from time import time -from typing import AsyncGenerator, Sequence +from typing import Any, AsyncGenerator, Sequence from uuid import UUID, uuid4 import httpx from fastapi import Request -from pydantic import Field, TypeAdapter +from pydantic import Field, SecretStr, TypeAdapter from openhands.agent_server.models import ( ConversationInfo, @@ -63,19 +63,27 @@ from openhands.app_server.sandbox.sandbox_spec_service import SandboxSpecService from openhands.app_server.services.injector import InjectorState from openhands.app_server.services.jwt_service import JwtService from openhands.app_server.user.user_context import UserContext +from openhands.app_server.user.user_models import UserInfo from openhands.app_server.utils.docker_utils import ( replace_localhost_hostname_for_docker, ) from openhands.experiments.experiment_manager import ExperimentManagerImpl from openhands.integrations.provider import ProviderType -from openhands.sdk import AgentContext, LocalWorkspace +from openhands.sdk import Agent, AgentContext, LocalWorkspace from openhands.sdk.conversation.secret_source import LookupSecret, StaticSecret from openhands.sdk.llm import LLM from openhands.sdk.security.confirmation_policy import AlwaysConfirm from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace from openhands.server.types import AppMode -from openhands.tools.preset.default import get_default_agent -from openhands.tools.preset.planning import get_planning_agent +from openhands.tools.preset.default import ( + get_default_condenser, + get_default_tools, +) +from openhands.tools.preset.planning import ( + format_plan_structure, + get_planning_condenser, + get_planning_tools, +) _conversation_info_type_adapter = TypeAdapter(list[ConversationInfo | None]) _logger = logging.getLogger(__name__) @@ -99,6 +107,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase): access_token_hard_timeout: timedelta | None app_mode: str | None = None keycloak_auth_cookie: str | None = None + tavily_api_key: str | None = None async def search_app_conversations( self, @@ -519,6 +528,223 @@ class LiveStatusAppConversationService(AppConversationServiceBase): if not request.llm_model and parent_info.llm_model: request.llm_model = parent_info.llm_model + async def _setup_secrets_for_git_provider( + self, git_provider: ProviderType | None, user: UserInfo + ) -> dict: + """Set up secrets for git provider authentication. + + Args: + git_provider: The git provider type (GitHub, GitLab, etc.) + user: User information containing authentication details + + Returns: + Dictionary of secrets for the conversation + """ + secrets = await self.user_context.get_secrets() + + if not git_provider: + return secrets + + secret_name = f'{git_provider.name}_TOKEN' + + if self.web_url: + # Create an access token for web-based authentication + access_token = self.jwt_service.create_jws_token( + payload={ + 'user_id': user.id, + 'provider_type': git_provider.value, + }, + expires_in=self.access_token_hard_timeout, + ) + headers = {'X-Access-Token': access_token} + + # Include keycloak_auth cookie in headers if app_mode is SaaS + if self.app_mode == 'saas' and self.keycloak_auth_cookie: + headers['Cookie'] = f'keycloak_auth={self.keycloak_auth_cookie}' + + secrets[secret_name] = LookupSecret( + url=self.web_url + '/api/v1/webhooks/secrets', + headers=headers, + ) + else: + # Use static token for environments without web URL access + static_token = await self.user_context.get_latest_token(git_provider) + if static_token: + secrets[secret_name] = StaticSecret(value=static_token) + + return secrets + + async def _configure_llm_and_mcp( + self, user: UserInfo, llm_model: str | None + ) -> tuple[LLM, dict]: + """Configure LLM and MCP (Model Context Protocol) settings. + + Args: + user: User information containing LLM preferences + llm_model: Optional specific model to use, falls back to user default + + Returns: + Tuple of (configured LLM instance, MCP config dictionary) + """ + # Configure LLM + model = llm_model or user.llm_model + llm = LLM( + model=model, + base_url=user.llm_base_url, + api_key=user.llm_api_key, + usage_id='agent', + ) + + # Configure MCP + mcp_config: dict[str, Any] = {} + if self.web_url: + mcp_url = f'{self.web_url}/mcp/mcp' + mcp_config = { + 'default': { + 'url': mcp_url, + } + } + + # Add API key if available + mcp_api_key = await self.user_context.get_mcp_api_key() + if mcp_api_key: + mcp_config['default']['headers'] = { + 'X-Session-API-Key': mcp_api_key, + } + + # Get the actual API key values, prioritizing user's key over service key + user_search_key = None + if user.search_api_key: + key_value = user.search_api_key.get_secret_value() + if key_value and key_value.strip(): + user_search_key = key_value + + service_tavily_key = None + if self.tavily_api_key: + # tavily_api_key is already a string (extracted in the factory method) + if self.tavily_api_key.strip(): + service_tavily_key = self.tavily_api_key + + tavily_api_key = user_search_key or service_tavily_key + + if tavily_api_key: + _logger.info('Adding search engine to MCP config') + mcp_config['tavily'] = { + 'url': f'https://mcp.tavily.com/mcp/?tavilyApiKey={tavily_api_key}' + } + else: + _logger.info('No search engine API key found, skipping search engine') + + return llm, mcp_config + + def _create_agent_with_context( + self, + llm: LLM, + agent_type: AgentType, + system_message_suffix: str | None, + mcp_config: dict, + ) -> Agent: + """Create an agent with appropriate tools and context based on agent type. + + Args: + llm: Configured LLM instance + agent_type: Type of agent to create (PLAN or DEFAULT) + system_message_suffix: Optional suffix for system messages + mcp_config: MCP configuration dictionary + + Returns: + Configured Agent instance with context + """ + # Create agent based on type + if agent_type == AgentType.PLAN: + agent = Agent( + llm=llm, + tools=get_planning_tools(), + system_prompt_filename='system_prompt_planning.j2', + system_prompt_kwargs={'plan_structure': format_plan_structure()}, + condenser=get_planning_condenser( + llm=llm.model_copy(update={'usage_id': 'planning_condenser'}) + ), + security_analyzer=None, + mcp_config=mcp_config, + ) + else: + agent = Agent( + llm=llm, + tools=get_default_tools(enable_browser=True), + system_prompt_kwargs={'cli_mode': False}, + condenser=get_default_condenser( + llm=llm.model_copy(update={'usage_id': 'condenser'}) + ), + mcp_config=mcp_config, + ) + + # Add agent context + agent_context = AgentContext(system_message_suffix=system_message_suffix) + agent = agent.model_copy(update={'agent_context': agent_context}) + + return agent + + async def _finalize_conversation_request( + self, + agent: Agent, + conversation_id: UUID | None, + user: UserInfo, + workspace: LocalWorkspace, + initial_message: SendMessageRequest | None, + secrets: dict, + sandbox: SandboxInfo, + remote_workspace: AsyncRemoteWorkspace | None, + selected_repository: str | None, + working_dir: str, + ) -> StartConversationRequest: + """Finalize the conversation request with experiment variants and skills. + + Args: + agent: The configured agent + conversation_id: Optional conversation ID, generates new one if None + user: User information + workspace: Local workspace instance + initial_message: Optional initial message for the conversation + secrets: Dictionary of secrets for authentication + sandbox: Sandbox information + remote_workspace: Optional remote workspace for skills loading + selected_repository: Optional repository name + working_dir: Working directory path + + Returns: + Complete StartConversationRequest ready for use + """ + # Generate conversation ID if not provided + conversation_id = conversation_id or uuid4() + + # Apply experiment variants + agent = ExperimentManagerImpl.run_agent_variant_tests__v1( + user.id, conversation_id, agent + ) + + # Load and merge skills if remote workspace is available + if remote_workspace: + try: + agent = await self._load_skills_and_update_agent( + sandbox, agent, remote_workspace, selected_repository, working_dir + ) + except Exception as e: + _logger.warning(f'Failed to load skills: {e}', exc_info=True) + # Continue without skills - don't fail conversation startup + + # Create and return the final request + return StartConversationRequest( + conversation_id=conversation_id, + agent=agent, + workspace=workspace, + confirmation_policy=( + AlwaysConfirm() if user.confirmation_mode else NeverConfirm() + ), + initial_message=initial_message, + secrets=secrets, + ) + async def _build_start_conversation_request_for_user( self, sandbox: SandboxInfo, @@ -532,87 +758,41 @@ class LiveStatusAppConversationService(AppConversationServiceBase): remote_workspace: AsyncRemoteWorkspace | None = None, selected_repository: str | None = None, ) -> StartConversationRequest: + """Build a complete conversation request for a user. + + This method orchestrates the creation of a conversation request by: + 1. Setting up git provider secrets + 2. Configuring LLM and MCP settings + 3. Creating an agent with appropriate context + 4. Finalizing the request with skills and experiment variants + """ user = await self.user_context.get_user_info() - - # Set up a secret for the git token - secrets = await self.user_context.get_secrets() - if git_provider: - secret_name = f'{git_provider.name}_TOKEN' - if self.web_url: - # If there is a web url, then we create an access token to access it. - # For security reasons, we are explicit here - only this user, and - # only this provider, with a timeout - access_token = self.jwt_service.create_jws_token( - payload={ - 'user_id': user.id, - 'provider_type': git_provider.value, - }, - expires_in=self.access_token_hard_timeout, - ) - headers = {'X-Access-Token': access_token} - - # Include keycloak_auth cookie in headers if app_mode is SaaS - if self.app_mode == 'saas' and self.keycloak_auth_cookie: - headers['Cookie'] = f'keycloak_auth={self.keycloak_auth_cookie}' - - secrets[secret_name] = LookupSecret( - url=self.web_url + '/api/v1/webhooks/secrets', - headers=headers, - ) - else: - # If there is no URL specified where the sandbox can access the app server - # then we supply a static secret with the most recent value. Depending - # on the type, this may eventually expire. - static_token = await self.user_context.get_latest_token(git_provider) - if static_token: - secrets[secret_name] = StaticSecret(value=static_token) - workspace = LocalWorkspace(working_dir=working_dir) - # Use provided llm_model if available, otherwise fall back to user's default - model = llm_model or user.llm_model - llm = LLM( - model=model, - base_url=user.llm_base_url, - api_key=user.llm_api_key, - usage_id='agent', - ) - # The agent gets passed initial instructions - # Select agent based on agent_type - if agent_type == AgentType.PLAN: - agent = get_planning_agent(llm=llm) - else: - agent = get_default_agent(llm=llm) + # Set up secrets for git provider + secrets = await self._setup_secrets_for_git_provider(git_provider, user) - agent_context = AgentContext(system_message_suffix=system_message_suffix) - agent = agent.model_copy(update={'agent_context': agent_context}) + # Configure LLM and MCP + llm, mcp_config = await self._configure_llm_and_mcp(user, llm_model) - conversation_id = conversation_id or uuid4() - agent = ExperimentManagerImpl.run_agent_variant_tests__v1( - user.id, conversation_id, agent + # Create agent with context + agent = self._create_agent_with_context( + llm, agent_type, system_message_suffix, mcp_config ) - # Load and merge all skills if remote_workspace is available - if remote_workspace: - try: - agent = await self._load_skills_and_update_agent( - sandbox, agent, remote_workspace, selected_repository, working_dir - ) - except Exception as e: - _logger.warning(f'Failed to load skills: {e}', exc_info=True) - # Continue without skills - don't fail conversation startup - - start_conversation_request = StartConversationRequest( - conversation_id=conversation_id, - agent=agent, - workspace=workspace, - confirmation_policy=( - AlwaysConfirm() if user.confirmation_mode else NeverConfirm() - ), - initial_message=initial_message, - secrets=secrets, + # Finalize and return the conversation request + return await self._finalize_conversation_request( + agent, + conversation_id, + user, + workspace, + initial_message, + secrets, + sandbox, + remote_workspace, + selected_repository, + working_dir, ) - return start_conversation_request async def update_agent_server_conversation_title( self, @@ -817,6 +997,10 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): 'be retrieved by a sandboxed conversation.' ), ) + tavily_api_key: SecretStr | None = Field( + default=None, + description='The Tavily Search API key to add to MCP integration', + ) async def inject( self, state: InjectorState, request: Request | None = None @@ -874,6 +1058,14 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): # If server_config is not available (e.g., in tests), continue without it pass + # We supply the global tavily key only if the app mode is not SAAS, where + # currently the search endpoints are patched into the app server instead + # so the tavily key does not need to be shared + if self.tavily_api_key and app_mode != AppMode.SAAS: + tavily_api_key = self.tavily_api_key.get_secret_value() + else: + tavily_api_key = None + yield LiveStatusAppConversationService( init_git_in_empty_workspace=self.init_git_in_empty_workspace, user_context=user_context, @@ -890,4 +1082,5 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): access_token_hard_timeout=access_token_hard_timeout, app_mode=app_mode, keycloak_auth_cookie=keycloak_auth_cookie, + tavily_api_key=tavily_api_key, ) diff --git a/openhands/app_server/config.py b/openhands/app_server/config.py index b58f3dcda1..b44608a887 100644 --- a/openhands/app_server/config.py +++ b/openhands/app_server/config.py @@ -6,7 +6,7 @@ from typing import AsyncContextManager import httpx from fastapi import Depends, Request -from pydantic import Field +from pydantic import Field, SecretStr from sqlalchemy.ext.asyncio import AsyncSession # Import the event_callback module to ensure all processors are registered @@ -185,7 +185,13 @@ def config_from_env() -> AppServerConfig: ) if config.app_conversation is None: - config.app_conversation = LiveStatusAppConversationServiceInjector() + tavily_api_key = None + tavily_api_key_str = os.getenv('TAVILY_API_KEY') or os.getenv('SEARCH_API_KEY') + if tavily_api_key_str: + tavily_api_key = SecretStr(tavily_api_key_str) + config.app_conversation = LiveStatusAppConversationServiceInjector( + tavily_api_key=tavily_api_key + ) if config.user is None: config.user = AuthUserContextInjector() diff --git a/openhands/app_server/user/auth_user_context.py b/openhands/app_server/user/auth_user_context.py index 53612364f5..e0b8fdd35f 100644 --- a/openhands/app_server/user/auth_user_context.py +++ b/openhands/app_server/user/auth_user_context.py @@ -78,6 +78,10 @@ class AuthUserContext(UserContext): return results + async def get_mcp_api_key(self) -> str | None: + mcp_api_key = await self.user_auth.get_mcp_api_key() + return mcp_api_key + USER_ID_ATTR = 'user_id' diff --git a/openhands/app_server/user/specifiy_user_context.py b/openhands/app_server/user/specifiy_user_context.py index 0855b447bf..d940061466 100644 --- a/openhands/app_server/user/specifiy_user_context.py +++ b/openhands/app_server/user/specifiy_user_context.py @@ -30,6 +30,9 @@ class SpecifyUserContext(UserContext): async def get_secrets(self) -> dict[str, SecretSource]: raise NotImplementedError() + async def get_mcp_api_key(self) -> str | None: + raise NotImplementedError() + USER_CONTEXT_ATTR = 'user_context' ADMIN = SpecifyUserContext(user_id=None) diff --git a/openhands/app_server/user/user_context.py b/openhands/app_server/user/user_context.py index 75fe957160..0971b71570 100644 --- a/openhands/app_server/user/user_context.py +++ b/openhands/app_server/user/user_context.py @@ -34,6 +34,10 @@ class UserContext(ABC): async def get_secrets(self) -> dict[str, SecretSource]: """Get custom secrets and github provider secrets for the conversation.""" + @abstractmethod + async def get_mcp_api_key(self) -> str | None: + """Get an MCP API Key.""" + class UserContextInjector(DiscriminatedUnionMixin, Injector[UserContext], ABC): """Injector for user contexts.""" diff --git a/openhands/server/user_auth/default_user_auth.py b/openhands/server/user_auth/default_user_auth.py index 2e0a7b5af9..8bc79af156 100644 --- a/openhands/server/user_auth/default_user_auth.py +++ b/openhands/server/user_auth/default_user_auth.py @@ -88,6 +88,9 @@ class DefaultUserAuth(UserAuth): return None return user_secrets.provider_tokens + async def get_mcp_api_key(self) -> str | None: + return None + @classmethod async def get_instance(cls, request: Request) -> UserAuth: user_auth = DefaultUserAuth() diff --git a/openhands/server/user_auth/user_auth.py b/openhands/server/user_auth/user_auth.py index e370d32474..c61c9ceb8b 100644 --- a/openhands/server/user_auth/user_auth.py +++ b/openhands/server/user_auth/user_auth.py @@ -75,6 +75,10 @@ class UserAuth(ABC): def get_auth_type(self) -> AuthType | None: return None + @abstractmethod + async def get_mcp_api_key(self) -> str | None: + """Get an mcp api key for the user""" + @classmethod @abstractmethod async def get_instance(cls, request: Request) -> UserAuth: diff --git a/tests/unit/app_server/test_live_status_app_conversation_service.py b/tests/unit/app_server/test_live_status_app_conversation_service.py new file mode 100644 index 0000000000..0f61c161b2 --- /dev/null +++ b/tests/unit/app_server/test_live_status_app_conversation_service.py @@ -0,0 +1,698 @@ +"""Unit tests for the methods in LiveStatusAppConversationService.""" + +from unittest.mock import AsyncMock, Mock, patch +from uuid import UUID, uuid4 + +import pytest + +from openhands.agent_server.models import SendMessageRequest, StartConversationRequest +from openhands.app_server.app_conversation.app_conversation_models import AgentType +from openhands.app_server.app_conversation.live_status_app_conversation_service import ( + LiveStatusAppConversationService, +) +from openhands.app_server.sandbox.sandbox_models import SandboxInfo, SandboxStatus +from openhands.app_server.user.user_context import UserContext +from openhands.integrations.provider import ProviderType +from openhands.sdk import Agent +from openhands.sdk.conversation.secret_source import LookupSecret, StaticSecret +from openhands.sdk.llm import LLM +from openhands.sdk.workspace import LocalWorkspace +from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace +from openhands.server.types import AppMode + + +class TestLiveStatusAppConversationService: + """Test cases for the methods in LiveStatusAppConversationService.""" + + def setup_method(self): + """Set up test fixtures.""" + # Create mock dependencies + self.mock_user_context = Mock(spec=UserContext) + self.mock_jwt_service = Mock() + self.mock_sandbox_service = Mock() + self.mock_sandbox_spec_service = Mock() + self.mock_app_conversation_info_service = Mock() + self.mock_app_conversation_start_task_service = Mock() + self.mock_event_callback_service = Mock() + self.mock_httpx_client = Mock() + + # Create service instance + self.service = LiveStatusAppConversationService( + init_git_in_empty_workspace=True, + user_context=self.mock_user_context, + app_conversation_info_service=self.mock_app_conversation_info_service, + app_conversation_start_task_service=self.mock_app_conversation_start_task_service, + event_callback_service=self.mock_event_callback_service, + sandbox_service=self.mock_sandbox_service, + sandbox_spec_service=self.mock_sandbox_spec_service, + jwt_service=self.mock_jwt_service, + sandbox_startup_timeout=30, + sandbox_startup_poll_frequency=1, + httpx_client=self.mock_httpx_client, + web_url='https://test.example.com', + access_token_hard_timeout=None, + app_mode='test', + keycloak_auth_cookie=None, + ) + + # Mock user info + self.mock_user = Mock() + self.mock_user.id = 'test_user_123' + self.mock_user.llm_model = 'gpt-4' + self.mock_user.llm_base_url = 'https://api.openai.com/v1' + self.mock_user.llm_api_key = 'test_api_key' + self.mock_user.confirmation_mode = False + self.mock_user.search_api_key = None # Default to None + + # Mock sandbox + self.mock_sandbox = Mock(spec=SandboxInfo) + self.mock_sandbox.id = uuid4() + self.mock_sandbox.status = SandboxStatus.RUNNING + + @pytest.mark.asyncio + async def test_setup_secrets_for_git_provider_no_provider(self): + """Test _setup_secrets_for_git_provider with no git provider.""" + # Arrange + base_secrets = {'existing': 'secret'} + self.mock_user_context.get_secrets.return_value = base_secrets + + # Act + result = await self.service._setup_secrets_for_git_provider( + None, self.mock_user + ) + + # Assert + assert result == base_secrets + self.mock_user_context.get_secrets.assert_called_once() + + @pytest.mark.asyncio + async def test_setup_secrets_for_git_provider_with_web_url(self): + """Test _setup_secrets_for_git_provider with web URL (creates access token).""" + # Arrange + base_secrets = {} + self.mock_user_context.get_secrets.return_value = base_secrets + self.mock_jwt_service.create_jws_token.return_value = 'test_access_token' + git_provider = ProviderType.GITHUB + + # Act + result = await self.service._setup_secrets_for_git_provider( + git_provider, self.mock_user + ) + + # Assert + assert 'GITHUB_TOKEN' in result + assert isinstance(result['GITHUB_TOKEN'], LookupSecret) + assert ( + result['GITHUB_TOKEN'].url + == 'https://test.example.com/api/v1/webhooks/secrets' + ) + assert result['GITHUB_TOKEN'].headers['X-Access-Token'] == 'test_access_token' + + self.mock_jwt_service.create_jws_token.assert_called_once_with( + payload={ + 'user_id': self.mock_user.id, + 'provider_type': git_provider.value, + }, + expires_in=None, + ) + + @pytest.mark.asyncio + async def test_setup_secrets_for_git_provider_with_saas_mode(self): + """Test _setup_secrets_for_git_provider with SaaS mode (includes keycloak cookie).""" + # Arrange + self.service.app_mode = 'saas' + self.service.keycloak_auth_cookie = 'test_cookie' + base_secrets = {} + self.mock_user_context.get_secrets.return_value = base_secrets + self.mock_jwt_service.create_jws_token.return_value = 'test_access_token' + git_provider = ProviderType.GITLAB + + # Act + result = await self.service._setup_secrets_for_git_provider( + git_provider, self.mock_user + ) + + # Assert + assert 'GITLAB_TOKEN' in result + lookup_secret = result['GITLAB_TOKEN'] + assert isinstance(lookup_secret, LookupSecret) + assert 'Cookie' in lookup_secret.headers + assert lookup_secret.headers['Cookie'] == 'keycloak_auth=test_cookie' + + @pytest.mark.asyncio + async def test_setup_secrets_for_git_provider_without_web_url(self): + """Test _setup_secrets_for_git_provider without web URL (uses static token).""" + # Arrange + self.service.web_url = None + base_secrets = {} + self.mock_user_context.get_secrets.return_value = base_secrets + self.mock_user_context.get_latest_token.return_value = 'static_token_value' + git_provider = ProviderType.GITHUB + + # Act + result = await self.service._setup_secrets_for_git_provider( + git_provider, self.mock_user + ) + + # Assert + assert 'GITHUB_TOKEN' in result + assert isinstance(result['GITHUB_TOKEN'], StaticSecret) + assert result['GITHUB_TOKEN'].value.get_secret_value() == 'static_token_value' + self.mock_user_context.get_latest_token.assert_called_once_with(git_provider) + + @pytest.mark.asyncio + async def test_setup_secrets_for_git_provider_no_static_token(self): + """Test _setup_secrets_for_git_provider when no static token is available.""" + # Arrange + self.service.web_url = None + base_secrets = {} + self.mock_user_context.get_secrets.return_value = base_secrets + self.mock_user_context.get_latest_token.return_value = None + git_provider = ProviderType.GITHUB + + # Act + result = await self.service._setup_secrets_for_git_provider( + git_provider, self.mock_user + ) + + # Assert + assert 'GITHUB_TOKEN' not in result + assert result == base_secrets + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_with_custom_model(self): + """Test _configure_llm_and_mcp with custom LLM model.""" + # Arrange + custom_model = 'gpt-3.5-turbo' + self.mock_user_context.get_mcp_api_key.return_value = 'mcp_api_key' + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, custom_model + ) + + # Assert + assert isinstance(llm, LLM) + assert llm.model == custom_model + assert llm.base_url == self.mock_user.llm_base_url + assert llm.api_key.get_secret_value() == self.mock_user.llm_api_key + assert llm.usage_id == 'agent' + + assert 'default' in mcp_config + assert mcp_config['default']['url'] == 'https://test.example.com/mcp/mcp' + assert mcp_config['default']['headers']['X-Session-API-Key'] == 'mcp_api_key' + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_with_user_default_model(self): + """Test _configure_llm_and_mcp using user's default model.""" + # Arrange + self.mock_user_context.get_mcp_api_key.return_value = None + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, None + ) + + # Assert + assert llm.model == self.mock_user.llm_model + assert 'default' in mcp_config + assert 'headers' not in mcp_config['default'] + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_without_web_url(self): + """Test _configure_llm_and_mcp without web URL (no MCP config).""" + # Arrange + self.service.web_url = None + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, None + ) + + # Assert + assert isinstance(llm, LLM) + assert mcp_config == {} + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_tavily_with_user_search_api_key(self): + """Test _configure_llm_and_mcp adds tavily when user has search_api_key.""" + # Arrange + from pydantic import SecretStr + + self.mock_user.search_api_key = SecretStr('user_search_key') + self.mock_user_context.get_mcp_api_key.return_value = 'mcp_api_key' + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, None + ) + + # Assert + assert isinstance(llm, LLM) + assert 'default' in mcp_config + assert 'tavily' in mcp_config + assert ( + mcp_config['tavily']['url'] + == 'https://mcp.tavily.com/mcp/?tavilyApiKey=user_search_key' + ) + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_tavily_with_env_tavily_key(self): + """Test _configure_llm_and_mcp adds tavily when service has tavily_api_key.""" + # Arrange + self.service.tavily_api_key = 'env_tavily_key' + self.mock_user_context.get_mcp_api_key.return_value = None + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, None + ) + + # Assert + assert isinstance(llm, LLM) + assert 'default' in mcp_config + assert 'tavily' in mcp_config + assert ( + mcp_config['tavily']['url'] + == 'https://mcp.tavily.com/mcp/?tavilyApiKey=env_tavily_key' + ) + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_tavily_user_key_takes_precedence(self): + """Test _configure_llm_and_mcp user search_api_key takes precedence over env key.""" + # Arrange + from pydantic import SecretStr + + self.mock_user.search_api_key = SecretStr('user_search_key') + self.service.tavily_api_key = 'env_tavily_key' + self.mock_user_context.get_mcp_api_key.return_value = None + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, None + ) + + # Assert + assert isinstance(llm, LLM) + assert 'tavily' in mcp_config + assert ( + mcp_config['tavily']['url'] + == 'https://mcp.tavily.com/mcp/?tavilyApiKey=user_search_key' + ) + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_no_tavily_without_keys(self): + """Test _configure_llm_and_mcp does not add tavily when no keys are available.""" + # Arrange + self.mock_user.search_api_key = None + self.service.tavily_api_key = None + self.mock_user_context.get_mcp_api_key.return_value = None + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, None + ) + + # Assert + assert isinstance(llm, LLM) + assert 'default' in mcp_config + assert 'tavily' not in mcp_config + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_saas_mode_no_tavily_without_user_key(self): + """Test _configure_llm_and_mcp does not add tavily in SAAS mode without user search_api_key. + + In SAAS mode, the global tavily_api_key should not be passed to the service instance, + so tavily should only be added if the user has their own search_api_key. + """ + # Arrange - simulate SAAS mode where no global tavily key is available + self.service.app_mode = AppMode.SAAS.value + self.service.tavily_api_key = None # In SAAS mode, this should be None + self.mock_user.search_api_key = None + self.mock_user_context.get_mcp_api_key.return_value = None + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, None + ) + + # Assert + assert isinstance(llm, LLM) + assert 'default' in mcp_config + assert 'tavily' not in mcp_config + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_saas_mode_with_user_search_key(self): + """Test _configure_llm_and_mcp adds tavily in SAAS mode when user has search_api_key. + + Even in SAAS mode, if the user has their own search_api_key, tavily should be added. + """ + # Arrange - simulate SAAS mode with user having their own search key + from pydantic import SecretStr + + self.service.app_mode = AppMode.SAAS.value + self.service.tavily_api_key = None # In SAAS mode, this should be None + self.mock_user.search_api_key = SecretStr('user_search_key') + self.mock_user_context.get_mcp_api_key.return_value = None + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, None + ) + + # Assert + assert isinstance(llm, LLM) + assert 'default' in mcp_config + assert 'tavily' in mcp_config + assert ( + mcp_config['tavily']['url'] + == 'https://mcp.tavily.com/mcp/?tavilyApiKey=user_search_key' + ) + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_tavily_with_empty_user_search_key(self): + """Test _configure_llm_and_mcp handles empty user search_api_key correctly.""" + # Arrange + from pydantic import SecretStr + + self.mock_user.search_api_key = SecretStr('') # Empty string + self.service.tavily_api_key = 'env_tavily_key' + self.mock_user_context.get_mcp_api_key.return_value = None + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, None + ) + + # Assert + assert isinstance(llm, LLM) + assert 'tavily' in mcp_config + # Should fall back to env key since user key is empty + assert ( + mcp_config['tavily']['url'] + == 'https://mcp.tavily.com/mcp/?tavilyApiKey=env_tavily_key' + ) + + @pytest.mark.asyncio + async def test_configure_llm_and_mcp_tavily_with_whitespace_user_search_key(self): + """Test _configure_llm_and_mcp handles whitespace-only user search_api_key correctly.""" + # Arrange + from pydantic import SecretStr + + self.mock_user.search_api_key = SecretStr(' ') # Whitespace only + self.service.tavily_api_key = 'env_tavily_key' + self.mock_user_context.get_mcp_api_key.return_value = None + + # Act + llm, mcp_config = await self.service._configure_llm_and_mcp( + self.mock_user, None + ) + + # Assert + assert isinstance(llm, LLM) + assert 'tavily' in mcp_config + # Should fall back to env key since user key is whitespace only + assert ( + mcp_config['tavily']['url'] + == 'https://mcp.tavily.com/mcp/?tavilyApiKey=env_tavily_key' + ) + + @patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service.get_planning_tools' + ) + @patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service.get_planning_condenser' + ) + @patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service.format_plan_structure' + ) + def test_create_agent_with_context_planning_agent( + self, mock_format_plan, mock_get_condenser, mock_get_tools + ): + """Test _create_agent_with_context for planning agent type.""" + # Arrange + mock_llm = Mock(spec=LLM) + mock_llm.model_copy.return_value = mock_llm + mock_get_tools.return_value = [] + mock_get_condenser.return_value = Mock() + mock_format_plan.return_value = 'test_plan_structure' + mcp_config = {'default': {'url': 'test'}} + system_message_suffix = 'Test suffix' + + # Act + with patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service.Agent' + ) as mock_agent_class: + mock_agent_instance = Mock() + mock_agent_instance.model_copy.return_value = mock_agent_instance + mock_agent_class.return_value = mock_agent_instance + + self.service._create_agent_with_context( + mock_llm, AgentType.PLAN, system_message_suffix, mcp_config + ) + + # Assert + mock_agent_class.assert_called_once() + call_kwargs = mock_agent_class.call_args[1] + assert call_kwargs['llm'] == mock_llm + assert call_kwargs['system_prompt_filename'] == 'system_prompt_planning.j2' + assert ( + call_kwargs['system_prompt_kwargs']['plan_structure'] + == 'test_plan_structure' + ) + assert call_kwargs['mcp_config'] == mcp_config + assert call_kwargs['security_analyzer'] is None + + @patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service.get_default_tools' + ) + @patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service.get_default_condenser' + ) + def test_create_agent_with_context_default_agent( + self, mock_get_condenser, mock_get_tools + ): + """Test _create_agent_with_context for default agent type.""" + # Arrange + mock_llm = Mock(spec=LLM) + mock_llm.model_copy.return_value = mock_llm + mock_get_tools.return_value = [] + mock_get_condenser.return_value = Mock() + mcp_config = {'default': {'url': 'test'}} + + # Act + with patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service.Agent' + ) as mock_agent_class: + mock_agent_instance = Mock() + mock_agent_instance.model_copy.return_value = mock_agent_instance + mock_agent_class.return_value = mock_agent_instance + + self.service._create_agent_with_context( + mock_llm, AgentType.DEFAULT, None, mcp_config + ) + + # Assert + mock_agent_class.assert_called_once() + call_kwargs = mock_agent_class.call_args[1] + assert call_kwargs['llm'] == mock_llm + assert call_kwargs['system_prompt_kwargs']['cli_mode'] is False + assert call_kwargs['mcp_config'] == mcp_config + mock_get_tools.assert_called_once_with(enable_browser=True) + + @pytest.mark.asyncio + @patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl' + ) + async def test_finalize_conversation_request_with_skills( + self, mock_experiment_manager + ): + """Test _finalize_conversation_request with skills loading.""" + # Arrange + mock_agent = Mock(spec=Agent) + mock_updated_agent = Mock(spec=Agent) + mock_experiment_manager.run_agent_variant_tests__v1.return_value = ( + mock_updated_agent + ) + + conversation_id = uuid4() + workspace = LocalWorkspace(working_dir='/test') + initial_message = Mock(spec=SendMessageRequest) + secrets = {'test': StaticSecret(value='secret')} + remote_workspace = Mock(spec=AsyncRemoteWorkspace) + + # Mock the skills loading method + self.service._load_skills_and_update_agent = AsyncMock( + return_value=mock_updated_agent + ) + + # Act + result = await self.service._finalize_conversation_request( + mock_agent, + conversation_id, + self.mock_user, + workspace, + initial_message, + secrets, + self.mock_sandbox, + remote_workspace, + 'test_repo', + '/test/dir', + ) + + # Assert + assert isinstance(result, StartConversationRequest) + assert result.conversation_id == conversation_id + assert result.agent == mock_updated_agent + assert result.workspace == workspace + assert result.initial_message == initial_message + assert result.secrets == secrets + + mock_experiment_manager.run_agent_variant_tests__v1.assert_called_once_with( + self.mock_user.id, conversation_id, mock_agent + ) + self.service._load_skills_and_update_agent.assert_called_once_with( + self.mock_sandbox, + mock_updated_agent, + remote_workspace, + 'test_repo', + '/test/dir', + ) + + @pytest.mark.asyncio + @patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl' + ) + async def test_finalize_conversation_request_without_skills( + self, mock_experiment_manager + ): + """Test _finalize_conversation_request without remote workspace (no skills).""" + # Arrange + mock_agent = Mock(spec=Agent) + mock_updated_agent = Mock(spec=Agent) + mock_experiment_manager.run_agent_variant_tests__v1.return_value = ( + mock_updated_agent + ) + + workspace = LocalWorkspace(working_dir='/test') + secrets = {'test': StaticSecret(value='secret')} + + # Act + result = await self.service._finalize_conversation_request( + mock_agent, + None, + self.mock_user, + workspace, + None, + secrets, + self.mock_sandbox, + None, + None, + '/test/dir', + ) + + # Assert + assert isinstance(result, StartConversationRequest) + assert isinstance(result.conversation_id, UUID) + assert result.agent == mock_updated_agent + mock_experiment_manager.run_agent_variant_tests__v1.assert_called_once() + + @pytest.mark.asyncio + @patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl' + ) + async def test_finalize_conversation_request_skills_loading_fails( + self, mock_experiment_manager + ): + """Test _finalize_conversation_request when skills loading fails.""" + # Arrange + mock_agent = Mock(spec=Agent) + mock_updated_agent = Mock(spec=Agent) + mock_experiment_manager.run_agent_variant_tests__v1.return_value = ( + mock_updated_agent + ) + + workspace = LocalWorkspace(working_dir='/test') + secrets = {'test': StaticSecret(value='secret')} + remote_workspace = Mock(spec=AsyncRemoteWorkspace) + + # Mock skills loading to raise an exception + self.service._load_skills_and_update_agent = AsyncMock( + side_effect=Exception('Skills loading failed') + ) + + # Act + with patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service._logger' + ) as mock_logger: + result = await self.service._finalize_conversation_request( + mock_agent, + None, + self.mock_user, + workspace, + None, + secrets, + self.mock_sandbox, + remote_workspace, + 'test_repo', + '/test/dir', + ) + + # Assert + assert isinstance(result, StartConversationRequest) + assert ( + result.agent == mock_updated_agent + ) # Should still use the experiment-modified agent + mock_logger.warning.assert_called_once() + + @pytest.mark.asyncio + async def test_build_start_conversation_request_for_user_integration(self): + """Test the main _build_start_conversation_request_for_user method integration.""" + # Arrange + self.mock_user_context.get_user_info.return_value = self.mock_user + + # Mock all the helper methods + mock_secrets = {'GITHUB_TOKEN': Mock()} + mock_llm = Mock(spec=LLM) + mock_mcp_config = {'default': {'url': 'test'}} + mock_agent = Mock(spec=Agent) + mock_final_request = Mock(spec=StartConversationRequest) + + self.service._setup_secrets_for_git_provider = AsyncMock( + return_value=mock_secrets + ) + self.service._configure_llm_and_mcp = AsyncMock( + return_value=(mock_llm, mock_mcp_config) + ) + self.service._create_agent_with_context = Mock(return_value=mock_agent) + self.service._finalize_conversation_request = AsyncMock( + return_value=mock_final_request + ) + + # Act + result = await self.service._build_start_conversation_request_for_user( + sandbox=self.mock_sandbox, + initial_message=None, + system_message_suffix='Test suffix', + git_provider=ProviderType.GITHUB, + working_dir='/test/dir', + agent_type=AgentType.DEFAULT, + llm_model='gpt-4', + conversation_id=None, + remote_workspace=None, + selected_repository='test/repo', + ) + + # Assert + assert result == mock_final_request + + self.service._setup_secrets_for_git_provider.assert_called_once_with( + ProviderType.GITHUB, self.mock_user + ) + self.service._configure_llm_and_mcp.assert_called_once_with( + self.mock_user, 'gpt-4' + ) + self.service._create_agent_with_context.assert_called_once_with( + mock_llm, AgentType.DEFAULT, 'Test suffix', mock_mcp_config + ) + self.service._finalize_conversation_request.assert_called_once() diff --git a/tests/unit/experiments/test_experiment_manager.py b/tests/unit/experiments/test_experiment_manager.py index eb48e5336f..00435c597c 100644 --- a/tests/unit/experiments/test_experiment_manager.py +++ b/tests/unit/experiments/test_experiment_manager.py @@ -200,8 +200,24 @@ class TestExperimentManagerIntegration: # Patch the pieces invoked by the service with ( - patch( - 'openhands.app_server.app_conversation.live_status_app_conversation_service.get_default_agent', + patch.object( + service, + '_setup_secrets_for_git_provider', + return_value={}, + ), + patch.object( + service, + '_configure_llm_and_mcp', + return_value=(mock_llm, {}), + ), + patch.object( + service, + '_create_agent_with_context', + return_value=mock_agent, + ), + patch.object( + service, + '_load_skills_and_update_agent', return_value=mock_agent, ), patch( diff --git a/tests/unit/server/routes/test_settings_api.py b/tests/unit/server/routes/test_settings_api.py index f01b1d77df..6ea4080388 100644 --- a/tests/unit/server/routes/test_settings_api.py +++ b/tests/unit/server/routes/test_settings_api.py @@ -46,6 +46,9 @@ class MockUserAuth(UserAuth): async def get_secrets(self) -> Secrets | None: return None + async def get_mcp_api_key(self) -> str | None: + return None + @classmethod async def get_instance(cls, request: Request) -> UserAuth: return MockUserAuth() diff --git a/tests/unit/server/test_openapi_schema_generation.py b/tests/unit/server/test_openapi_schema_generation.py index 2aa798e1e6..eb967e496c 100644 --- a/tests/unit/server/test_openapi_schema_generation.py +++ b/tests/unit/server/test_openapi_schema_generation.py @@ -46,6 +46,9 @@ class MockUserAuth(UserAuth): async def get_secrets(self) -> Secrets | None: return None + async def get_mcp_api_key(self) -> str | None: + return None + @classmethod async def get_instance(cls, request: Request) -> UserAuth: return MockUserAuth()