diff --git a/enterprise/integrations/github/github_view.py b/enterprise/integrations/github/github_view.py index 733cec6c2a..97352a233f 100644 --- a/enterprise/integrations/github/github_view.py +++ b/enterprise/integrations/github/github_view.py @@ -86,6 +86,9 @@ class GithubUserContext(UserContext): # For now, return a basic HTTPS URL return f'https://github.com/{repository}.git' + async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None: + return self.git_provider_tokens + async def get_latest_token(self, provider_type: ProviderType) -> str | None: # Return the appropriate token from git_provider_tokens if provider_type == ProviderType.GITHUB and self.git_provider_tokens: diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index 070640c907..9b290455c6 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -526,13 +526,10 @@ class LiveStatusAppConversationService(AppConversationServiceBase): if not request.llm_model and parent_info.llm_model: request.llm_model = parent_info.llm_model - async def _setup_secrets_for_git_provider( - self, git_provider: ProviderType | None, user: UserInfo - ) -> dict: - """Set up secrets for git provider authentication. + async def _setup_secrets_for_git_providers(self, user: UserInfo) -> dict: + """Set up secrets for all git provider authentication. Args: - git_provider: The git provider type (GitHub, GitLab, etc.) user: User information containing authentication details Returns: @@ -540,35 +537,42 @@ class LiveStatusAppConversationService(AppConversationServiceBase): """ secrets = await self.user_context.get_secrets() - if not git_provider: + # Get all provider tokens from user authentication + provider_tokens = await self.user_context.get_provider_tokens() + if not provider_tokens: return secrets - secret_name = f'{git_provider.name}_TOKEN' + # Create secrets for each provider token + for provider_type, provider_token in provider_tokens.items(): + if not provider_token.token: + continue - if self.web_url: - # Create an access token for web-based authentication - access_token = self.jwt_service.create_jws_token( - payload={ - 'user_id': user.id, - 'provider_type': git_provider.value, - }, - expires_in=self.access_token_hard_timeout, - ) - headers = {'X-Access-Token': access_token} + secret_name = f'{provider_type.name}_TOKEN' - # Include keycloak_auth cookie in headers if app_mode is SaaS - if self.app_mode == 'saas' and self.keycloak_auth_cookie: - headers['Cookie'] = f'keycloak_auth={self.keycloak_auth_cookie}' + if self.web_url: + # Create an access token for web-based authentication + access_token = self.jwt_service.create_jws_token( + payload={ + 'user_id': user.id, + 'provider_type': provider_type.value, + }, + expires_in=self.access_token_hard_timeout, + ) + headers = {'X-Access-Token': access_token} - secrets[secret_name] = LookupSecret( - url=self.web_url + '/api/v1/webhooks/secrets', - headers=headers, - ) - else: - # Use static token for environments without web URL access - static_token = await self.user_context.get_latest_token(git_provider) - if static_token: - secrets[secret_name] = StaticSecret(value=static_token) + # Include keycloak_auth cookie in headers if app_mode is SaaS + if self.app_mode == 'saas' and self.keycloak_auth_cookie: + headers['Cookie'] = f'keycloak_auth={self.keycloak_auth_cookie}' + + secrets[secret_name] = LookupSecret( + url=self.web_url + '/api/v1/webhooks/secrets', + headers=headers, + ) + else: + # Use static token for environments without web URL access + static_token = await self.user_context.get_latest_token(provider_type) + if static_token: + secrets[secret_name] = StaticSecret(value=static_token) return secrets @@ -768,8 +772,8 @@ class LiveStatusAppConversationService(AppConversationServiceBase): user = await self.user_context.get_user_info() workspace = LocalWorkspace(working_dir=working_dir) - # Set up secrets for git provider - secrets = await self._setup_secrets_for_git_provider(git_provider, user) + # Set up secrets for all git providers + secrets = await self._setup_secrets_for_git_providers(user) # Configure LLM and MCP llm, mcp_config = await self._configure_llm_and_mcp(user, llm_model) diff --git a/openhands/app_server/user/auth_user_context.py b/openhands/app_server/user/auth_user_context.py index e0b8fdd35f..8ea95036f4 100644 --- a/openhands/app_server/user/auth_user_context.py +++ b/openhands/app_server/user/auth_user_context.py @@ -9,7 +9,11 @@ from openhands.app_server.services.injector import InjectorState from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR from openhands.app_server.user.user_context import UserContext, UserContextInjector from openhands.app_server.user.user_models import UserInfo -from openhands.integrations.provider import ProviderHandler, ProviderType +from openhands.integrations.provider import ( + PROVIDER_TOKEN_TYPE, + ProviderHandler, + ProviderType, +) from openhands.sdk.conversation.secret_source import SecretSource, StaticSecret from openhands.server.user_auth.user_auth import UserAuth, get_user_auth @@ -44,6 +48,9 @@ class AuthUserContext(UserContext): self._user_info = user_info return user_info + async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None: + return await self.user_auth.get_provider_tokens() + async def get_provider_handler(self): provider_handler = self._provider_handler if not provider_handler: diff --git a/openhands/app_server/user/specifiy_user_context.py b/openhands/app_server/user/specifiy_user_context.py index d940061466..87e2d74da2 100644 --- a/openhands/app_server/user/specifiy_user_context.py +++ b/openhands/app_server/user/specifiy_user_context.py @@ -5,7 +5,7 @@ from fastapi import Request from openhands.app_server.errors import OpenHandsError from openhands.app_server.user.user_context import UserContext from openhands.app_server.user.user_models import UserInfo -from openhands.integrations.provider import ProviderType +from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType from openhands.sdk.conversation.secret_source import SecretSource @@ -24,6 +24,9 @@ class SpecifyUserContext(UserContext): async def get_authenticated_git_url(self, repository: str) -> str: raise NotImplementedError() + async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None: + raise NotImplementedError() + async def get_latest_token(self, provider_type: ProviderType) -> str | None: raise NotImplementedError() diff --git a/openhands/app_server/user/user_context.py b/openhands/app_server/user/user_context.py index 0971b71570..02c0ba8aaf 100644 --- a/openhands/app_server/user/user_context.py +++ b/openhands/app_server/user/user_context.py @@ -4,7 +4,7 @@ from openhands.app_server.services.injector import Injector from openhands.app_server.user.user_models import ( UserInfo, ) -from openhands.integrations.provider import ProviderType +from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType from openhands.sdk.conversation.secret_source import SecretSource from openhands.sdk.utils.models import DiscriminatedUnionMixin @@ -26,6 +26,10 @@ class UserContext(ABC): async def get_authenticated_git_url(self, repository: str) -> str: """Get the provider tokens for the user""" + @abstractmethod + async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None: + """Get the latest tokens for all provider types""" + @abstractmethod async def get_latest_token(self, provider_type: ProviderType) -> str | None: """Get the latest token for the provider type given""" diff --git a/tests/unit/app_server/test_live_status_app_conversation_service.py b/tests/unit/app_server/test_live_status_app_conversation_service.py index 00d80abf64..62cd0858f3 100644 --- a/tests/unit/app_server/test_live_status_app_conversation_service.py +++ b/tests/unit/app_server/test_live_status_app_conversation_service.py @@ -28,6 +28,8 @@ class TestLiveStatusAppConversationService: """Set up test fixtures.""" # Create mock dependencies self.mock_user_context = Mock(spec=UserContext) + self.mock_user_auth = Mock() + self.mock_user_context.user_auth = self.mock_user_auth self.mock_jwt_service = Mock() self.mock_sandbox_service = Mock() self.mock_sandbox_spec_service = Mock() @@ -71,67 +73,83 @@ class TestLiveStatusAppConversationService: self.mock_sandbox.status = SandboxStatus.RUNNING @pytest.mark.asyncio - async def test_setup_secrets_for_git_provider_no_provider(self): - """Test _setup_secrets_for_git_provider with no git provider.""" + async def test_setup_secrets_for_git_providers_no_provider_tokens(self): + """Test _setup_secrets_for_git_providers with no provider tokens.""" # Arrange base_secrets = {'existing': 'secret'} self.mock_user_context.get_secrets.return_value = base_secrets + self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None) # Act - result = await self.service._setup_secrets_for_git_provider( - None, self.mock_user - ) + result = await self.service._setup_secrets_for_git_providers(self.mock_user) # Assert assert result == base_secrets self.mock_user_context.get_secrets.assert_called_once() + self.mock_user_context.get_provider_tokens.assert_called_once() @pytest.mark.asyncio - async def test_setup_secrets_for_git_provider_with_web_url(self): - """Test _setup_secrets_for_git_provider with web URL (creates access token).""" + async def test_setup_secrets_for_git_providers_with_web_url(self): + """Test _setup_secrets_for_git_providers with web URL (creates access token).""" # Arrange + from pydantic import SecretStr + + from openhands.integrations.provider import ProviderToken + base_secrets = {} self.mock_user_context.get_secrets.return_value = base_secrets self.mock_jwt_service.create_jws_token.return_value = 'test_access_token' - git_provider = ProviderType.GITHUB + + # Mock provider tokens + provider_tokens = { + ProviderType.GITHUB: ProviderToken(token=SecretStr('github_token')), + ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab_token')), + } + self.mock_user_context.get_provider_tokens = AsyncMock( + return_value=provider_tokens + ) # Act - result = await self.service._setup_secrets_for_git_provider( - git_provider, self.mock_user - ) + result = await self.service._setup_secrets_for_git_providers(self.mock_user) # Assert assert 'GITHUB_TOKEN' in result + assert 'GITLAB_TOKEN' in result assert isinstance(result['GITHUB_TOKEN'], LookupSecret) + assert isinstance(result['GITLAB_TOKEN'], LookupSecret) assert ( result['GITHUB_TOKEN'].url == 'https://test.example.com/api/v1/webhooks/secrets' ) assert result['GITHUB_TOKEN'].headers['X-Access-Token'] == 'test_access_token' - self.mock_jwt_service.create_jws_token.assert_called_once_with( - payload={ - 'user_id': self.mock_user.id, - 'provider_type': git_provider.value, - }, - expires_in=None, - ) + # Should be called twice, once for each provider + assert self.mock_jwt_service.create_jws_token.call_count == 2 @pytest.mark.asyncio - async def test_setup_secrets_for_git_provider_with_saas_mode(self): - """Test _setup_secrets_for_git_provider with SaaS mode (includes keycloak cookie).""" + async def test_setup_secrets_for_git_providers_with_saas_mode(self): + """Test _setup_secrets_for_git_providers with SaaS mode (includes keycloak cookie).""" # Arrange + from pydantic import SecretStr + + from openhands.integrations.provider import ProviderToken + self.service.app_mode = 'saas' self.service.keycloak_auth_cookie = 'test_cookie' base_secrets = {} self.mock_user_context.get_secrets.return_value = base_secrets self.mock_jwt_service.create_jws_token.return_value = 'test_access_token' - git_provider = ProviderType.GITLAB + + # Mock provider tokens + provider_tokens = { + ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab_token')), + } + self.mock_user_context.get_provider_tokens = AsyncMock( + return_value=provider_tokens + ) # Act - result = await self.service._setup_secrets_for_git_provider( - git_provider, self.mock_user - ) + result = await self.service._setup_secrets_for_git_providers(self.mock_user) # Assert assert 'GITLAB_TOKEN' in result @@ -141,40 +159,60 @@ class TestLiveStatusAppConversationService: assert lookup_secret.headers['Cookie'] == 'keycloak_auth=test_cookie' @pytest.mark.asyncio - async def test_setup_secrets_for_git_provider_without_web_url(self): - """Test _setup_secrets_for_git_provider without web URL (uses static token).""" + async def test_setup_secrets_for_git_providers_without_web_url(self): + """Test _setup_secrets_for_git_providers without web URL (uses static token).""" # Arrange + from pydantic import SecretStr + + from openhands.integrations.provider import ProviderToken + self.service.web_url = None base_secrets = {} self.mock_user_context.get_secrets.return_value = base_secrets self.mock_user_context.get_latest_token.return_value = 'static_token_value' - git_provider = ProviderType.GITHUB + + # Mock provider tokens + provider_tokens = { + ProviderType.GITHUB: ProviderToken(token=SecretStr('github_token')), + } + self.mock_user_context.get_provider_tokens = AsyncMock( + return_value=provider_tokens + ) # Act - result = await self.service._setup_secrets_for_git_provider( - git_provider, self.mock_user - ) + result = await self.service._setup_secrets_for_git_providers(self.mock_user) # Assert assert 'GITHUB_TOKEN' in result assert isinstance(result['GITHUB_TOKEN'], StaticSecret) assert result['GITHUB_TOKEN'].value.get_secret_value() == 'static_token_value' - self.mock_user_context.get_latest_token.assert_called_once_with(git_provider) + self.mock_user_context.get_latest_token.assert_called_once_with( + ProviderType.GITHUB + ) @pytest.mark.asyncio - async def test_setup_secrets_for_git_provider_no_static_token(self): - """Test _setup_secrets_for_git_provider when no static token is available.""" + async def test_setup_secrets_for_git_providers_no_static_token(self): + """Test _setup_secrets_for_git_providers when no static token is available.""" # Arrange + from pydantic import SecretStr + + from openhands.integrations.provider import ProviderToken + self.service.web_url = None base_secrets = {} self.mock_user_context.get_secrets.return_value = base_secrets self.mock_user_context.get_latest_token.return_value = None - git_provider = ProviderType.GITHUB + + # Mock provider tokens + provider_tokens = { + ProviderType.GITHUB: ProviderToken(token=SecretStr('github_token')), + } + self.mock_user_context.get_provider_tokens = AsyncMock( + return_value=provider_tokens + ) # Act - result = await self.service._setup_secrets_for_git_provider( - git_provider, self.mock_user - ) + result = await self.service._setup_secrets_for_git_providers(self.mock_user) # Assert assert 'GITHUB_TOKEN' not in result @@ -677,7 +715,7 @@ class TestLiveStatusAppConversationService: mock_agent = Mock(spec=Agent) mock_final_request = Mock(spec=StartConversationRequest) - self.service._setup_secrets_for_git_provider = AsyncMock( + self.service._setup_secrets_for_git_providers = AsyncMock( return_value=mock_secrets ) self.service._configure_llm_and_mcp = AsyncMock( @@ -705,8 +743,8 @@ class TestLiveStatusAppConversationService: # Assert assert result == mock_final_request - self.service._setup_secrets_for_git_provider.assert_called_once_with( - ProviderType.GITHUB, self.mock_user + self.service._setup_secrets_for_git_providers.assert_called_once_with( + self.mock_user ) self.service._configure_llm_and_mcp.assert_called_once_with( self.mock_user, 'gpt-4' diff --git a/tests/unit/experiments/test_experiment_manager.py b/tests/unit/experiments/test_experiment_manager.py index a21943ace2..f726c1eca3 100644 --- a/tests/unit/experiments/test_experiment_manager.py +++ b/tests/unit/experiments/test_experiment_manager.py @@ -203,7 +203,7 @@ class TestExperimentManagerIntegration: with ( patch.object( service, - '_setup_secrets_for_git_provider', + '_setup_secrets_for_git_providers', return_value={}, ), patch.object(