APP-216 Support multiple git providers in conversation secrets (#11908)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2025-12-05 11:50:45 -07:00
committed by GitHub
parent 7811a62491
commit 72c7d9c497
7 changed files with 134 additions and 75 deletions

View File

@@ -86,6 +86,9 @@ class GithubUserContext(UserContext):
# For now, return a basic HTTPS URL
return f'https://github.com/{repository}.git'
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
return self.git_provider_tokens
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
# Return the appropriate token from git_provider_tokens
if provider_type == ProviderType.GITHUB and self.git_provider_tokens:

View File

@@ -526,13 +526,10 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
if not request.llm_model and parent_info.llm_model:
request.llm_model = parent_info.llm_model
async def _setup_secrets_for_git_provider(
self, git_provider: ProviderType | None, user: UserInfo
) -> dict:
"""Set up secrets for git provider authentication.
async def _setup_secrets_for_git_providers(self, user: UserInfo) -> dict:
"""Set up secrets for all git provider authentication.
Args:
git_provider: The git provider type (GitHub, GitLab, etc.)
user: User information containing authentication details
Returns:
@@ -540,35 +537,42 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
"""
secrets = await self.user_context.get_secrets()
if not git_provider:
# Get all provider tokens from user authentication
provider_tokens = await self.user_context.get_provider_tokens()
if not provider_tokens:
return secrets
secret_name = f'{git_provider.name}_TOKEN'
# Create secrets for each provider token
for provider_type, provider_token in provider_tokens.items():
if not provider_token.token:
continue
if self.web_url:
# Create an access token for web-based authentication
access_token = self.jwt_service.create_jws_token(
payload={
'user_id': user.id,
'provider_type': git_provider.value,
},
expires_in=self.access_token_hard_timeout,
)
headers = {'X-Access-Token': access_token}
secret_name = f'{provider_type.name}_TOKEN'
# Include keycloak_auth cookie in headers if app_mode is SaaS
if self.app_mode == 'saas' and self.keycloak_auth_cookie:
headers['Cookie'] = f'keycloak_auth={self.keycloak_auth_cookie}'
if self.web_url:
# Create an access token for web-based authentication
access_token = self.jwt_service.create_jws_token(
payload={
'user_id': user.id,
'provider_type': provider_type.value,
},
expires_in=self.access_token_hard_timeout,
)
headers = {'X-Access-Token': access_token}
secrets[secret_name] = LookupSecret(
url=self.web_url + '/api/v1/webhooks/secrets',
headers=headers,
)
else:
# Use static token for environments without web URL access
static_token = await self.user_context.get_latest_token(git_provider)
if static_token:
secrets[secret_name] = StaticSecret(value=static_token)
# Include keycloak_auth cookie in headers if app_mode is SaaS
if self.app_mode == 'saas' and self.keycloak_auth_cookie:
headers['Cookie'] = f'keycloak_auth={self.keycloak_auth_cookie}'
secrets[secret_name] = LookupSecret(
url=self.web_url + '/api/v1/webhooks/secrets',
headers=headers,
)
else:
# Use static token for environments without web URL access
static_token = await self.user_context.get_latest_token(provider_type)
if static_token:
secrets[secret_name] = StaticSecret(value=static_token)
return secrets
@@ -768,8 +772,8 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
user = await self.user_context.get_user_info()
workspace = LocalWorkspace(working_dir=working_dir)
# Set up secrets for git provider
secrets = await self._setup_secrets_for_git_provider(git_provider, user)
# Set up secrets for all git providers
secrets = await self._setup_secrets_for_git_providers(user)
# Configure LLM and MCP
llm, mcp_config = await self._configure_llm_and_mcp(user, llm_model)

View File

@@ -9,7 +9,11 @@ from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
from openhands.app_server.user.user_context import UserContext, UserContextInjector
from openhands.app_server.user.user_models import UserInfo
from openhands.integrations.provider import ProviderHandler, ProviderType
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderHandler,
ProviderType,
)
from openhands.sdk.conversation.secret_source import SecretSource, StaticSecret
from openhands.server.user_auth.user_auth import UserAuth, get_user_auth
@@ -44,6 +48,9 @@ class AuthUserContext(UserContext):
self._user_info = user_info
return user_info
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
return await self.user_auth.get_provider_tokens()
async def get_provider_handler(self):
provider_handler = self._provider_handler
if not provider_handler:

View File

@@ -5,7 +5,7 @@ from fastapi import Request
from openhands.app_server.errors import OpenHandsError
from openhands.app_server.user.user_context import UserContext
from openhands.app_server.user.user_models import UserInfo
from openhands.integrations.provider import ProviderType
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
from openhands.sdk.conversation.secret_source import SecretSource
@@ -24,6 +24,9 @@ class SpecifyUserContext(UserContext):
async def get_authenticated_git_url(self, repository: str) -> str:
raise NotImplementedError()
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
raise NotImplementedError()
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
raise NotImplementedError()

View File

@@ -4,7 +4,7 @@ from openhands.app_server.services.injector import Injector
from openhands.app_server.user.user_models import (
UserInfo,
)
from openhands.integrations.provider import ProviderType
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
from openhands.sdk.conversation.secret_source import SecretSource
from openhands.sdk.utils.models import DiscriminatedUnionMixin
@@ -26,6 +26,10 @@ class UserContext(ABC):
async def get_authenticated_git_url(self, repository: str) -> str:
"""Get the provider tokens for the user"""
@abstractmethod
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
"""Get the latest tokens for all provider types"""
@abstractmethod
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
"""Get the latest token for the provider type given"""

View File

@@ -28,6 +28,8 @@ class TestLiveStatusAppConversationService:
"""Set up test fixtures."""
# Create mock dependencies
self.mock_user_context = Mock(spec=UserContext)
self.mock_user_auth = Mock()
self.mock_user_context.user_auth = self.mock_user_auth
self.mock_jwt_service = Mock()
self.mock_sandbox_service = Mock()
self.mock_sandbox_spec_service = Mock()
@@ -71,67 +73,83 @@ class TestLiveStatusAppConversationService:
self.mock_sandbox.status = SandboxStatus.RUNNING
@pytest.mark.asyncio
async def test_setup_secrets_for_git_provider_no_provider(self):
"""Test _setup_secrets_for_git_provider with no git provider."""
async def test_setup_secrets_for_git_providers_no_provider_tokens(self):
"""Test _setup_secrets_for_git_providers with no provider tokens."""
# Arrange
base_secrets = {'existing': 'secret'}
self.mock_user_context.get_secrets.return_value = base_secrets
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
# Act
result = await self.service._setup_secrets_for_git_provider(
None, self.mock_user
)
result = await self.service._setup_secrets_for_git_providers(self.mock_user)
# Assert
assert result == base_secrets
self.mock_user_context.get_secrets.assert_called_once()
self.mock_user_context.get_provider_tokens.assert_called_once()
@pytest.mark.asyncio
async def test_setup_secrets_for_git_provider_with_web_url(self):
"""Test _setup_secrets_for_git_provider with web URL (creates access token)."""
async def test_setup_secrets_for_git_providers_with_web_url(self):
"""Test _setup_secrets_for_git_providers with web URL (creates access token)."""
# Arrange
from pydantic import SecretStr
from openhands.integrations.provider import ProviderToken
base_secrets = {}
self.mock_user_context.get_secrets.return_value = base_secrets
self.mock_jwt_service.create_jws_token.return_value = 'test_access_token'
git_provider = ProviderType.GITHUB
# Mock provider tokens
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github_token')),
ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab_token')),
}
self.mock_user_context.get_provider_tokens = AsyncMock(
return_value=provider_tokens
)
# Act
result = await self.service._setup_secrets_for_git_provider(
git_provider, self.mock_user
)
result = await self.service._setup_secrets_for_git_providers(self.mock_user)
# Assert
assert 'GITHUB_TOKEN' in result
assert 'GITLAB_TOKEN' in result
assert isinstance(result['GITHUB_TOKEN'], LookupSecret)
assert isinstance(result['GITLAB_TOKEN'], LookupSecret)
assert (
result['GITHUB_TOKEN'].url
== 'https://test.example.com/api/v1/webhooks/secrets'
)
assert result['GITHUB_TOKEN'].headers['X-Access-Token'] == 'test_access_token'
self.mock_jwt_service.create_jws_token.assert_called_once_with(
payload={
'user_id': self.mock_user.id,
'provider_type': git_provider.value,
},
expires_in=None,
)
# Should be called twice, once for each provider
assert self.mock_jwt_service.create_jws_token.call_count == 2
@pytest.mark.asyncio
async def test_setup_secrets_for_git_provider_with_saas_mode(self):
"""Test _setup_secrets_for_git_provider with SaaS mode (includes keycloak cookie)."""
async def test_setup_secrets_for_git_providers_with_saas_mode(self):
"""Test _setup_secrets_for_git_providers with SaaS mode (includes keycloak cookie)."""
# Arrange
from pydantic import SecretStr
from openhands.integrations.provider import ProviderToken
self.service.app_mode = 'saas'
self.service.keycloak_auth_cookie = 'test_cookie'
base_secrets = {}
self.mock_user_context.get_secrets.return_value = base_secrets
self.mock_jwt_service.create_jws_token.return_value = 'test_access_token'
git_provider = ProviderType.GITLAB
# Mock provider tokens
provider_tokens = {
ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab_token')),
}
self.mock_user_context.get_provider_tokens = AsyncMock(
return_value=provider_tokens
)
# Act
result = await self.service._setup_secrets_for_git_provider(
git_provider, self.mock_user
)
result = await self.service._setup_secrets_for_git_providers(self.mock_user)
# Assert
assert 'GITLAB_TOKEN' in result
@@ -141,40 +159,60 @@ class TestLiveStatusAppConversationService:
assert lookup_secret.headers['Cookie'] == 'keycloak_auth=test_cookie'
@pytest.mark.asyncio
async def test_setup_secrets_for_git_provider_without_web_url(self):
"""Test _setup_secrets_for_git_provider without web URL (uses static token)."""
async def test_setup_secrets_for_git_providers_without_web_url(self):
"""Test _setup_secrets_for_git_providers without web URL (uses static token)."""
# Arrange
from pydantic import SecretStr
from openhands.integrations.provider import ProviderToken
self.service.web_url = None
base_secrets = {}
self.mock_user_context.get_secrets.return_value = base_secrets
self.mock_user_context.get_latest_token.return_value = 'static_token_value'
git_provider = ProviderType.GITHUB
# Mock provider tokens
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github_token')),
}
self.mock_user_context.get_provider_tokens = AsyncMock(
return_value=provider_tokens
)
# Act
result = await self.service._setup_secrets_for_git_provider(
git_provider, self.mock_user
)
result = await self.service._setup_secrets_for_git_providers(self.mock_user)
# Assert
assert 'GITHUB_TOKEN' in result
assert isinstance(result['GITHUB_TOKEN'], StaticSecret)
assert result['GITHUB_TOKEN'].value.get_secret_value() == 'static_token_value'
self.mock_user_context.get_latest_token.assert_called_once_with(git_provider)
self.mock_user_context.get_latest_token.assert_called_once_with(
ProviderType.GITHUB
)
@pytest.mark.asyncio
async def test_setup_secrets_for_git_provider_no_static_token(self):
"""Test _setup_secrets_for_git_provider when no static token is available."""
async def test_setup_secrets_for_git_providers_no_static_token(self):
"""Test _setup_secrets_for_git_providers when no static token is available."""
# Arrange
from pydantic import SecretStr
from openhands.integrations.provider import ProviderToken
self.service.web_url = None
base_secrets = {}
self.mock_user_context.get_secrets.return_value = base_secrets
self.mock_user_context.get_latest_token.return_value = None
git_provider = ProviderType.GITHUB
# Mock provider tokens
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github_token')),
}
self.mock_user_context.get_provider_tokens = AsyncMock(
return_value=provider_tokens
)
# Act
result = await self.service._setup_secrets_for_git_provider(
git_provider, self.mock_user
)
result = await self.service._setup_secrets_for_git_providers(self.mock_user)
# Assert
assert 'GITHUB_TOKEN' not in result
@@ -677,7 +715,7 @@ class TestLiveStatusAppConversationService:
mock_agent = Mock(spec=Agent)
mock_final_request = Mock(spec=StartConversationRequest)
self.service._setup_secrets_for_git_provider = AsyncMock(
self.service._setup_secrets_for_git_providers = AsyncMock(
return_value=mock_secrets
)
self.service._configure_llm_and_mcp = AsyncMock(
@@ -705,8 +743,8 @@ class TestLiveStatusAppConversationService:
# Assert
assert result == mock_final_request
self.service._setup_secrets_for_git_provider.assert_called_once_with(
ProviderType.GITHUB, self.mock_user
self.service._setup_secrets_for_git_providers.assert_called_once_with(
self.mock_user
)
self.service._configure_llm_and_mcp.assert_called_once_with(
self.mock_user, 'gpt-4'

View File

@@ -203,7 +203,7 @@ class TestExperimentManagerIntegration:
with (
patch.object(
service,
'_setup_secrets_for_git_provider',
'_setup_secrets_for_git_providers',
return_value={},
),
patch.object(