mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
APP-443: Fix key generation (#12726)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: tofarr <tofarr@gmail.com>
This commit is contained in:
@@ -24,13 +24,23 @@ from storage.user_settings import UserSettings
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# Timeout in seconds for BYOR key verification requests to LiteLLM
|
||||
BYOR_KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
# Timeout in seconds for key verification requests to LiteLLM
|
||||
KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
|
||||
# A very large number to represent "unlimited" until LiteLLM fixes their unlimited update bug.
|
||||
UNLIMITED_BUDGET_SETTING = 1000000000.0
|
||||
|
||||
|
||||
def get_openhands_cloud_key_alias(keycloak_user_id: str, org_id: str) -> str:
|
||||
"""Generate the key alias for OpenHands Cloud managed keys."""
|
||||
return f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}'
|
||||
|
||||
|
||||
def get_byor_key_alias(keycloak_user_id: str, org_id: str) -> str:
|
||||
"""Generate the key alias for BYOR (Bring Your Own Runtime) keys."""
|
||||
return f'BYOR Key - user {keycloak_user_id}, org {org_id}'
|
||||
|
||||
|
||||
class LiteLlmManager:
|
||||
"""Manage LiteLLM interactions."""
|
||||
|
||||
@@ -79,7 +89,7 @@ class LiteLlmManager:
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}',
|
||||
get_openhands_cloud_key_alias(keycloak_user_id, org_id),
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -251,7 +261,7 @@ class LiteLlmManager:
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}',
|
||||
get_openhands_cloud_key_alias(keycloak_user_id, org_id),
|
||||
None,
|
||||
)
|
||||
if new_key:
|
||||
@@ -1044,7 +1054,7 @@ class LiteLlmManager:
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
timeout=KEY_VERIFICATION_TIMEOUT,
|
||||
) as client:
|
||||
# Make a lightweight request to verify the key
|
||||
# Using /v1/models endpoint as it's lightweight and requires authentication
|
||||
@@ -1058,7 +1068,7 @@ class LiteLlmManager:
|
||||
# Only 200 status code indicates valid key
|
||||
if response.status_code == 200:
|
||||
logger.debug(
|
||||
'BYOR key verification successful',
|
||||
'Key verification successful',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
@@ -1066,7 +1076,7 @@ class LiteLlmManager:
|
||||
# All other status codes (401, 403, 500, etc.) are treated as invalid
|
||||
# This includes authentication errors and server errors
|
||||
logger.warning(
|
||||
'BYOR key verification failed - treating as invalid',
|
||||
'Key verification failed - treating as invalid',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': response.status_code,
|
||||
@@ -1079,7 +1089,7 @@ class LiteLlmManager:
|
||||
# Any exception (timeout, network error, etc.) means we can't verify
|
||||
# Return False to trigger regeneration rather than returning potentially invalid key
|
||||
logger.warning(
|
||||
'BYOR key verification error - treating as invalid to ensure key validity',
|
||||
'Key verification error - treating as invalid to ensure key validity',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
@@ -1123,6 +1133,103 @@ class LiteLlmManager:
|
||||
'key_spend': key_info.get('spend'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _get_all_keys_for_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
) -> list[dict]:
|
||||
"""Get all keys for a user from LiteLLM.
|
||||
|
||||
Returns a list of key info dictionaries containing:
|
||||
- token: the key value (hashed or partial)
|
||||
- key_alias: the alias for the key
|
||||
- key_name: the name of the key
|
||||
- spend: the amount spent on this key
|
||||
- max_budget: the max budget for this key
|
||||
- team_id: the team the key belongs to
|
||||
- metadata: any metadata associated with the key
|
||||
|
||||
Returns an empty list if no keys found or on error.
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={keycloak_user_id}',
|
||||
headers={'x-goog-api-key': LITE_LLM_API_KEY},
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_json = response.json()
|
||||
# The user/info endpoint returns keys in the 'keys' field
|
||||
return user_json.get('keys', [])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'LiteLlmManager:_get_all_keys_for_user:error',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _verify_existing_key(
|
||||
client: httpx.AsyncClient,
|
||||
key_value: str,
|
||||
keycloak_user_id: str,
|
||||
org_id: str,
|
||||
openhands_type: bool = False,
|
||||
) -> bool:
|
||||
"""Check if an existing key exists for the user/org in LiteLLM.
|
||||
|
||||
Verifies the provided key_value matches a key registered in LiteLLM for
|
||||
the given user and organization. For openhands_type=True, looks for keys
|
||||
with metadata type='openhands' and matching team_id. For openhands_type=False,
|
||||
looks for keys with matching alias and team_id.
|
||||
|
||||
Returns True if the key is found and valid, False otherwise.
|
||||
"""
|
||||
found = False
|
||||
keys = await LiteLlmManager._get_all_keys_for_user(client, keycloak_user_id)
|
||||
for key_info in keys:
|
||||
metadata = key_info.get('metadata') or {}
|
||||
team_id = key_info.get('team_id')
|
||||
key_alias = key_info.get('key_alias')
|
||||
token = None
|
||||
if (
|
||||
openhands_type
|
||||
and metadata.get('type') == 'openhands'
|
||||
and team_id == org_id
|
||||
):
|
||||
# Found an existing OpenHands key for this org
|
||||
key_name = key_info.get('key_name')
|
||||
token = key_name[-4:] if key_name else None # last 4 digits of key
|
||||
if token and key_value.endswith(
|
||||
token
|
||||
): # check if this is our current key
|
||||
found = True
|
||||
break
|
||||
if (
|
||||
not openhands_type
|
||||
and team_id == org_id
|
||||
and (
|
||||
key_alias == get_openhands_cloud_key_alias(keycloak_user_id, org_id)
|
||||
or key_alias == get_byor_key_alias(keycloak_user_id, org_id)
|
||||
)
|
||||
):
|
||||
# Found an existing key for this org (regardless of type)
|
||||
key_name = key_info.get('key_name')
|
||||
token = key_name[-4:] if key_name else None # last 4 digits of key
|
||||
if token and key_value.endswith(
|
||||
token
|
||||
): # check if this is our current key
|
||||
found = True
|
||||
break
|
||||
|
||||
return found
|
||||
|
||||
@staticmethod
|
||||
async def _delete_key_by_alias(
|
||||
client: httpx.AsyncClient,
|
||||
@@ -1220,6 +1327,8 @@ class LiteLlmManager:
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
get_key_info = staticmethod(with_http_client(_get_key_info))
|
||||
verify_existing_key = staticmethod(with_http_client(_verify_existing_key))
|
||||
delete_key = staticmethod(with_http_client(_delete_key))
|
||||
get_user_keys = staticmethod(with_http_client(_get_user_keys))
|
||||
delete_key_by_alias = staticmethod(with_http_client(_delete_key_by_alias))
|
||||
update_user_keys = staticmethod(with_http_client(_update_user_keys))
|
||||
|
||||
@@ -5,7 +5,8 @@ Store class for managing organization-member relationships.
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
@@ -38,7 +39,7 @@ class OrgMemberStore:
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
|
||||
def get_org_member(org_id: UUID, user_id: UUID) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
@@ -48,7 +49,18 @@ class OrgMemberStore:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: int) -> list[OrgMember]:
|
||||
async def get_org_member_async(org_id: UUID, user_id: UUID) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgMember).filter(
|
||||
OrgMember.org_id == org_id, OrgMember.user_id == user_id
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: UUID) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
|
||||
@@ -68,7 +80,7 @@ class OrgMemberStore:
|
||||
|
||||
@staticmethod
|
||||
def update_user_role_in_org(
|
||||
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
|
||||
org_id: UUID, user_id: UUID, role_id: int, status: Optional[str] = None
|
||||
) -> Optional[OrgMember]:
|
||||
"""Update user's role in an organization."""
|
||||
with session_maker() as session:
|
||||
@@ -90,7 +102,7 @@ class OrgMemberStore:
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
|
||||
def remove_user_from_org(org_id: UUID, user_id: UUID) -> bool:
|
||||
"""Remove a user from an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
|
||||
@@ -8,10 +8,11 @@ from dataclasses import dataclass
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from pydantic import SecretStr
|
||||
from server.constants import LITE_LLM_API_URL
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.lite_llm_manager import LiteLlmManager, get_openhands_cloud_key_alias
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
@@ -143,23 +144,34 @@ class SaasSettingsStore(SettingsStore):
|
||||
return None
|
||||
|
||||
org_id = user.current_org_id
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item, str(org_id))
|
||||
org_member = None
|
||||
|
||||
org_member: OrgMember = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
|
||||
org: Org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
|
||||
llm_base_url = (
|
||||
org_member.llm_base_url
|
||||
if org_member.llm_base_url
|
||||
else org.default_llm_base_url
|
||||
)
|
||||
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if self._is_openhands_provider(item):
|
||||
await self._ensure_api_key(item, str(org_id), openhands_type=True)
|
||||
elif llm_base_url == LITE_LLM_API_URL:
|
||||
await self._ensure_api_key(item, str(org_id))
|
||||
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
@@ -227,48 +239,49 @@ class SaasSettingsStore(SettingsStore):
|
||||
"""Check if the settings use the OpenHands provider."""
|
||||
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
|
||||
|
||||
async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None:
|
||||
async def _ensure_api_key(
|
||||
self, item: Settings, org_id: str, openhands_type: bool = False
|
||||
) -> None:
|
||||
"""Generate and set the OpenHands API key for the given settings.
|
||||
|
||||
First checks if an existing key exists for the user and verifies it
|
||||
is valid in LiteLLM. If valid, reuses it. Otherwise, generates a new key.
|
||||
"""
|
||||
# Check if user already has keys in LiteLLM
|
||||
existing_keys = await LiteLlmManager.get_user_keys(self.user_id)
|
||||
if existing_keys:
|
||||
# Verify the first key is actually valid in LiteLLM before reusing
|
||||
# This handles cases where keys exist in our DB but were orphaned in LiteLLM
|
||||
key_to_reuse = existing_keys[0]
|
||||
if await LiteLlmManager.verify_key(key_to_reuse, self.user_id):
|
||||
item.llm_api_key = SecretStr(key_to_reuse)
|
||||
logger.info(
|
||||
'saas_settings_store:store:reusing_verified_key',
|
||||
extra={'user_id': self.user_id, 'key_count': len(existing_keys)},
|
||||
)
|
||||
return
|
||||
else:
|
||||
logger.warning(
|
||||
'saas_settings_store:store:existing_key_invalid',
|
||||
extra={'user_id': self.user_id, 'key_count': len(existing_keys)},
|
||||
)
|
||||
# Fall through to generate a new key
|
||||
|
||||
# Generate new key if none exists or existing keys are invalid
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
# First, check if our current key is valid
|
||||
if item.llm_api_key and not await LiteLlmManager.verify_existing_key(
|
||||
item.llm_api_key.get_secret_value(),
|
||||
self.user_id,
|
||||
org_id,
|
||||
None,
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
openhands_type=openhands_type,
|
||||
):
|
||||
generated_key = None
|
||||
if openhands_type:
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
None,
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
else:
|
||||
# Must delete any existing key with the same alias first
|
||||
key_alias = get_openhands_cloud_key_alias(self.user_id, org_id)
|
||||
await LiteLlmManager.delete_key_by_alias(key_alias=key_alias)
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
key_alias,
|
||||
None,
|
||||
)
|
||||
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
'saas_settings_store:store:generated_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
'saas_settings_store:store:generated_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
@@ -11,7 +11,11 @@ from pydantic import SecretStr
|
||||
from server.constants import (
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.lite_llm_manager import (
|
||||
LiteLlmManager,
|
||||
get_byor_key_alias,
|
||||
get_openhands_cloud_key_alias,
|
||||
)
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
@@ -1547,3 +1551,321 @@ class TestLiteLlmManager:
|
||||
# making any LiteLLM calls
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
|
||||
|
||||
class TestGetAllKeysForUser:
|
||||
"""Test cases for _get_all_keys_for_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_keys_missing_config(self):
|
||||
"""Test _get_all_keys_for_user when LiteLLM config is missing."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager._get_all_keys_for_user(
|
||||
mock_client, 'test-user-id'
|
||||
)
|
||||
assert result == []
|
||||
mock_client.get.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_keys_success(self):
|
||||
"""Test _get_all_keys_for_user returns keys on success."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'keys': [
|
||||
{
|
||||
'key_name': 'sk-test1234',
|
||||
'key_alias': 'test-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
},
|
||||
{
|
||||
'key_name': 'sk-test5678',
|
||||
'key_alias': 'another-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': None,
|
||||
},
|
||||
]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._get_all_keys_for_user(
|
||||
mock_client, 'test-user-id'
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]['key_name'] == 'sk-test1234'
|
||||
assert result[1]['key_name'] == 'sk-test5678'
|
||||
|
||||
# Verify API key header is included
|
||||
mock_client.get.assert_called_once()
|
||||
call_kwargs = mock_client.get.call_args
|
||||
assert call_kwargs.kwargs['headers'] == {
|
||||
'x-goog-api-key': 'test-api-key'
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_keys_empty_response(self):
|
||||
"""Test _get_all_keys_for_user returns empty list when user has no keys."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {'keys': []}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._get_all_keys_for_user(
|
||||
mock_client, 'test-user-id'
|
||||
)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_keys_api_error(self):
|
||||
"""Test _get_all_keys_for_user handles API errors gracefully."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_client.get.side_effect = Exception('API Error')
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._get_all_keys_for_user(
|
||||
mock_client, 'test-user-id'
|
||||
)
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestVerifyExistingKey:
|
||||
"""Test cases for _verify_existing_key method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_openhands_type_found(self):
|
||||
"""Test _verify_existing_key finds matching OpenHands key."""
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': 'sk-test1234',
|
||||
'key_alias': 'some-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
# Key ending with '1234' should match 'sk-test1234'
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'my-key-ending-with-1234',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_openhands_type_not_found(self):
|
||||
"""Test _verify_existing_key returns False when key doesn't match."""
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': 'sk-test1234',
|
||||
'key_alias': 'some-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
# Key ending with '5678' should NOT match 'sk-test1234'
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'my-key-ending-with-5678',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_by_alias_openhands_cloud(self):
|
||||
"""Test _verify_existing_key finds key by OpenHands Cloud alias."""
|
||||
user_id = 'test-user-id'
|
||||
org_id = 'test-org'
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': 'sk-testABCD',
|
||||
'key_alias': get_openhands_cloud_key_alias(user_id, org_id),
|
||||
'team_id': org_id,
|
||||
'metadata': None,
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'my-key-ending-with-ABCD',
|
||||
user_id,
|
||||
org_id,
|
||||
openhands_type=False,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_by_alias_byor(self):
|
||||
"""Test _verify_existing_key finds key by BYOR alias."""
|
||||
user_id = 'test-user-id'
|
||||
org_id = 'test-org'
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': 'sk-testXYZW',
|
||||
'key_alias': get_byor_key_alias(user_id, org_id),
|
||||
'team_id': org_id,
|
||||
'metadata': None,
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'my-key-ending-with-XYZW',
|
||||
user_id,
|
||||
org_id,
|
||||
openhands_type=False,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_wrong_team(self):
|
||||
"""Test _verify_existing_key returns False for wrong team_id."""
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': 'sk-test1234',
|
||||
'key_alias': 'some-alias',
|
||||
'team_id': 'different-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'my-key-ending-with-1234',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_no_keys(self):
|
||||
"""Test _verify_existing_key returns False when user has no keys."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = []
|
||||
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'some-key-value',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_handles_none_key_name(self):
|
||||
"""Test _verify_existing_key handles None key_name gracefully."""
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': None,
|
||||
'key_alias': 'some-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
# Should not raise TypeError, should return False
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'some-key-value',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_handles_empty_key_name(self):
|
||||
"""Test _verify_existing_key handles empty key_name gracefully."""
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': '',
|
||||
'key_alias': 'some-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
# Should not raise error, should return False
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'some-key-value',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@@ -180,48 +180,40 @@ async def test_encryption(settings_store):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_openhands_api_key_sets_key_when_reusing_verified_key(mock_config):
|
||||
"""The old code returned early without setting item.llm_api_key."""
|
||||
async def test_ensure_api_key_keeps_valid_key(mock_config):
|
||||
"""When the existing key is valid, it should be kept unchanged."""
|
||||
store = SaasSettingsStore('test-user-id-123', MagicMock(), mock_config)
|
||||
existing_key = 'sk-existing-key'
|
||||
item = DataSettings(llm_model='openhands/gpt-4')
|
||||
item = DataSettings(
|
||||
llm_model='openhands/gpt-4', llm_api_key=SecretStr(existing_key)
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.saas_settings_store.LiteLlmManager.get_user_keys',
|
||||
new_callable=AsyncMock,
|
||||
return_value=[existing_key],
|
||||
),
|
||||
patch(
|
||||
'storage.saas_settings_store.LiteLlmManager.verify_key',
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
),
|
||||
with patch(
|
||||
'storage.saas_settings_store.LiteLlmManager.verify_existing_key',
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
await store._ensure_openhands_api_key(item, 'org-123')
|
||||
await store._ensure_api_key(item, 'org-123', openhands_type=True)
|
||||
|
||||
# This assertion failed with the old code
|
||||
# Key should remain unchanged when it's valid
|
||||
assert item.llm_api_key is not None
|
||||
assert item.llm_api_key.get_secret_value() == existing_key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_openhands_api_key_generates_new_key_when_verification_fails(
|
||||
async def test_ensure_api_key_generates_new_key_when_verification_fails(
|
||||
mock_config,
|
||||
):
|
||||
"""Handles orphaned keys that exist in our DB but not in LiteLLM."""
|
||||
"""When verification fails, a new key should be generated."""
|
||||
store = SaasSettingsStore('test-user-id-123', MagicMock(), mock_config)
|
||||
new_key = 'sk-new-key'
|
||||
item = DataSettings(llm_model='openhands/gpt-4')
|
||||
item = DataSettings(
|
||||
llm_model='openhands/gpt-4', llm_api_key=SecretStr('sk-invalid-key')
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.saas_settings_store.LiteLlmManager.get_user_keys',
|
||||
new_callable=AsyncMock,
|
||||
return_value=['sk-orphaned-key'],
|
||||
),
|
||||
patch(
|
||||
'storage.saas_settings_store.LiteLlmManager.verify_key',
|
||||
'storage.saas_settings_store.LiteLlmManager.verify_existing_key',
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
),
|
||||
@@ -231,7 +223,7 @@ async def test_ensure_openhands_api_key_generates_new_key_when_verification_fail
|
||||
return_value=new_key,
|
||||
),
|
||||
):
|
||||
await store._ensure_openhands_api_key(item, 'org-123')
|
||||
await store._ensure_api_key(item, 'org-123', openhands_type=True)
|
||||
|
||||
assert item.llm_api_key is not None
|
||||
assert item.llm_api_key.get_secret_value() == new_key
|
||||
|
||||
Reference in New Issue
Block a user