From 0e825c38d70946af0c82a1ea532680b7e7c4bca0 Mon Sep 17 00:00:00 2001 From: chuckbutkus Date: Tue, 3 Feb 2026 12:50:40 -0500 Subject: [PATCH] APP-443: Fix key generation (#12726) Co-authored-by: openhands Co-authored-by: tofarr --- enterprise/storage/lite_llm_manager.py | 125 ++++++- enterprise/storage/org_member_store.py | 22 +- enterprise/storage/saas_settings_store.py | 97 +++--- .../tests/unit/test_lite_llm_manager.py | 324 +++++++++++++++++- .../tests/unit/test_saas_settings_store.py | 44 +-- openhands/server/routes/settings.py | 6 +- .../routes/test_settings_store_functions.py | 21 +- 7 files changed, 537 insertions(+), 102 deletions(-) diff --git a/enterprise/storage/lite_llm_manager.py b/enterprise/storage/lite_llm_manager.py index 7517c3633f..e4380e29f3 100644 --- a/enterprise/storage/lite_llm_manager.py +++ b/enterprise/storage/lite_llm_manager.py @@ -24,13 +24,23 @@ from storage.user_settings import UserSettings from openhands.server.settings import Settings from openhands.utils.http_session import httpx_verify_option -# Timeout in seconds for BYOR key verification requests to LiteLLM -BYOR_KEY_VERIFICATION_TIMEOUT = 5.0 +# Timeout in seconds for key verification requests to LiteLLM +KEY_VERIFICATION_TIMEOUT = 5.0 # A very large number to represent "unlimited" until LiteLLM fixes their unlimited update bug. UNLIMITED_BUDGET_SETTING = 1000000000.0 +def get_openhands_cloud_key_alias(keycloak_user_id: str, org_id: str) -> str: + """Generate the key alias for OpenHands Cloud managed keys.""" + return f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}' + + +def get_byor_key_alias(keycloak_user_id: str, org_id: str) -> str: + """Generate the key alias for BYOR (Bring Your Own Runtime) keys.""" + return f'BYOR Key - user {keycloak_user_id}, org {org_id}' + + class LiteLlmManager: """Manage LiteLLM interactions.""" @@ -79,7 +89,7 @@ class LiteLlmManager: client, keycloak_user_id, org_id, - f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}', + get_openhands_cloud_key_alias(keycloak_user_id, org_id), None, ) @@ -251,7 +261,7 @@ class LiteLlmManager: client, keycloak_user_id, org_id, - f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}', + get_openhands_cloud_key_alias(keycloak_user_id, org_id), None, ) if new_key: @@ -1044,7 +1054,7 @@ class LiteLlmManager: try: async with httpx.AsyncClient( verify=httpx_verify_option(), - timeout=BYOR_KEY_VERIFICATION_TIMEOUT, + timeout=KEY_VERIFICATION_TIMEOUT, ) as client: # Make a lightweight request to verify the key # Using /v1/models endpoint as it's lightweight and requires authentication @@ -1058,7 +1068,7 @@ class LiteLlmManager: # Only 200 status code indicates valid key if response.status_code == 200: logger.debug( - 'BYOR key verification successful', + 'Key verification successful', extra={'user_id': user_id}, ) return True @@ -1066,7 +1076,7 @@ class LiteLlmManager: # All other status codes (401, 403, 500, etc.) are treated as invalid # This includes authentication errors and server errors logger.warning( - 'BYOR key verification failed - treating as invalid', + 'Key verification failed - treating as invalid', extra={ 'user_id': user_id, 'status_code': response.status_code, @@ -1079,7 +1089,7 @@ class LiteLlmManager: # Any exception (timeout, network error, etc.) means we can't verify # Return False to trigger regeneration rather than returning potentially invalid key logger.warning( - 'BYOR key verification error - treating as invalid to ensure key validity', + 'Key verification error - treating as invalid to ensure key validity', extra={ 'user_id': user_id, 'error': str(e), @@ -1123,6 +1133,103 @@ class LiteLlmManager: 'key_spend': key_info.get('spend'), } + @staticmethod + async def _get_all_keys_for_user( + client: httpx.AsyncClient, + keycloak_user_id: str, + ) -> list[dict]: + """Get all keys for a user from LiteLLM. + + Returns a list of key info dictionaries containing: + - token: the key value (hashed or partial) + - key_alias: the alias for the key + - key_name: the name of the key + - spend: the amount spent on this key + - max_budget: the max budget for this key + - team_id: the team the key belongs to + - metadata: any metadata associated with the key + + Returns an empty list if no keys found or on error. + """ + if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None: + logger.warning('LiteLLM API configuration not found') + return [] + + try: + response = await client.get( + f'{LITE_LLM_API_URL}/user/info?user_id={keycloak_user_id}', + headers={'x-goog-api-key': LITE_LLM_API_KEY}, + ) + response.raise_for_status() + user_json = response.json() + # The user/info endpoint returns keys in the 'keys' field + return user_json.get('keys', []) + except Exception as e: + logger.warning( + 'LiteLlmManager:_get_all_keys_for_user:error', + extra={ + 'user_id': keycloak_user_id, + 'error': str(e), + }, + ) + return [] + + @staticmethod + async def _verify_existing_key( + client: httpx.AsyncClient, + key_value: str, + keycloak_user_id: str, + org_id: str, + openhands_type: bool = False, + ) -> bool: + """Check if an existing key exists for the user/org in LiteLLM. + + Verifies the provided key_value matches a key registered in LiteLLM for + the given user and organization. For openhands_type=True, looks for keys + with metadata type='openhands' and matching team_id. For openhands_type=False, + looks for keys with matching alias and team_id. + + Returns True if the key is found and valid, False otherwise. + """ + found = False + keys = await LiteLlmManager._get_all_keys_for_user(client, keycloak_user_id) + for key_info in keys: + metadata = key_info.get('metadata') or {} + team_id = key_info.get('team_id') + key_alias = key_info.get('key_alias') + token = None + if ( + openhands_type + and metadata.get('type') == 'openhands' + and team_id == org_id + ): + # Found an existing OpenHands key for this org + key_name = key_info.get('key_name') + token = key_name[-4:] if key_name else None # last 4 digits of key + if token and key_value.endswith( + token + ): # check if this is our current key + found = True + break + if ( + not openhands_type + and team_id == org_id + and ( + key_alias == get_openhands_cloud_key_alias(keycloak_user_id, org_id) + or key_alias == get_byor_key_alias(keycloak_user_id, org_id) + ) + ): + # Found an existing key for this org (regardless of type) + key_name = key_info.get('key_name') + token = key_name[-4:] if key_name else None # last 4 digits of key + if token and key_value.endswith( + token + ): # check if this is our current key + found = True + break + + return found + @staticmethod async def _delete_key_by_alias( client: httpx.AsyncClient, @@ -1220,6 +1327,8 @@ class LiteLlmManager: update_user_in_team = staticmethod(with_http_client(_update_user_in_team)) generate_key = staticmethod(with_http_client(_generate_key)) get_key_info = staticmethod(with_http_client(_get_key_info)) + verify_existing_key = staticmethod(with_http_client(_verify_existing_key)) delete_key = staticmethod(with_http_client(_delete_key)) get_user_keys = staticmethod(with_http_client(_get_user_keys)) + delete_key_by_alias = staticmethod(with_http_client(_delete_key_by_alias)) update_user_keys = staticmethod(with_http_client(_update_user_keys)) diff --git a/enterprise/storage/org_member_store.py b/enterprise/storage/org_member_store.py index 39841180ed..9ff4485d2f 100644 --- a/enterprise/storage/org_member_store.py +++ b/enterprise/storage/org_member_store.py @@ -5,7 +5,8 @@ Store class for managing organization-member relationships. from typing import Optional from uuid import UUID -from storage.database import session_maker +from sqlalchemy import select +from storage.database import a_session_maker, session_maker from storage.org_member import OrgMember from storage.user_settings import UserSettings @@ -38,7 +39,7 @@ class OrgMemberStore: return org_member @staticmethod - def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]: + def get_org_member(org_id: UUID, user_id: UUID) -> Optional[OrgMember]: """Get organization-user relationship.""" with session_maker() as session: return ( @@ -48,7 +49,18 @@ class OrgMemberStore: ) @staticmethod - def get_user_orgs(user_id: int) -> list[OrgMember]: + async def get_org_member_async(org_id: UUID, user_id: UUID) -> Optional[OrgMember]: + """Get organization-user relationship.""" + async with a_session_maker() as session: + result = await session.execute( + select(OrgMember).filter( + OrgMember.org_id == org_id, OrgMember.user_id == user_id + ) + ) + return result.scalars().first() + + @staticmethod + def get_user_orgs(user_id: UUID) -> list[OrgMember]: """Get all organizations for a user.""" with session_maker() as session: return session.query(OrgMember).filter(OrgMember.user_id == user_id).all() @@ -68,7 +80,7 @@ class OrgMemberStore: @staticmethod def update_user_role_in_org( - org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None + org_id: UUID, user_id: UUID, role_id: int, status: Optional[str] = None ) -> Optional[OrgMember]: """Update user's role in an organization.""" with session_maker() as session: @@ -90,7 +102,7 @@ class OrgMemberStore: return org_member @staticmethod - def remove_user_from_org(org_id: UUID, user_id: int) -> bool: + def remove_user_from_org(org_id: UUID, user_id: UUID) -> bool: """Remove a user from an organization.""" with session_maker() as session: org_member = ( diff --git a/enterprise/storage/saas_settings_store.py b/enterprise/storage/saas_settings_store.py index 0af2cc6330..78925a9f72 100644 --- a/enterprise/storage/saas_settings_store.py +++ b/enterprise/storage/saas_settings_store.py @@ -8,10 +8,11 @@ from dataclasses import dataclass from cryptography.fernet import Fernet from pydantic import SecretStr +from server.constants import LITE_LLM_API_URL from server.logger import logger from sqlalchemy.orm import joinedload, sessionmaker from storage.database import session_maker -from storage.lite_llm_manager import LiteLlmManager +from storage.lite_llm_manager import LiteLlmManager, get_openhands_cloud_key_alias from storage.org import Org from storage.org_member import OrgMember from storage.org_store import OrgStore @@ -143,23 +144,34 @@ class SaasSettingsStore(SettingsStore): return None org_id = user.current_org_id - # Check if provider is OpenHands and generate API key if needed - if self._is_openhands_provider(item): - await self._ensure_openhands_api_key(item, str(org_id)) - org_member = None + + org_member: OrgMember = None for om in user.org_members: if om.org_id == org_id: org_member = om break if not org_member or not org_member.llm_api_key: return None - org = session.query(Org).filter(Org.id == org_id).first() + + org: Org = session.query(Org).filter(Org.id == org_id).first() if not org: logger.error( f'Org not found for ID {org_id} as the current org for user {self.user_id}' ) return None + llm_base_url = ( + org_member.llm_base_url + if org_member.llm_base_url + else org.default_llm_base_url + ) + + # Check if provider is OpenHands and generate API key if needed + if self._is_openhands_provider(item): + await self._ensure_api_key(item, str(org_id), openhands_type=True) + elif llm_base_url == LITE_LLM_API_URL: + await self._ensure_api_key(item, str(org_id)) + kwargs = item.model_dump(context={'expose_secrets': True}) for model in (user, org, org_member): for key, value in kwargs.items(): @@ -227,48 +239,49 @@ class SaasSettingsStore(SettingsStore): """Check if the settings use the OpenHands provider.""" return bool(item.llm_model and item.llm_model.startswith('openhands/')) - async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None: + async def _ensure_api_key( + self, item: Settings, org_id: str, openhands_type: bool = False + ) -> None: """Generate and set the OpenHands API key for the given settings. First checks if an existing key exists for the user and verifies it is valid in LiteLLM. If valid, reuses it. Otherwise, generates a new key. """ - # Check if user already has keys in LiteLLM - existing_keys = await LiteLlmManager.get_user_keys(self.user_id) - if existing_keys: - # Verify the first key is actually valid in LiteLLM before reusing - # This handles cases where keys exist in our DB but were orphaned in LiteLLM - key_to_reuse = existing_keys[0] - if await LiteLlmManager.verify_key(key_to_reuse, self.user_id): - item.llm_api_key = SecretStr(key_to_reuse) - logger.info( - 'saas_settings_store:store:reusing_verified_key', - extra={'user_id': self.user_id, 'key_count': len(existing_keys)}, - ) - return - else: - logger.warning( - 'saas_settings_store:store:existing_key_invalid', - extra={'user_id': self.user_id, 'key_count': len(existing_keys)}, - ) - # Fall through to generate a new key - # Generate new key if none exists or existing keys are invalid - generated_key = await LiteLlmManager.generate_key( + # First, check if our current key is valid + if item.llm_api_key and not await LiteLlmManager.verify_existing_key( + item.llm_api_key.get_secret_value(), self.user_id, org_id, - None, - {'type': 'openhands'}, - ) + openhands_type=openhands_type, + ): + generated_key = None + if openhands_type: + generated_key = await LiteLlmManager.generate_key( + self.user_id, + org_id, + None, + {'type': 'openhands'}, + ) + else: + # Must delete any existing key with the same alias first + key_alias = get_openhands_cloud_key_alias(self.user_id, org_id) + await LiteLlmManager.delete_key_by_alias(key_alias=key_alias) + generated_key = await LiteLlmManager.generate_key( + self.user_id, + org_id, + key_alias, + None, + ) - if generated_key: - item.llm_api_key = SecretStr(generated_key) - logger.info( - 'saas_settings_store:store:generated_openhands_key', - extra={'user_id': self.user_id}, - ) - else: - logger.warning( - 'saas_settings_store:store:failed_to_generate_openhands_key', - extra={'user_id': self.user_id}, - ) + if generated_key: + item.llm_api_key = SecretStr(generated_key) + logger.info( + 'saas_settings_store:store:generated_openhands_key', + extra={'user_id': self.user_id}, + ) + else: + logger.warning( + 'saas_settings_store:store:failed_to_generate_openhands_key', + extra={'user_id': self.user_id}, + ) diff --git a/enterprise/tests/unit/test_lite_llm_manager.py b/enterprise/tests/unit/test_lite_llm_manager.py index 7ad9986242..c89a89d6ba 100644 --- a/enterprise/tests/unit/test_lite_llm_manager.py +++ b/enterprise/tests/unit/test_lite_llm_manager.py @@ -11,7 +11,11 @@ from pydantic import SecretStr from server.constants import ( get_default_litellm_model, ) -from storage.lite_llm_manager import LiteLlmManager +from storage.lite_llm_manager import ( + LiteLlmManager, + get_byor_key_alias, + get_openhands_cloud_key_alias, +) from storage.user_settings import UserSettings from openhands.server.settings import Settings @@ -1547,3 +1551,321 @@ class TestLiteLlmManager: # making any LiteLLM calls assert result is not None assert result.agent == 'TestAgent' + + +class TestGetAllKeysForUser: + """Test cases for _get_all_keys_for_user method.""" + + @pytest.mark.asyncio + async def test_get_all_keys_missing_config(self): + """Test _get_all_keys_for_user when LiteLLM config is missing.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + + with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None): + with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None): + result = await LiteLlmManager._get_all_keys_for_user( + mock_client, 'test-user-id' + ) + assert result == [] + mock_client.get.assert_not_called() + + @pytest.mark.asyncio + async def test_get_all_keys_success(self): + """Test _get_all_keys_for_user returns keys on success.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'keys': [ + { + 'key_name': 'sk-test1234', + 'key_alias': 'test-alias', + 'team_id': 'test-org', + 'metadata': {'type': 'openhands'}, + }, + { + 'key_name': 'sk-test5678', + 'key_alias': 'another-alias', + 'team_id': 'test-org', + 'metadata': None, + }, + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.get.return_value = mock_response + + with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'): + with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): + result = await LiteLlmManager._get_all_keys_for_user( + mock_client, 'test-user-id' + ) + + assert len(result) == 2 + assert result[0]['key_name'] == 'sk-test1234' + assert result[1]['key_name'] == 'sk-test5678' + + # Verify API key header is included + mock_client.get.assert_called_once() + call_kwargs = mock_client.get.call_args + assert call_kwargs.kwargs['headers'] == { + 'x-goog-api-key': 'test-api-key' + } + + @pytest.mark.asyncio + async def test_get_all_keys_empty_response(self): + """Test _get_all_keys_for_user returns empty list when user has no keys.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'keys': []} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.get.return_value = mock_response + + with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'): + with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): + result = await LiteLlmManager._get_all_keys_for_user( + mock_client, 'test-user-id' + ) + assert result == [] + + @pytest.mark.asyncio + async def test_get_all_keys_api_error(self): + """Test _get_all_keys_for_user handles API errors gracefully.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.get.side_effect = Exception('API Error') + + with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'): + with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): + result = await LiteLlmManager._get_all_keys_for_user( + mock_client, 'test-user-id' + ) + assert result == [] + + +class TestVerifyExistingKey: + """Test cases for _verify_existing_key method.""" + + @pytest.mark.asyncio + async def test_verify_existing_key_openhands_type_found(self): + """Test _verify_existing_key finds matching OpenHands key.""" + mock_keys = [ + { + 'key_name': 'sk-test1234', + 'key_alias': 'some-alias', + 'team_id': 'test-org', + 'metadata': {'type': 'openhands'}, + } + ] + + mock_client = AsyncMock(spec=httpx.AsyncClient) + + with patch.object( + LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock + ) as mock_get_keys: + mock_get_keys.return_value = mock_keys + + # Key ending with '1234' should match 'sk-test1234' + result = await LiteLlmManager._verify_existing_key( + mock_client, + 'my-key-ending-with-1234', + 'test-user-id', + 'test-org', + openhands_type=True, + ) + assert result is True + + @pytest.mark.asyncio + async def test_verify_existing_key_openhands_type_not_found(self): + """Test _verify_existing_key returns False when key doesn't match.""" + mock_keys = [ + { + 'key_name': 'sk-test1234', + 'key_alias': 'some-alias', + 'team_id': 'test-org', + 'metadata': {'type': 'openhands'}, + } + ] + + mock_client = AsyncMock(spec=httpx.AsyncClient) + + with patch.object( + LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock + ) as mock_get_keys: + mock_get_keys.return_value = mock_keys + + # Key ending with '5678' should NOT match 'sk-test1234' + result = await LiteLlmManager._verify_existing_key( + mock_client, + 'my-key-ending-with-5678', + 'test-user-id', + 'test-org', + openhands_type=True, + ) + assert result is False + + @pytest.mark.asyncio + async def test_verify_existing_key_by_alias_openhands_cloud(self): + """Test _verify_existing_key finds key by OpenHands Cloud alias.""" + user_id = 'test-user-id' + org_id = 'test-org' + mock_keys = [ + { + 'key_name': 'sk-testABCD', + 'key_alias': get_openhands_cloud_key_alias(user_id, org_id), + 'team_id': org_id, + 'metadata': None, + } + ] + + mock_client = AsyncMock(spec=httpx.AsyncClient) + + with patch.object( + LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock + ) as mock_get_keys: + mock_get_keys.return_value = mock_keys + + result = await LiteLlmManager._verify_existing_key( + mock_client, + 'my-key-ending-with-ABCD', + user_id, + org_id, + openhands_type=False, + ) + assert result is True + + @pytest.mark.asyncio + async def test_verify_existing_key_by_alias_byor(self): + """Test _verify_existing_key finds key by BYOR alias.""" + user_id = 'test-user-id' + org_id = 'test-org' + mock_keys = [ + { + 'key_name': 'sk-testXYZW', + 'key_alias': get_byor_key_alias(user_id, org_id), + 'team_id': org_id, + 'metadata': None, + } + ] + + mock_client = AsyncMock(spec=httpx.AsyncClient) + + with patch.object( + LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock + ) as mock_get_keys: + mock_get_keys.return_value = mock_keys + + result = await LiteLlmManager._verify_existing_key( + mock_client, + 'my-key-ending-with-XYZW', + user_id, + org_id, + openhands_type=False, + ) + assert result is True + + @pytest.mark.asyncio + async def test_verify_existing_key_wrong_team(self): + """Test _verify_existing_key returns False for wrong team_id.""" + mock_keys = [ + { + 'key_name': 'sk-test1234', + 'key_alias': 'some-alias', + 'team_id': 'different-org', + 'metadata': {'type': 'openhands'}, + } + ] + + mock_client = AsyncMock(spec=httpx.AsyncClient) + + with patch.object( + LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock + ) as mock_get_keys: + mock_get_keys.return_value = mock_keys + + result = await LiteLlmManager._verify_existing_key( + mock_client, + 'my-key-ending-with-1234', + 'test-user-id', + 'test-org', + openhands_type=True, + ) + assert result is False + + @pytest.mark.asyncio + async def test_verify_existing_key_no_keys(self): + """Test _verify_existing_key returns False when user has no keys.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + + with patch.object( + LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock + ) as mock_get_keys: + mock_get_keys.return_value = [] + + result = await LiteLlmManager._verify_existing_key( + mock_client, + 'some-key-value', + 'test-user-id', + 'test-org', + openhands_type=True, + ) + assert result is False + + @pytest.mark.asyncio + async def test_verify_existing_key_handles_none_key_name(self): + """Test _verify_existing_key handles None key_name gracefully.""" + mock_keys = [ + { + 'key_name': None, + 'key_alias': 'some-alias', + 'team_id': 'test-org', + 'metadata': {'type': 'openhands'}, + } + ] + + mock_client = AsyncMock(spec=httpx.AsyncClient) + + with patch.object( + LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock + ) as mock_get_keys: + mock_get_keys.return_value = mock_keys + + # Should not raise TypeError, should return False + result = await LiteLlmManager._verify_existing_key( + mock_client, + 'some-key-value', + 'test-user-id', + 'test-org', + openhands_type=True, + ) + assert result is False + + @pytest.mark.asyncio + async def test_verify_existing_key_handles_empty_key_name(self): + """Test _verify_existing_key handles empty key_name gracefully.""" + mock_keys = [ + { + 'key_name': '', + 'key_alias': 'some-alias', + 'team_id': 'test-org', + 'metadata': {'type': 'openhands'}, + } + ] + + mock_client = AsyncMock(spec=httpx.AsyncClient) + + with patch.object( + LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock + ) as mock_get_keys: + mock_get_keys.return_value = mock_keys + + # Should not raise error, should return False + result = await LiteLlmManager._verify_existing_key( + mock_client, + 'some-key-value', + 'test-user-id', + 'test-org', + openhands_type=True, + ) + assert result is False diff --git a/enterprise/tests/unit/test_saas_settings_store.py b/enterprise/tests/unit/test_saas_settings_store.py index a3b175e30b..344c260a83 100644 --- a/enterprise/tests/unit/test_saas_settings_store.py +++ b/enterprise/tests/unit/test_saas_settings_store.py @@ -180,48 +180,40 @@ async def test_encryption(settings_store): @pytest.mark.asyncio -async def test_ensure_openhands_api_key_sets_key_when_reusing_verified_key(mock_config): - """The old code returned early without setting item.llm_api_key.""" +async def test_ensure_api_key_keeps_valid_key(mock_config): + """When the existing key is valid, it should be kept unchanged.""" store = SaasSettingsStore('test-user-id-123', MagicMock(), mock_config) existing_key = 'sk-existing-key' - item = DataSettings(llm_model='openhands/gpt-4') + item = DataSettings( + llm_model='openhands/gpt-4', llm_api_key=SecretStr(existing_key) + ) - with ( - patch( - 'storage.saas_settings_store.LiteLlmManager.get_user_keys', - new_callable=AsyncMock, - return_value=[existing_key], - ), - patch( - 'storage.saas_settings_store.LiteLlmManager.verify_key', - new_callable=AsyncMock, - return_value=True, - ), + with patch( + 'storage.saas_settings_store.LiteLlmManager.verify_existing_key', + new_callable=AsyncMock, + return_value=True, ): - await store._ensure_openhands_api_key(item, 'org-123') + await store._ensure_api_key(item, 'org-123', openhands_type=True) - # This assertion failed with the old code + # Key should remain unchanged when it's valid assert item.llm_api_key is not None assert item.llm_api_key.get_secret_value() == existing_key @pytest.mark.asyncio -async def test_ensure_openhands_api_key_generates_new_key_when_verification_fails( +async def test_ensure_api_key_generates_new_key_when_verification_fails( mock_config, ): - """Handles orphaned keys that exist in our DB but not in LiteLLM.""" + """When verification fails, a new key should be generated.""" store = SaasSettingsStore('test-user-id-123', MagicMock(), mock_config) new_key = 'sk-new-key' - item = DataSettings(llm_model='openhands/gpt-4') + item = DataSettings( + llm_model='openhands/gpt-4', llm_api_key=SecretStr('sk-invalid-key') + ) with ( patch( - 'storage.saas_settings_store.LiteLlmManager.get_user_keys', - new_callable=AsyncMock, - return_value=['sk-orphaned-key'], - ), - patch( - 'storage.saas_settings_store.LiteLlmManager.verify_key', + 'storage.saas_settings_store.LiteLlmManager.verify_existing_key', new_callable=AsyncMock, return_value=False, ), @@ -231,7 +223,7 @@ async def test_ensure_openhands_api_key_generates_new_key_when_verification_fail return_value=new_key, ), ): - await store._ensure_openhands_api_key(item, 'org-123') + await store._ensure_api_key(item, 'org-123', openhands_type=True) assert item.llm_api_key is not None assert item.llm_api_key.get_secret_value() == new_key diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index 7db5a54c4b..6063752dea 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -114,10 +114,8 @@ async def reset_settings() -> JSONResponse: async def store_llm_settings( - settings: Settings, settings_store: SettingsStore + settings: Settings, existing_settings: Settings ) -> Settings: - existing_settings = await settings_store.load() - # Convert to Settings model and merge with existing settings if existing_settings: # Keep existing LLM settings if not provided @@ -156,7 +154,7 @@ async def store_settings( # Convert to Settings model and merge with existing settings if existing_settings: - settings = await store_llm_settings(settings, settings_store) + settings = await store_llm_settings(settings, existing_settings) # Keep existing analytics consent if not provided if settings.user_consents_to_analytics is None: diff --git a/tests/unit/server/routes/test_settings_store_functions.py b/tests/unit/server/routes/test_settings_store_functions.py index c6eb6f5628..25c23a7daf 100644 --- a/tests/unit/server/routes/test_settings_store_functions.py +++ b/tests/unit/server/routes/test_settings_store_functions.py @@ -149,11 +149,10 @@ async def test_store_llm_settings_new_settings(): llm_base_url='https://api.example.com', ) - # Mock the settings store - mock_store = MagicMock() - mock_store.load = AsyncMock(return_value=None) # No existing settings + # No existing settings + existing_settings = None - result = await store_llm_settings(settings, mock_store) + result = await store_llm_settings(settings, existing_settings) # Should return settings with the provided values assert result.llm_model == 'gpt-4' @@ -170,9 +169,6 @@ async def test_store_llm_settings_update_existing(): llm_base_url='https://new.example.com', ) - # Mock the settings store - mock_store = MagicMock() - # Create existing settings existing_settings = Settings( llm_model='gpt-3.5', @@ -180,9 +176,7 @@ async def test_store_llm_settings_update_existing(): llm_base_url='https://old.example.com', ) - mock_store.load = AsyncMock(return_value=existing_settings) - - result = await store_llm_settings(settings, mock_store) + result = await store_llm_settings(settings, existing_settings) # Should return settings with the updated values assert result.llm_model == 'gpt-4' @@ -197,9 +191,6 @@ async def test_store_llm_settings_partial_update(): llm_model='gpt-4' # Only updating model ) - # Mock the settings store - mock_store = MagicMock() - # Create existing settings existing_settings = Settings( llm_model='gpt-3.5', @@ -207,9 +198,7 @@ async def test_store_llm_settings_partial_update(): llm_base_url='https://existing.example.com', ) - mock_store.load = AsyncMock(return_value=existing_settings) - - result = await store_llm_settings(settings, mock_store) + result = await store_llm_settings(settings, existing_settings) # Should return settings with updated model but keep other values assert result.llm_model == 'gpt-4'