Fix UserSettings creation from Org tables (#12635)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: tofarr <tofarr@gmail.com>
This commit is contained in:
chuckbutkus
2026-01-28 11:35:05 -05:00
committed by GitHub
parent 00a74731ae
commit 570ab904f6
15 changed files with 402 additions and 178 deletions

View File

@@ -216,9 +216,9 @@ class SaasUserAuth(UserAuth):
async def get_mcp_api_key(self) -> str:
api_key_store = ApiKeyStore.get_instance()
mcp_api_key = api_key_store.retrieve_mcp_api_key(self.user_id)
mcp_api_key = await api_key_store.retrieve_mcp_api_key(self.user_id)
if not mcp_api_key:
mcp_api_key = api_key_store.create_api_key(
mcp_api_key = await api_key_store.create_api_key(
self.user_id, 'MCP_API_KEY', None
)
return mcp_api_key

View File

@@ -22,7 +22,7 @@ from openhands.core.logger import openhands_logger as logger
# NOTE: these details are specific to the MCP protocol
class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
@staticmethod
def create_default_mcp_server_config(
async def create_default_mcp_server_config(
host: str, config: 'OpenHandsConfig', user_id: str | None = None
) -> tuple[MCPSHTTPServerConfig | None, list[MCPStdioServerConfig]]:
"""
@@ -38,10 +38,12 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
api_key_store = ApiKeyStore.get_instance()
if user_id:
api_key = api_key_store.retrieve_mcp_api_key(user_id)
api_key = await api_key_store.retrieve_mcp_api_key(user_id)
if not api_key:
api_key = api_key_store.create_api_key(user_id, 'MCP_API_KEY', None)
api_key = await api_key_store.create_api_key(
user_id, 'MCP_API_KEY', None
)
if not api_key:
logger.error(f'Could not provision MCP API Key for user: {user_id}')

View File

@@ -10,53 +10,44 @@ from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
from openhands.utils.async_utils import call_sync_from_async
# Helper functions for BYOR API key management
async def get_byor_key_from_db(user_id: str) -> str | None:
"""Get the BYOR key from the database for a user."""
def _get_byor_key():
user = UserStore.get_user_by_id(user_id)
if not user:
return None
current_org_id = user.current_org_id
current_org_member: OrgMember = None
for org_member in user.org_members:
if org_member.org_id == current_org_id:
current_org_member = org_member
break
if not current_org_member:
return None
if current_org_member.llm_api_key_for_byor:
return current_org_member.llm_api_key_for_byor.get_secret_value()
user = await UserStore.get_user_by_id_async(user_id)
if not user:
return None
return await call_sync_from_async(_get_byor_key)
current_org_id = user.current_org_id
current_org_member: OrgMember = None
for org_member in user.org_members:
if org_member.org_id == current_org_id:
current_org_member = org_member
break
if not current_org_member:
return None
if current_org_member.llm_api_key_for_byor:
return current_org_member.llm_api_key_for_byor.get_secret_value()
return None
async def store_byor_key_in_db(user_id: str, key: str) -> None:
"""Store the BYOR key in the database for a user."""
user = await UserStore.get_user_by_id_async(user_id)
if not user:
return None
def _update_user_settings():
user = UserStore.get_user_by_id(user_id)
if not user:
return None
current_org_id = user.current_org_id
current_org_member: OrgMember = None
for org_member in user.org_members:
if org_member.org_id == current_org_id:
current_org_member = org_member
break
if not current_org_member:
return None
current_org_member.llm_api_key_for_byor = key
OrgMemberStore.update_org_member(current_org_member)
await call_sync_from_async(_update_user_settings)
current_org_id = user.current_org_id
current_org_member: OrgMember = None
for org_member in user.org_members:
if org_member.org_id == current_org_id:
current_org_member = org_member
break
if not current_org_member:
return None
current_org_member.llm_api_key_for_byor = key
OrgMemberStore.update_org_member(current_org_member)
async def generate_byor_key(user_id: str) -> str | None:
@@ -161,11 +152,11 @@ class LlmApiKeyResponse(BaseModel):
async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)):
"""Create a new API key for the authenticated user."""
try:
api_key = api_key_store.create_api_key(
api_key = await api_key_store.create_api_key(
user_id, key_data.name, key_data.expires_at
)
# Get the created key details
keys = api_key_store.list_api_keys(user_id)
keys = await api_key_store.list_api_keys(user_id)
for key in keys:
if key['name'] == key_data.name:
return {
@@ -193,7 +184,7 @@ async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user
async def list_api_keys(user_id: str = Depends(get_user_id)):
"""List all API keys for the authenticated user."""
try:
keys = api_key_store.list_api_keys(user_id)
keys = await api_key_store.list_api_keys(user_id)
return [
{
**key,
@@ -222,7 +213,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
"""Delete an API key."""
try:
# First, verify the key belongs to the user
keys = api_key_store.list_api_keys(user_id)
keys = await api_key_store.list_api_keys(user_id)
key_to_delete = None
for key in keys:

View File

@@ -272,7 +272,7 @@ async def device_verification_authenticated(
try:
# Create a unique API key for this device using user_code in the name
device_key_name = f'{API_KEY_NAME} ({user_code})'
api_key_store.create_api_key(
await api_key_store.create_api_key(
user_id,
name=device_key_name,
expires_at=datetime.now(UTC) + KEY_EXPIRATION_TIME,

View File

@@ -516,11 +516,13 @@ class SaasNestedConversationManager(ConversationManager):
)
raise
def _get_mcp_config(self, user_id: str) -> MCPConfig | None:
async def _get_mcp_config(self, user_id: str) -> MCPConfig | None:
api_key_store = ApiKeyStore.get_instance()
mcp_api_key = api_key_store.retrieve_mcp_api_key(user_id)
mcp_api_key = await api_key_store.retrieve_mcp_api_key(user_id)
if not mcp_api_key:
mcp_api_key = api_key_store.create_api_key(user_id, 'MCP_API_KEY', None)
mcp_api_key = await api_key_store.create_api_key(
user_id, 'MCP_API_KEY', None
)
if not mcp_api_key:
return None
web_host = os.environ.get('WEB_HOST', 'app.all-hands.dev')
@@ -547,7 +549,7 @@ class SaasNestedConversationManager(ConversationManager):
'conversation_id': sid,
}
mcp_config = self._get_mcp_config(user_id)
mcp_config = await self._get_mcp_config(user_id)
if mcp_config:
# Merge with any MCP config from settings
if settings.mcp_config:

View File

@@ -12,6 +12,7 @@ from storage.database import session_maker
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.utils.async_utils import call_sync_from_async
@dataclass
@@ -26,7 +27,7 @@ class ApiKeyStore:
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
return f'{self.API_KEY_PREFIX}{random_part}'
def create_api_key(
async def create_api_key(
self, user_id: str, name: str | None = None, expires_at: datetime | None = None
) -> str:
"""Create a new API key for a user.
@@ -40,8 +41,23 @@ class ApiKeyStore:
The generated API key
"""
api_key = self.generate_api_key()
user = UserStore.get_user_by_id(user_id)
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
await call_sync_from_async(
self._store_api_key, user_id, org_id, api_key, name, expires_at
)
return api_key
def _store_api_key(
self,
user_id: str,
org_id: str,
api_key: str,
name: str | None,
expires_at: datetime | None = None,
) -> None:
"""Store an existing API key in the database."""
with self.session_maker() as session:
key_record = ApiKey(
key=api_key,
@@ -53,8 +69,6 @@ class ApiKeyStore:
session.add(key_record)
session.commit()
return api_key
def validate_api_key(self, api_key: str) -> str | None:
"""Validate an API key and return the associated user_id if valid."""
now = datetime.now(UTC)
@@ -112,10 +126,13 @@ class ApiKeyStore:
return True
def list_api_keys(self, user_id: str) -> list[dict]:
async def list_api_keys(self, user_id: str) -> list[dict]:
"""List all API keys for a user."""
user = UserStore.get_user_by_id(user_id)
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
return await call_sync_from_async(self._list_api_keys_from_db, user_id, org_id)
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
with self.session_maker() as session:
keys = (
session.query(ApiKey)
@@ -136,9 +153,14 @@ class ApiKeyStore:
if 'MCP_API_KEY' != key.name
]
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = UserStore.get_user_by_id(user_id)
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
return await call_sync_from_async(
self._retrieve_mcp_api_key_from_db, user_id, org_id
)
def _retrieve_mcp_api_key_from_db(self, user_id: str, org_id: str) -> str | None:
with self.session_maker() as session:
keys: list[ApiKey] = (
session.query(ApiKey)

View File

@@ -98,6 +98,29 @@ def decrypt_legacy_value(value: str | SecretStr) -> str:
return get_fernet().decrypt(b64decode(value.encode())).decode()
def encrypt_legacy_model(encrypt_keys: list, model_instance) -> dict:
return encrypt_legacy_kwargs(encrypt_keys, model_to_kwargs(model_instance))
def encrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
if value is None:
continue
if key in encrypt_keys:
value = encrypt_legacy_value(value)
kwargs[key] = value
return kwargs
def encrypt_legacy_value(value: str | SecretStr) -> str:
if isinstance(value, SecretStr):
return b64encode(
get_fernet().encrypt(value.get_secret_value().encode())
).decode()
else:
return b64encode(get_fernet().encrypt(value.encode())).decode()
def get_fernet():
global _fernet
if _fernet is None:

View File

@@ -34,11 +34,10 @@ class SaasConversationStore(ConversationStore):
session_maker: sessionmaker
org_id: UUID | None = None # will be fetched automatically
def __init__(self, user_id: str, session_maker: sessionmaker):
def __init__(self, user_id: str, org_id: UUID, session_maker: sessionmaker):
self.user_id = user_id
self.org_id = org_id
self.session_maker = session_maker
user = UserStore.get_user_by_id(user_id)
self.org_id = user.current_org_id if user else None
def _select_by_id(self, session, conversation_id: str):
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
@@ -235,4 +234,6 @@ class SaasConversationStore(ConversationStore):
cls, config: OpenHandsConfig, user_id: str | None
) -> ConversationStore:
# user_id should not be None in SaaS, should we raise?
return SaasConversationStore(str(user_id), session_maker)
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id if user else None
return SaasConversationStore(str(user_id), org_id, session_maker)

View File

@@ -17,7 +17,11 @@ from server.logger import logger
from sqlalchemy import select, text
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker, session_maker
from storage.encrypt_utils import decrypt_legacy_model
from storage.encrypt_utils import (
decrypt_legacy_model,
decrypt_legacy_value,
encrypt_legacy_value,
)
from storage.org import Org
from storage.org_member import OrgMember
from storage.role_store import RoleStore
@@ -127,6 +131,25 @@ class UserStore:
)
return bool(lock_acquired)
@staticmethod
async def _release_user_creation_lock(user_id: str) -> bool:
"""Release the distributed lock for user creation.
Returns True if the lock was released or if Redis is unavailable.
Returns False if the lock could not be released.
"""
redis_client = UserStore._get_redis_client()
if redis_client is None:
logger.warning(
'user_store:_release_user_creation_lock:no_redis_client',
extra={'user_id': user_id},
)
return True # Nothing to release if Redis is unavailable
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}'
deleted = await redis_client.delete(user_key)
return bool(deleted)
@staticmethod
async def migrate_user(
user_id: str,
@@ -238,7 +261,6 @@ class UserStore:
if not custom_settings:
del org_member_kwargs['llm_model']
del org_member_kwargs['llm_base_url']
del org_member_kwargs['llm_api_key_for_byor']
org_member = OrgMember(
org_id=org.id,
@@ -423,20 +445,10 @@ class UserStore:
org_member = org_members[0]
is_new_signup = True
# Create a new user_settings entry from org_member data
# Create a new user_settings entry from OrgMember, User, and Org data
# This is needed for new sign-ups who don't have user_settings
user_settings = UserSettings(
keycloak_user_id=user_id,
llm_api_key=org_member.llm_api_key.get_secret_value()
if org_member.llm_api_key
else None,
llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value()
if org_member.llm_api_key_for_byor
else None,
llm_model=org_member.llm_model,
llm_base_url=org_member.llm_base_url,
max_iterations=org_member.max_iterations,
already_migrated=False, # Will be set correctly below
user_settings = UserStore._create_user_settings_from_entities(
user_id, org_member, user, org
)
session.add(user_settings)
session.flush()
@@ -565,8 +577,21 @@ class UserStore:
{'org_id': user_uuid},
)
# Step 8: Set already_migrated=False on user_settings
# Step 8: Set already_migrated=False on user_settings and encrypt fields
user_settings.already_migrated = False
# Re-encrypt the sensitive fields before storing in the DB
encrypt_keys = [
'llm_api_key',
'llm_api_key_for_byor',
'search_api_key',
'sandbox_api_key',
]
for key in encrypt_keys:
value = getattr(user_settings, key, None)
if value is not None and not _is_legacy_value_encrypted(value):
setattr(user_settings, key, encrypt_legacy_value(value))
session.merge(user_settings)
session.commit()
@@ -608,41 +633,46 @@ class UserStore:
asyncio.sleep, GENERAL_TIMEOUT, _RETRY_LOAD_DELAY_SECONDS
)
# Check for user again as migration could have happened while trying to get the lock.
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
if user:
return user
try:
# Check for user again as migration could have happened while trying to get the lock.
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
if user:
return user
user_settings = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(False),
user_settings = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(False),
)
.first()
)
.first()
)
if user_settings:
token_manager = TokenManager()
user_info = call_async_from_sync(
token_manager.get_user_info_from_user_id,
GENERAL_TIMEOUT,
user_id,
if user_settings:
token_manager = TokenManager()
user_info = call_async_from_sync(
token_manager.get_user_info_from_user_id,
GENERAL_TIMEOUT,
user_id,
)
user = call_async_from_sync(
UserStore.migrate_user,
GENERAL_TIMEOUT,
user_id,
user_settings,
user_info,
)
return user
else:
return None
finally:
call_async_from_sync(
UserStore._release_user_creation_lock, GENERAL_TIMEOUT, user_id
)
user = call_async_from_sync(
UserStore.migrate_user,
GENERAL_TIMEOUT,
user_id,
user_settings,
user_info,
)
return user
else:
return None
@staticmethod
async def get_user_by_id_async(user_id: str) -> Optional[User]:
@@ -670,42 +700,45 @@ class UserStore:
)
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
# Check for user again as migration could have happened while trying to get the lock.
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
)
user = result.scalars().first()
if user:
return user
logger.info(
'user_store:get_user_by_id_async:start_migration',
extra={'user_id': user_id},
)
result = await session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(False),
try:
# Check for user again as migration could have happened while trying to get the lock.
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
)
)
user_settings = result.scalars().first()
if user_settings:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(user_id)
user = result.scalars().first()
if user:
return user
logger.info(
'user_store:get_user_by_id_async:calling_migrate_user',
'user_store:get_user_by_id_async:start_migration',
extra={'user_id': user_id},
)
user = await UserStore.migrate_user(
user_id,
user_settings,
user_info,
result = await session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(False),
)
)
return user
else:
return None
user_settings = result.scalars().first()
if user_settings:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(user_id)
logger.info(
'user_store:get_user_by_id_async:calling_migrate_user',
extra={'user_id': user_id},
)
user = await UserStore.migrate_user(
user_id,
user_settings,
user_info,
)
return user
else:
return None
finally:
await UserStore._release_user_creation_lock(user_id)
@staticmethod
def list_users() -> list[User]:
@@ -767,6 +800,96 @@ class UserStore:
}
return kwargs
@staticmethod
def _create_user_settings_from_entities(
user_id: str, org_member: OrgMember, user: User, org: Org
) -> UserSettings:
"""Create UserSettings from OrgMember, User, and Org data.
Uses OrgMember values first. If an OrgMember field is None and there's
a corresponding "default_" field in Org, use the Org value.
Also pulls relevant fields from User.
Args:
user_id: The Keycloak user ID
org_member: The OrgMember entity
user: The User entity
org: The Org entity
Returns:
A new UserSettings object populated from the entities
"""
# Mapping from OrgMember fields to corresponding Org "default_" fields
org_member_to_org_default = {
'llm_model': 'default_llm_model',
'llm_base_url': 'default_llm_base_url',
'max_iterations': 'default_max_iterations',
}
def get_value_with_org_fallback(field_name: str, org_member_value):
"""Get value from OrgMember, falling back to Org default if None."""
if org_member_value is not None:
return org_member_value
org_default_field = org_member_to_org_default.get(field_name)
if org_default_field and hasattr(org, org_default_field):
return getattr(org, org_default_field)
return None
# Get values from OrgMember with Org fallback for fields with default_ prefix
llm_model = get_value_with_org_fallback('llm_model', org_member.llm_model)
llm_base_url = get_value_with_org_fallback(
'llm_base_url', org_member.llm_base_url
)
max_iterations = get_value_with_org_fallback(
'max_iterations', org_member.max_iterations
)
return UserSettings(
keycloak_user_id=user_id,
# OrgMember fields
llm_api_key=org_member.llm_api_key.get_secret_value()
if org_member.llm_api_key
else None,
llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value()
if org_member.llm_api_key_for_byor
else None,
llm_model=llm_model,
llm_base_url=llm_base_url,
max_iterations=max_iterations,
# User fields
accepted_tos=user.accepted_tos,
enable_sound_notifications=user.enable_sound_notifications,
language=user.language,
user_consents_to_analytics=user.user_consents_to_analytics,
email=user.email,
email_verified=user.email_verified,
git_user_name=user.git_user_name,
git_user_email=user.git_user_email,
# Org fields
agent=org.agent,
security_analyzer=org.security_analyzer,
confirmation_mode=org.confirmation_mode,
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
enable_default_condenser=org.enable_default_condenser,
billing_margin=org.billing_margin,
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
sandbox_base_container_image=org.sandbox_base_container_image,
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
user_version=org.org_version,
mcp_config=org.mcp_config,
search_api_key=org.search_api_key.get_secret_value()
if org.search_api_key
else None,
sandbox_api_key=org.sandbox_api_key.get_secret_value()
if org.sandbox_api_key
else None,
max_budget_per_task=org.max_budget_per_task,
enable_solvability_analysis=org.enable_solvability_analysis,
v1_enabled=org.v1_enabled,
condenser_max_size=org.condenser_max_size,
already_migrated=False,
)
@staticmethod
def _has_custom_settings(
user_settings: UserSettings, old_user_version: int | None
@@ -812,3 +935,12 @@ class UserStore:
return False # Matches old default
return True # Custom model
def _is_legacy_value_encrypted(value: str) -> bool:
"""Check if a legacy value is encrypted by trying to decrypt it"""
try:
decrypt_legacy_value(value)
return True
except Exception:
return False

View File

@@ -1,7 +1,7 @@
"""Unit tests for OAuth2 Device Flow endpoints."""
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, Request
@@ -22,8 +22,10 @@ def mock_device_code_store():
@pytest.fixture
def mock_api_key_store():
"""Mock API key store."""
return MagicMock()
"""Mock API key store with async create_api_key."""
mock = MagicMock()
mock.create_api_key = AsyncMock()
return mock
@pytest.fixture
@@ -204,8 +206,9 @@ class TestDeviceVerificationAuthenticated:
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = True
# Mock API key store
# Mock API key store with async create_api_key
mock_api_key_store = MagicMock()
mock_api_key_store.create_api_key = AsyncMock()
mock_api_key_class.get_instance.return_value = mock_api_key_store
result = await device_verification_authenticated(
@@ -228,8 +231,9 @@ class TestDeviceVerificationAuthenticated:
@patch('server.routes.oauth_device.device_code_store')
async def test_multiple_device_authentication(self, mock_store, mock_api_key_class):
"""Test that multiple devices can authenticate simultaneously."""
# Mock API key store
# Mock API key store with async create_api_key
mock_api_key_store = MagicMock()
mock_api_key_store.create_api_key = AsyncMock()
mock_api_key_class.get_instance.return_value = mock_api_key_store
# Simulate two different devices
@@ -486,8 +490,9 @@ class TestDeviceVerificationTransactionIntegrity:
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = False # Authorization fails
# Mock API key store
# Mock API key store with async create_api_key
mock_api_key_store = MagicMock()
mock_api_key_store.create_api_key = AsyncMock()
mock_api_key_class.get_instance.return_value = mock_api_key_store
# Should raise HTTPException due to authorization failure
@@ -518,9 +523,11 @@ class TestDeviceVerificationTransactionIntegrity:
mock_store.authorize_device_code.return_value = True # Authorization succeeds
mock_store.deny_device_code.return_value = True # Cleanup succeeds
# Mock API key store to fail on creation
# Mock API key store to fail on creation (async)
mock_api_key_store = MagicMock()
mock_api_key_store.create_api_key.side_effect = Exception('Database error')
mock_api_key_store.create_api_key = AsyncMock(
side_effect=Exception('Database error')
)
mock_api_key_class.get_instance.return_value = mock_api_key_store
# Should raise HTTPException due to API key creation failure
@@ -558,9 +565,11 @@ class TestDeviceVerificationTransactionIntegrity:
'Cleanup failed'
) # Cleanup fails
# Mock API key store to fail on creation
# Mock API key store to fail on creation (async)
mock_api_key_store = MagicMock()
mock_api_key_store.create_api_key.side_effect = Exception('Database error')
mock_api_key_store.create_api_key = AsyncMock(
side_effect=Exception('Database error')
)
mock_api_key_class.get_instance.return_value = mock_api_key_store
# Should still raise HTTPException for the original API key creation failure
@@ -589,8 +598,9 @@ class TestDeviceVerificationTransactionIntegrity:
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = True # Authorization succeeds
# Mock API key store
# Mock API key store with async create_api_key
mock_api_key_store = MagicMock()
mock_api_key_store.create_api_key = AsyncMock()
mock_api_key_class.get_instance.return_value = mock_api_key_store
result = await device_verification_authenticated(

View File

@@ -32,6 +32,11 @@ def api_key_store(mock_session_maker):
return ApiKeyStore(mock_session_maker)
def run_sync(func, *args, **kwargs):
"""Helper to execute sync functions directly (mocks call_sync_from_async)."""
return func(*args, **kwargs)
def test_generate_api_key(api_key_store):
"""Test that generate_api_key returns a string with sk-oh- prefix and expected length."""
key = api_key_store.generate_api_key(length=32)
@@ -41,8 +46,12 @@ def test_generate_api_key(api_key_store):
assert len(key) == len('sk-oh-') + 32
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
@pytest.mark.asyncio
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
async def test_create_api_key(
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
):
"""Test creating an API key."""
# Setup
user_id = 'test-user-123'
@@ -51,7 +60,7 @@ def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
# Execute
result = api_key_store.create_api_key(user_id, name)
result = await api_key_store.create_api_key(user_id, name)
# Verify
assert result == 'test-api-key'
@@ -219,8 +228,12 @@ def test_delete_api_key_by_id(api_key_store, mock_session):
mock_session.commit.assert_called_once()
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
@pytest.mark.asyncio
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
async def test_list_api_keys(
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
):
"""Test listing API keys for a user."""
# Setup
user_id = 'test-user-123'
@@ -247,7 +260,7 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
mock_filter_org.all.return_value = [mock_key1, mock_key2]
# Execute
result = api_key_store.list_api_keys(user_id)
result = await api_key_store.list_api_keys(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
@@ -265,8 +278,12 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
assert result[1]['expires_at'] is None
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_user):
@pytest.mark.asyncio
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
async def test_retrieve_mcp_api_key(
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
):
"""Test retrieving MCP API key for a user."""
# Setup
user_id = 'test-user-123'
@@ -287,16 +304,18 @@ def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_u
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
# Execute
result = api_key_store.retrieve_mcp_api_key(user_id)
result = await api_key_store.retrieve_mcp_api_key(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
assert result == 'mcp-test-key'
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_retrieve_mcp_api_key_not_found(
mock_get_user, api_key_store, mock_session, mock_user
@pytest.mark.asyncio
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
async def test_retrieve_mcp_api_key_not_found(
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
):
"""Test retrieving MCP API key when none exists."""
# Setup
@@ -314,7 +333,7 @@ def test_retrieve_mcp_api_key_not_found(
mock_filter_org.all.return_value = [mock_other_key]
# Execute
result = api_key_store.retrieve_mcp_api_key(user_id)
result = await api_key_store.retrieve_mcp_api_key(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)

View File

@@ -37,7 +37,11 @@ def mock_user_store():
@pytest.mark.asyncio
async def test_save_and_get(session_maker):
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
store = SaasConversationStore(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
session_maker,
)
metadata = ConversationMetadata(
conversation_id='my-conversation-id',
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
@@ -62,7 +66,11 @@ async def test_save_and_get(session_maker):
@pytest.mark.asyncio
async def test_search(session_maker):
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
store = SaasConversationStore(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
session_maker,
)
# Create test conversations with different timestamps
conversations = [
@@ -107,7 +115,11 @@ async def test_search(session_maker):
@pytest.mark.asyncio
async def test_delete_metadata(session_maker):
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
store = SaasConversationStore(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
session_maker,
)
metadata = ConversationMetadata(
conversation_id='to-delete',
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
@@ -127,14 +139,22 @@ async def test_delete_metadata(session_maker):
@pytest.mark.asyncio
async def test_get_nonexistent_metadata(session_maker):
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
store = SaasConversationStore(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
session_maker,
)
with pytest.raises(FileNotFoundError):
await store.get_metadata('nonexistent-id')
@pytest.mark.asyncio
async def test_exists(session_maker):
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
store = SaasConversationStore(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
session_maker,
)
metadata = ConversationMetadata(
conversation_id='exists-test',
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',

View File

@@ -360,7 +360,7 @@ class OpenHandsMCPConfig:
return None
@staticmethod
def create_default_mcp_server_config(
async def create_default_mcp_server_config(
host: str, config: 'OpenHandsConfig', user_id: str | None = None
) -> tuple[MCPSHTTPServerConfig | None, list[MCPStdioServerConfig]]:
"""Create a default MCP server configuration.

View File

@@ -153,10 +153,11 @@ async def run_controller(
# Add MCP tools to the agent
if agent.config.enable_mcp:
# Add OpenHands' MCP server by default
_, openhands_mcp_stdio_servers = (
OpenHandsMCPConfigImpl.create_default_mcp_server_config(
config.mcp_host, config, None
)
(
_,
openhands_mcp_stdio_servers,
) = await OpenHandsMCPConfigImpl.create_default_mcp_server_config(
config.mcp_host, config, None
)
runtime.config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)

View File

@@ -202,10 +202,11 @@ class WebSession:
self.logger.debug(f'Merged custom MCP Config: {mcp_config}')
# Add OpenHands' MCP server by default
openhands_mcp_server, openhands_mcp_stdio_servers = (
OpenHandsMCPConfigImpl.create_default_mcp_server_config(
self.config.mcp_host, self.config, self.user_id
)
(
openhands_mcp_server,
openhands_mcp_stdio_servers,
) = await OpenHandsMCPConfigImpl.create_default_mcp_server_config(
self.config.mcp_host, self.config, self.user_id
)
if openhands_mcp_server: