diff --git a/enterprise/server/saas_nested_conversation_manager.py b/enterprise/server/saas_nested_conversation_manager.py index 59c6d4e981..5576e21cba 100644 --- a/enterprise/server/saas_nested_conversation_manager.py +++ b/enterprise/server/saas_nested_conversation_manager.py @@ -12,6 +12,8 @@ from typing import Any, cast import httpx import socketio +from pydantic import SecretStr +from server.auth.token_manager import TokenManager from server.constants import PERMITTED_CORS_ORIGINS, WEB_HOST from server.utils.conversation_callback_utils import ( process_event, @@ -29,7 +31,11 @@ from openhands.core.logger import openhands_logger as logger from openhands.events.action import MessageAction from openhands.events.event_store import EventStore from openhands.events.serialization.event import event_to_dict -from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler +from openhands.integrations.provider import ( + PROVIDER_TOKEN_TYPE, + ProviderHandler, + ProviderToken, +) from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime from openhands.runtime.plugins.vscode import VSCodeRequirement from openhands.runtime.runtime_status import RuntimeStatus @@ -228,6 +234,102 @@ class SaasNestedConversationManager(ConversationManager): status=status, ) + async def _refresh_provider_tokens_after_runtime_init( + self, settings: Settings, sid: str, user_id: str | None = None + ) -> Settings: + """Refresh provider tokens after runtime initialization. + + During runtime initialization, tokens may be refreshed by Runtime.__init__(). + This method retrieves the fresh tokens from the database and creates a new + settings object with updated tokens to avoid sending stale tokens to the + nested runtime. + + The method handles two scenarios: + 1. ProviderToken has user_id (IDP user ID, e.g., GitLab user ID) + → Uses get_idp_token_from_idp_user_id() + 2. ProviderToken has no user_id but Keycloak user_id is available + → Uses load_offline_token() + get_idp_token_from_offline_token() + + Args: + settings: The conversation settings that may contain provider tokens + sid: The session ID for logging purposes + user_id: The Keycloak user ID (optional, used as fallback when + ProviderToken.user_id is not available) + + Returns: + Updated settings with fresh provider tokens, or original settings + if no update is needed + """ + if not isinstance(settings, ConversationInitData): + return settings + + if not settings.git_provider_tokens: + return settings + + token_manager = TokenManager() + updated_tokens = {} + tokens_refreshed = 0 + tokens_failed = 0 + + for provider_type, provider_token in settings.git_provider_tokens.items(): + fresh_token = None + + try: + if provider_token.user_id: + # Case 1: We have IDP user ID (e.g., GitLab user ID '32546706') + # Get the token that was just refreshed during runtime initialization + fresh_token = await token_manager.get_idp_token_from_idp_user_id( + provider_token.user_id, provider_type + ) + elif user_id: + # Case 2: We have Keycloak user ID but no IDP user ID + # This happens in web UI flow where ProviderToken.user_id is None + offline_token = await token_manager.load_offline_token(user_id) + if offline_token: + fresh_token = ( + await token_manager.get_idp_token_from_offline_token( + offline_token, provider_type + ) + ) + + if fresh_token: + updated_tokens[provider_type] = ProviderToken( + token=SecretStr(fresh_token), + user_id=provider_token.user_id, + host=provider_token.host, + ) + tokens_refreshed += 1 + else: + # Keep original token if we couldn't get a fresh one + updated_tokens[provider_type] = provider_token + + except Exception as e: + # If refresh fails, use original token to prevent conversation startup failure + logger.warning( + f'Failed to refresh {provider_type.value} token: {e}', + extra={'session_id': sid, 'provider': provider_type.value}, + exc_info=True, + ) + updated_tokens[provider_type] = provider_token + tokens_failed += 1 + + # Create new ConversationInitData with updated tokens + # We cannot modify the frozen field directly, so we create a new object + updated_settings = settings.model_copy( + update={'git_provider_tokens': MappingProxyType(updated_tokens)} + ) + + logger.info( + 'Updated provider tokens after runtime creation', + extra={ + 'session_id': sid, + 'providers': [p.value for p in updated_tokens.keys()], + 'refreshed': tokens_refreshed, + 'failed': tokens_failed, + }, + ) + return updated_settings + async def _start_agent_loop( self, sid, settings, user_id, initial_user_msg=None, replay_json=None ): @@ -249,6 +351,11 @@ class SaasNestedConversationManager(ConversationManager): session_api_key = runtime.session.headers['X-Session-API-Key'] + # Update provider tokens with fresh ones after runtime creation + settings = await self._refresh_provider_tokens_after_runtime_init( + settings, sid, user_id + ) + await self._start_conversation( sid, user_id, @@ -333,7 +440,12 @@ class SaasNestedConversationManager(ConversationManager): async def _setup_provider_tokens( self, client: httpx.AsyncClient, api_url: str, settings: Settings ): - """Setup provider tokens for the nested conversation.""" + """Setup provider tokens for the nested conversation. + + Note: Token validation happens in the nested runtime. If tokens are revoked, + the nested runtime will return 401. The caller should handle token refresh + and retry if needed. + """ provider_handler = self._get_provider_handler(settings) provider_tokens = provider_handler.provider_tokens if provider_tokens: diff --git a/enterprise/tests/unit/test_saas_nested_conversation_manager_token_refresh.py b/enterprise/tests/unit/test_saas_nested_conversation_manager_token_refresh.py new file mode 100644 index 0000000000..55780bd0a5 --- /dev/null +++ b/enterprise/tests/unit/test_saas_nested_conversation_manager_token_refresh.py @@ -0,0 +1,437 @@ +""" +TDD Tests for SaasNestedConversationManager token refresh functionality. + +This module tests the token refresh logic that prevents stale tokens from being +sent to nested runtimes after Runtime.__init__() refreshes them. + +Test Coverage: +- Token refresh with IDP user ID (GitLab webhook flow) +- Token refresh with Keycloak user ID (Web UI flow) +- Error handling and fallback behavior +- Settings immutability handling +""" + +from types import MappingProxyType +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from pydantic import SecretStr + +from enterprise.server.saas_nested_conversation_manager import ( + SaasNestedConversationManager, +) +from openhands.integrations.provider import ProviderToken, ProviderType +from openhands.server.session.conversation_init_data import ConversationInitData +from openhands.storage.data_models.settings import Settings + + +class TestRefreshProviderTokensAfterRuntimeInit: + """Test suite for _refresh_provider_tokens_after_runtime_init method.""" + + @pytest.fixture + def conversation_manager(self): + """Create a minimal SaasNestedConversationManager instance for testing.""" + # Arrange: Create mock dependencies + mock_sio = Mock() + mock_config = Mock() + mock_config.max_concurrent_conversations = 5 + mock_server_config = Mock() + mock_file_store = Mock() + + # Create manager instance + manager = SaasNestedConversationManager( + sio=mock_sio, + config=mock_config, + server_config=mock_server_config, + file_store=mock_file_store, + event_retrieval=Mock(), + ) + return manager + + @pytest.fixture + def gitlab_provider_token_with_user_id(self): + """Create a GitLab ProviderToken with IDP user ID (webhook flow).""" + return ProviderToken( + token=SecretStr('old_token_abc123'), + user_id='32546706', # GitLab user ID + host=None, + ) + + @pytest.fixture + def gitlab_provider_token_without_user_id(self): + """Create a GitLab ProviderToken without IDP user ID (web UI flow).""" + return ProviderToken( + token=SecretStr('old_token_xyz789'), + user_id=None, + host=None, + ) + + @pytest.fixture + def conversation_init_data_with_user_id(self, gitlab_provider_token_with_user_id): + """Create ConversationInitData with provider token containing user_id.""" + return ConversationInitData( + git_provider_tokens=MappingProxyType( + {ProviderType.GITLAB: gitlab_provider_token_with_user_id} + ) + ) + + @pytest.fixture + def conversation_init_data_without_user_id( + self, gitlab_provider_token_without_user_id + ): + """Create ConversationInitData with provider token without user_id.""" + return ConversationInitData( + git_provider_tokens=MappingProxyType( + {ProviderType.GITLAB: gitlab_provider_token_without_user_id} + ) + ) + + @pytest.mark.asyncio + async def test_returns_original_settings_when_not_conversation_init_data( + self, conversation_manager + ): + """ + Test: Returns original settings when not ConversationInitData. + + Arrange: Create a Settings object (not ConversationInitData) + Act: Call _refresh_provider_tokens_after_runtime_init + Assert: Returns the same settings object unchanged + """ + # Arrange + settings = Settings() + sid = 'test_session_123' + + # Act + result = await conversation_manager._refresh_provider_tokens_after_runtime_init( + settings, sid + ) + + # Assert + assert result is settings + + @pytest.mark.asyncio + async def test_returns_original_settings_when_no_provider_tokens( + self, conversation_manager + ): + """ + Test: Returns original settings when no provider tokens present. + + Arrange: Create ConversationInitData without git_provider_tokens + Act: Call _refresh_provider_tokens_after_runtime_init + Assert: Returns the same settings object unchanged + """ + # Arrange + settings = ConversationInitData(git_provider_tokens=None) + sid = 'test_session_456' + + # Act + result = await conversation_manager._refresh_provider_tokens_after_runtime_init( + settings, sid + ) + + # Assert + assert result is settings + + @pytest.mark.asyncio + async def test_refreshes_token_with_idp_user_id( + self, conversation_manager, conversation_init_data_with_user_id + ): + """ + Test: Refreshes token using IDP user ID (GitLab webhook flow). + + Arrange: ConversationInitData with GitLab token containing user_id + Act: Call _refresh_provider_tokens_after_runtime_init with mocked TokenManager + Assert: Token is refreshed using get_idp_token_from_idp_user_id + """ + # Arrange + sid = 'test_session_789' + fresh_token = 'fresh_token_def456' + + with patch( + 'enterprise.server.saas_nested_conversation_manager.TokenManager' + ) as mock_token_manager_class: + mock_token_manager = AsyncMock() + mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock( + return_value=fresh_token + ) + mock_token_manager_class.return_value = mock_token_manager + + # Act + result = ( + await conversation_manager._refresh_provider_tokens_after_runtime_init( + conversation_init_data_with_user_id, sid + ) + ) + + # Assert + mock_token_manager.get_idp_token_from_idp_user_id.assert_called_once_with( + '32546706', ProviderType.GITLAB + ) + assert ( + result.git_provider_tokens[ProviderType.GITLAB].token.get_secret_value() + == fresh_token + ) + assert result.git_provider_tokens[ProviderType.GITLAB].user_id == '32546706' + + @pytest.mark.asyncio + async def test_refreshes_token_with_keycloak_user_id( + self, conversation_manager, conversation_init_data_without_user_id + ): + """ + Test: Refreshes token using Keycloak user ID (Web UI flow). + + Arrange: ConversationInitData without IDP user_id, but with Keycloak user_id + Act: Call _refresh_provider_tokens_after_runtime_init with mocked TokenManager + Assert: Token is refreshed using load_offline_token + get_idp_token_from_offline_token + """ + # Arrange + sid = 'test_session_101' + keycloak_user_id = 'keycloak_user_abc' + offline_token = 'offline_token_xyz' + fresh_token = 'fresh_token_ghi789' + + with patch( + 'enterprise.server.saas_nested_conversation_manager.TokenManager' + ) as mock_token_manager_class: + mock_token_manager = AsyncMock() + mock_token_manager.load_offline_token = AsyncMock( + return_value=offline_token + ) + mock_token_manager.get_idp_token_from_offline_token = AsyncMock( + return_value=fresh_token + ) + mock_token_manager_class.return_value = mock_token_manager + + # Act + result = ( + await conversation_manager._refresh_provider_tokens_after_runtime_init( + conversation_init_data_without_user_id, sid, keycloak_user_id + ) + ) + + # Assert + mock_token_manager.load_offline_token.assert_called_once_with( + keycloak_user_id + ) + mock_token_manager.get_idp_token_from_offline_token.assert_called_once_with( + offline_token, ProviderType.GITLAB + ) + assert ( + result.git_provider_tokens[ProviderType.GITLAB].token.get_secret_value() + == fresh_token + ) + assert result.git_provider_tokens[ProviderType.GITLAB].user_id is None + + @pytest.mark.asyncio + async def test_keeps_original_token_when_refresh_fails( + self, conversation_manager, conversation_init_data_with_user_id + ): + """ + Test: Keeps original token when refresh fails (error handling). + + Arrange: ConversationInitData with token, TokenManager raises exception + Act: Call _refresh_provider_tokens_after_runtime_init + Assert: Original token is preserved, no exception raised + """ + # Arrange + sid = 'test_session_error' + original_token = conversation_init_data_with_user_id.git_provider_tokens[ + ProviderType.GITLAB + ].token.get_secret_value() + + with patch( + 'enterprise.server.saas_nested_conversation_manager.TokenManager' + ) as mock_token_manager_class: + mock_token_manager = AsyncMock() + mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock( + side_effect=Exception('Token refresh failed') + ) + mock_token_manager_class.return_value = mock_token_manager + + # Act + result = ( + await conversation_manager._refresh_provider_tokens_after_runtime_init( + conversation_init_data_with_user_id, sid + ) + ) + + # Assert + assert ( + result.git_provider_tokens[ProviderType.GITLAB].token.get_secret_value() + == original_token + ) + + @pytest.mark.asyncio + async def test_keeps_original_token_when_no_fresh_token_available( + self, conversation_manager, conversation_init_data_with_user_id + ): + """ + Test: Keeps original token when no fresh token is available. + + Arrange: ConversationInitData with token, TokenManager returns None + Act: Call _refresh_provider_tokens_after_runtime_init + Assert: Original token is preserved + """ + # Arrange + sid = 'test_session_no_fresh' + original_token = conversation_init_data_with_user_id.git_provider_tokens[ + ProviderType.GITLAB + ].token.get_secret_value() + + with patch( + 'enterprise.server.saas_nested_conversation_manager.TokenManager' + ) as mock_token_manager_class: + mock_token_manager = AsyncMock() + mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock( + return_value=None + ) + mock_token_manager_class.return_value = mock_token_manager + + # Act + result = ( + await conversation_manager._refresh_provider_tokens_after_runtime_init( + conversation_init_data_with_user_id, sid + ) + ) + + # Assert + assert ( + result.git_provider_tokens[ProviderType.GITLAB].token.get_secret_value() + == original_token + ) + + @pytest.mark.asyncio + async def test_creates_new_settings_object_preserving_immutability( + self, conversation_manager, conversation_init_data_with_user_id + ): + """ + Test: Creates new settings object (respects Pydantic frozen fields). + + Arrange: ConversationInitData with frozen git_provider_tokens field + Act: Call _refresh_provider_tokens_after_runtime_init + Assert: Returns a new ConversationInitData object, not the same instance + """ + # Arrange + sid = 'test_session_immutable' + fresh_token = 'fresh_token_new' + + with patch( + 'enterprise.server.saas_nested_conversation_manager.TokenManager' + ) as mock_token_manager_class: + mock_token_manager = AsyncMock() + mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock( + return_value=fresh_token + ) + mock_token_manager_class.return_value = mock_token_manager + + # Act + result = ( + await conversation_manager._refresh_provider_tokens_after_runtime_init( + conversation_init_data_with_user_id, sid + ) + ) + + # Assert + assert result is not conversation_init_data_with_user_id + assert isinstance(result, ConversationInitData) + + @pytest.mark.asyncio + async def test_handles_multiple_providers(self, conversation_manager): + """ + Test: Handles multiple provider tokens correctly. + + Arrange: ConversationInitData with both GitLab and GitHub tokens + Act: Call _refresh_provider_tokens_after_runtime_init + Assert: Both tokens are refreshed independently + """ + # Arrange + sid = 'test_session_multi' + gitlab_token = ProviderToken( + token=SecretStr('old_gitlab_token'), user_id='gitlab_user_123', host=None + ) + github_token = ProviderToken( + token=SecretStr('old_github_token'), user_id='github_user_456', host=None + ) + settings = ConversationInitData( + git_provider_tokens=MappingProxyType( + {ProviderType.GITLAB: gitlab_token, ProviderType.GITHUB: github_token} + ) + ) + + fresh_gitlab_token = 'fresh_gitlab_token' + fresh_github_token = 'fresh_github_token' + + with patch( + 'enterprise.server.saas_nested_conversation_manager.TokenManager' + ) as mock_token_manager_class: + mock_token_manager = AsyncMock() + + async def mock_get_token(user_id, provider_type): + if provider_type == ProviderType.GITLAB: + return fresh_gitlab_token + elif provider_type == ProviderType.GITHUB: + return fresh_github_token + return None + + mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock( + side_effect=mock_get_token + ) + mock_token_manager_class.return_value = mock_token_manager + + # Act + result = ( + await conversation_manager._refresh_provider_tokens_after_runtime_init( + settings, sid + ) + ) + + # Assert + assert ( + result.git_provider_tokens[ProviderType.GITLAB].token.get_secret_value() + == fresh_gitlab_token + ) + assert ( + result.git_provider_tokens[ProviderType.GITHUB].token.get_secret_value() + == fresh_github_token + ) + assert mock_token_manager.get_idp_token_from_idp_user_id.call_count == 2 + + @pytest.mark.asyncio + async def test_preserves_token_host_field(self, conversation_manager): + """ + Test: Preserves the host field from original token. + + Arrange: ProviderToken with custom host value + Act: Call _refresh_provider_tokens_after_runtime_init + Assert: Host field is preserved in the refreshed token + """ + # Arrange + sid = 'test_session_host' + custom_host = 'gitlab.example.com' + token_with_host = ProviderToken( + token=SecretStr('old_token'), user_id='user_789', host=custom_host + ) + settings = ConversationInitData( + git_provider_tokens=MappingProxyType({ProviderType.GITLAB: token_with_host}) + ) + + fresh_token = 'fresh_token_with_host' + + with patch( + 'enterprise.server.saas_nested_conversation_manager.TokenManager' + ) as mock_token_manager_class: + mock_token_manager = AsyncMock() + mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock( + return_value=fresh_token + ) + mock_token_manager_class.return_value = mock_token_manager + + # Act + result = ( + await conversation_manager._refresh_provider_tokens_after_runtime_init( + settings, sid + ) + ) + + # Assert + assert result.git_provider_tokens[ProviderType.GITLAB].host == custom_host