diff --git a/enterprise/server/auth/saas_user_auth.py b/enterprise/server/auth/saas_user_auth.py index 5cd6a1e2c4..a0172c9d69 100644 --- a/enterprise/server/auth/saas_user_auth.py +++ b/enterprise/server/auth/saas_user_auth.py @@ -216,9 +216,9 @@ class SaasUserAuth(UserAuth): async def get_mcp_api_key(self) -> str: api_key_store = ApiKeyStore.get_instance() - mcp_api_key = api_key_store.retrieve_mcp_api_key(self.user_id) + mcp_api_key = await api_key_store.retrieve_mcp_api_key(self.user_id) if not mcp_api_key: - mcp_api_key = api_key_store.create_api_key( + mcp_api_key = await api_key_store.create_api_key( self.user_id, 'MCP_API_KEY', None ) return mcp_api_key diff --git a/enterprise/server/mcp/mcp_config.py b/enterprise/server/mcp/mcp_config.py index a4d60b8f7e..8231ba51a6 100644 --- a/enterprise/server/mcp/mcp_config.py +++ b/enterprise/server/mcp/mcp_config.py @@ -22,7 +22,7 @@ from openhands.core.logger import openhands_logger as logger # NOTE: these details are specific to the MCP protocol class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig): @staticmethod - def create_default_mcp_server_config( + async def create_default_mcp_server_config( host: str, config: 'OpenHandsConfig', user_id: str | None = None ) -> tuple[MCPSHTTPServerConfig | None, list[MCPStdioServerConfig]]: """ @@ -38,10 +38,12 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig): api_key_store = ApiKeyStore.get_instance() if user_id: - api_key = api_key_store.retrieve_mcp_api_key(user_id) + api_key = await api_key_store.retrieve_mcp_api_key(user_id) if not api_key: - api_key = api_key_store.create_api_key(user_id, 'MCP_API_KEY', None) + api_key = await api_key_store.create_api_key( + user_id, 'MCP_API_KEY', None + ) if not api_key: logger.error(f'Could not provision MCP API Key for user: {user_id}') diff --git a/enterprise/server/routes/api_keys.py b/enterprise/server/routes/api_keys.py index c07ae3fc51..049df8dd4f 100644 --- a/enterprise/server/routes/api_keys.py +++ b/enterprise/server/routes/api_keys.py @@ -10,53 +10,44 @@ from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger from openhands.server.user_auth import get_user_id -from openhands.utils.async_utils import call_sync_from_async # Helper functions for BYOR API key management async def get_byor_key_from_db(user_id: str) -> str | None: """Get the BYOR key from the database for a user.""" - - def _get_byor_key(): - user = UserStore.get_user_by_id(user_id) - if not user: - return None - - current_org_id = user.current_org_id - current_org_member: OrgMember = None - for org_member in user.org_members: - if org_member.org_id == current_org_id: - current_org_member = org_member - break - if not current_org_member: - return None - if current_org_member.llm_api_key_for_byor: - return current_org_member.llm_api_key_for_byor.get_secret_value() + user = await UserStore.get_user_by_id_async(user_id) + if not user: return None - return await call_sync_from_async(_get_byor_key) + current_org_id = user.current_org_id + current_org_member: OrgMember = None + for org_member in user.org_members: + if org_member.org_id == current_org_id: + current_org_member = org_member + break + if not current_org_member: + return None + if current_org_member.llm_api_key_for_byor: + return current_org_member.llm_api_key_for_byor.get_secret_value() + return None async def store_byor_key_in_db(user_id: str, key: str) -> None: """Store the BYOR key in the database for a user.""" + user = await UserStore.get_user_by_id_async(user_id) + if not user: + return None - def _update_user_settings(): - user = UserStore.get_user_by_id(user_id) - if not user: - return None - - current_org_id = user.current_org_id - current_org_member: OrgMember = None - for org_member in user.org_members: - if org_member.org_id == current_org_id: - current_org_member = org_member - break - if not current_org_member: - return None - current_org_member.llm_api_key_for_byor = key - OrgMemberStore.update_org_member(current_org_member) - - await call_sync_from_async(_update_user_settings) + current_org_id = user.current_org_id + current_org_member: OrgMember = None + for org_member in user.org_members: + if org_member.org_id == current_org_id: + current_org_member = org_member + break + if not current_org_member: + return None + current_org_member.llm_api_key_for_byor = key + OrgMemberStore.update_org_member(current_org_member) async def generate_byor_key(user_id: str) -> str | None: @@ -161,11 +152,11 @@ class LlmApiKeyResponse(BaseModel): async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)): """Create a new API key for the authenticated user.""" try: - api_key = api_key_store.create_api_key( + api_key = await api_key_store.create_api_key( user_id, key_data.name, key_data.expires_at ) # Get the created key details - keys = api_key_store.list_api_keys(user_id) + keys = await api_key_store.list_api_keys(user_id) for key in keys: if key['name'] == key_data.name: return { @@ -193,7 +184,7 @@ async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user async def list_api_keys(user_id: str = Depends(get_user_id)): """List all API keys for the authenticated user.""" try: - keys = api_key_store.list_api_keys(user_id) + keys = await api_key_store.list_api_keys(user_id) return [ { **key, @@ -222,7 +213,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)): """Delete an API key.""" try: # First, verify the key belongs to the user - keys = api_key_store.list_api_keys(user_id) + keys = await api_key_store.list_api_keys(user_id) key_to_delete = None for key in keys: diff --git a/enterprise/server/routes/oauth_device.py b/enterprise/server/routes/oauth_device.py index 3e9425a012..25033a46f3 100644 --- a/enterprise/server/routes/oauth_device.py +++ b/enterprise/server/routes/oauth_device.py @@ -272,7 +272,7 @@ async def device_verification_authenticated( try: # Create a unique API key for this device using user_code in the name device_key_name = f'{API_KEY_NAME} ({user_code})' - api_key_store.create_api_key( + await api_key_store.create_api_key( user_id, name=device_key_name, expires_at=datetime.now(UTC) + KEY_EXPIRATION_TIME, diff --git a/enterprise/server/saas_nested_conversation_manager.py b/enterprise/server/saas_nested_conversation_manager.py index c642366304..e757570113 100644 --- a/enterprise/server/saas_nested_conversation_manager.py +++ b/enterprise/server/saas_nested_conversation_manager.py @@ -516,11 +516,13 @@ class SaasNestedConversationManager(ConversationManager): ) raise - def _get_mcp_config(self, user_id: str) -> MCPConfig | None: + async def _get_mcp_config(self, user_id: str) -> MCPConfig | None: api_key_store = ApiKeyStore.get_instance() - mcp_api_key = api_key_store.retrieve_mcp_api_key(user_id) + mcp_api_key = await api_key_store.retrieve_mcp_api_key(user_id) if not mcp_api_key: - mcp_api_key = api_key_store.create_api_key(user_id, 'MCP_API_KEY', None) + mcp_api_key = await api_key_store.create_api_key( + user_id, 'MCP_API_KEY', None + ) if not mcp_api_key: return None web_host = os.environ.get('WEB_HOST', 'app.all-hands.dev') @@ -547,7 +549,7 @@ class SaasNestedConversationManager(ConversationManager): 'conversation_id': sid, } - mcp_config = self._get_mcp_config(user_id) + mcp_config = await self._get_mcp_config(user_id) if mcp_config: # Merge with any MCP config from settings if settings.mcp_config: diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index e753d1324e..7e32ace2b6 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -12,6 +12,7 @@ from storage.database import session_maker from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger +from openhands.utils.async_utils import call_sync_from_async @dataclass @@ -26,7 +27,7 @@ class ApiKeyStore: random_part = ''.join(secrets.choice(alphabet) for _ in range(length)) return f'{self.API_KEY_PREFIX}{random_part}' - def create_api_key( + async def create_api_key( self, user_id: str, name: str | None = None, expires_at: datetime | None = None ) -> str: """Create a new API key for a user. @@ -40,8 +41,23 @@ class ApiKeyStore: The generated API key """ api_key = self.generate_api_key() - user = UserStore.get_user_by_id(user_id) + user = await UserStore.get_user_by_id_async(user_id) org_id = user.current_org_id + await call_sync_from_async( + self._store_api_key, user_id, org_id, api_key, name, expires_at + ) + + return api_key + + def _store_api_key( + self, + user_id: str, + org_id: str, + api_key: str, + name: str | None, + expires_at: datetime | None = None, + ) -> None: + """Store an existing API key in the database.""" with self.session_maker() as session: key_record = ApiKey( key=api_key, @@ -53,8 +69,6 @@ class ApiKeyStore: session.add(key_record) session.commit() - return api_key - def validate_api_key(self, api_key: str) -> str | None: """Validate an API key and return the associated user_id if valid.""" now = datetime.now(UTC) @@ -112,10 +126,13 @@ class ApiKeyStore: return True - def list_api_keys(self, user_id: str) -> list[dict]: + async def list_api_keys(self, user_id: str) -> list[dict]: """List all API keys for a user.""" - user = UserStore.get_user_by_id(user_id) + user = await UserStore.get_user_by_id_async(user_id) org_id = user.current_org_id + return await call_sync_from_async(self._list_api_keys_from_db, user_id, org_id) + + def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]: with self.session_maker() as session: keys = ( session.query(ApiKey) @@ -136,9 +153,14 @@ class ApiKeyStore: if 'MCP_API_KEY' != key.name ] - def retrieve_mcp_api_key(self, user_id: str) -> str | None: - user = UserStore.get_user_by_id(user_id) + async def retrieve_mcp_api_key(self, user_id: str) -> str | None: + user = await UserStore.get_user_by_id_async(user_id) org_id = user.current_org_id + return await call_sync_from_async( + self._retrieve_mcp_api_key_from_db, user_id, org_id + ) + + def _retrieve_mcp_api_key_from_db(self, user_id: str, org_id: str) -> str | None: with self.session_maker() as session: keys: list[ApiKey] = ( session.query(ApiKey) diff --git a/enterprise/storage/encrypt_utils.py b/enterprise/storage/encrypt_utils.py index a8240e5d95..a612a2ca0a 100644 --- a/enterprise/storage/encrypt_utils.py +++ b/enterprise/storage/encrypt_utils.py @@ -98,6 +98,29 @@ def decrypt_legacy_value(value: str | SecretStr) -> str: return get_fernet().decrypt(b64decode(value.encode())).decode() +def encrypt_legacy_model(encrypt_keys: list, model_instance) -> dict: + return encrypt_legacy_kwargs(encrypt_keys, model_to_kwargs(model_instance)) + + +def encrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict: + for key, value in kwargs.items(): + if value is None: + continue + if key in encrypt_keys: + value = encrypt_legacy_value(value) + kwargs[key] = value + return kwargs + + +def encrypt_legacy_value(value: str | SecretStr) -> str: + if isinstance(value, SecretStr): + return b64encode( + get_fernet().encrypt(value.get_secret_value().encode()) + ).decode() + else: + return b64encode(get_fernet().encrypt(value.encode())).decode() + + def get_fernet(): global _fernet if _fernet is None: diff --git a/enterprise/storage/saas_conversation_store.py b/enterprise/storage/saas_conversation_store.py index e74624314c..7e43dfe471 100644 --- a/enterprise/storage/saas_conversation_store.py +++ b/enterprise/storage/saas_conversation_store.py @@ -34,11 +34,10 @@ class SaasConversationStore(ConversationStore): session_maker: sessionmaker org_id: UUID | None = None # will be fetched automatically - def __init__(self, user_id: str, session_maker: sessionmaker): + def __init__(self, user_id: str, org_id: UUID, session_maker: sessionmaker): self.user_id = user_id + self.org_id = org_id self.session_maker = session_maker - user = UserStore.get_user_by_id(user_id) - self.org_id = user.current_org_id if user else None def _select_by_id(self, session, conversation_id: str): # Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org @@ -235,4 +234,6 @@ class SaasConversationStore(ConversationStore): cls, config: OpenHandsConfig, user_id: str | None ) -> ConversationStore: # user_id should not be None in SaaS, should we raise? - return SaasConversationStore(str(user_id), session_maker) + user = await UserStore.get_user_by_id_async(user_id) + org_id = user.current_org_id if user else None + return SaasConversationStore(str(user_id), org_id, session_maker) diff --git a/enterprise/storage/user_store.py b/enterprise/storage/user_store.py index 7a42677a44..267107d660 100644 --- a/enterprise/storage/user_store.py +++ b/enterprise/storage/user_store.py @@ -17,7 +17,11 @@ from server.logger import logger from sqlalchemy import select, text from sqlalchemy.orm import joinedload from storage.database import a_session_maker, session_maker -from storage.encrypt_utils import decrypt_legacy_model +from storage.encrypt_utils import ( + decrypt_legacy_model, + decrypt_legacy_value, + encrypt_legacy_value, +) from storage.org import Org from storage.org_member import OrgMember from storage.role_store import RoleStore @@ -127,6 +131,25 @@ class UserStore: ) return bool(lock_acquired) + @staticmethod + async def _release_user_creation_lock(user_id: str) -> bool: + """Release the distributed lock for user creation. + + Returns True if the lock was released or if Redis is unavailable. + Returns False if the lock could not be released. + """ + redis_client = UserStore._get_redis_client() + if redis_client is None: + logger.warning( + 'user_store:_release_user_creation_lock:no_redis_client', + extra={'user_id': user_id}, + ) + return True # Nothing to release if Redis is unavailable + + user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}' + deleted = await redis_client.delete(user_key) + return bool(deleted) + @staticmethod async def migrate_user( user_id: str, @@ -238,7 +261,6 @@ class UserStore: if not custom_settings: del org_member_kwargs['llm_model'] del org_member_kwargs['llm_base_url'] - del org_member_kwargs['llm_api_key_for_byor'] org_member = OrgMember( org_id=org.id, @@ -423,20 +445,10 @@ class UserStore: org_member = org_members[0] is_new_signup = True - # Create a new user_settings entry from org_member data + # Create a new user_settings entry from OrgMember, User, and Org data # This is needed for new sign-ups who don't have user_settings - user_settings = UserSettings( - keycloak_user_id=user_id, - llm_api_key=org_member.llm_api_key.get_secret_value() - if org_member.llm_api_key - else None, - llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value() - if org_member.llm_api_key_for_byor - else None, - llm_model=org_member.llm_model, - llm_base_url=org_member.llm_base_url, - max_iterations=org_member.max_iterations, - already_migrated=False, # Will be set correctly below + user_settings = UserStore._create_user_settings_from_entities( + user_id, org_member, user, org ) session.add(user_settings) session.flush() @@ -565,8 +577,21 @@ class UserStore: {'org_id': user_uuid}, ) - # Step 8: Set already_migrated=False on user_settings + # Step 8: Set already_migrated=False on user_settings and encrypt fields user_settings.already_migrated = False + + # Re-encrypt the sensitive fields before storing in the DB + encrypt_keys = [ + 'llm_api_key', + 'llm_api_key_for_byor', + 'search_api_key', + 'sandbox_api_key', + ] + for key in encrypt_keys: + value = getattr(user_settings, key, None) + if value is not None and not _is_legacy_value_encrypted(value): + setattr(user_settings, key, encrypt_legacy_value(value)) + session.merge(user_settings) session.commit() @@ -608,41 +633,46 @@ class UserStore: asyncio.sleep, GENERAL_TIMEOUT, _RETRY_LOAD_DELAY_SECONDS ) - # Check for user again as migration could have happened while trying to get the lock. - user = ( - session.query(User) - .options(joinedload(User.org_members)) - .filter(User.id == uuid.UUID(user_id)) - .first() - ) - if user: - return user + try: + # Check for user again as migration could have happened while trying to get the lock. + user = ( + session.query(User) + .options(joinedload(User.org_members)) + .filter(User.id == uuid.UUID(user_id)) + .first() + ) + if user: + return user - user_settings = ( - session.query(UserSettings) - .filter( - UserSettings.keycloak_user_id == user_id, - UserSettings.already_migrated.is_(False), + user_settings = ( + session.query(UserSettings) + .filter( + UserSettings.keycloak_user_id == user_id, + UserSettings.already_migrated.is_(False), + ) + .first() ) - .first() - ) - if user_settings: - token_manager = TokenManager() - user_info = call_async_from_sync( - token_manager.get_user_info_from_user_id, - GENERAL_TIMEOUT, - user_id, + if user_settings: + token_manager = TokenManager() + user_info = call_async_from_sync( + token_manager.get_user_info_from_user_id, + GENERAL_TIMEOUT, + user_id, + ) + user = call_async_from_sync( + UserStore.migrate_user, + GENERAL_TIMEOUT, + user_id, + user_settings, + user_info, + ) + return user + else: + return None + finally: + call_async_from_sync( + UserStore._release_user_creation_lock, GENERAL_TIMEOUT, user_id ) - user = call_async_from_sync( - UserStore.migrate_user, - GENERAL_TIMEOUT, - user_id, - user_settings, - user_info, - ) - return user - else: - return None @staticmethod async def get_user_by_id_async(user_id: str) -> Optional[User]: @@ -670,42 +700,45 @@ class UserStore: ) await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS) - # Check for user again as migration could have happened while trying to get the lock. - result = await session.execute( - select(User) - .options(joinedload(User.org_members)) - .filter(User.id == uuid.UUID(user_id)) - ) - user = result.scalars().first() - if user: - return user - - logger.info( - 'user_store:get_user_by_id_async:start_migration', - extra={'user_id': user_id}, - ) - result = await session.execute( - select(UserSettings).filter( - UserSettings.keycloak_user_id == user_id, - UserSettings.already_migrated.is_(False), + try: + # Check for user again as migration could have happened while trying to get the lock. + result = await session.execute( + select(User) + .options(joinedload(User.org_members)) + .filter(User.id == uuid.UUID(user_id)) ) - ) - user_settings = result.scalars().first() - if user_settings: - token_manager = TokenManager() - user_info = await token_manager.get_user_info_from_user_id(user_id) + user = result.scalars().first() + if user: + return user + logger.info( - 'user_store:get_user_by_id_async:calling_migrate_user', + 'user_store:get_user_by_id_async:start_migration', extra={'user_id': user_id}, ) - user = await UserStore.migrate_user( - user_id, - user_settings, - user_info, + result = await session.execute( + select(UserSettings).filter( + UserSettings.keycloak_user_id == user_id, + UserSettings.already_migrated.is_(False), + ) ) - return user - else: - return None + user_settings = result.scalars().first() + if user_settings: + token_manager = TokenManager() + user_info = await token_manager.get_user_info_from_user_id(user_id) + logger.info( + 'user_store:get_user_by_id_async:calling_migrate_user', + extra={'user_id': user_id}, + ) + user = await UserStore.migrate_user( + user_id, + user_settings, + user_info, + ) + return user + else: + return None + finally: + await UserStore._release_user_creation_lock(user_id) @staticmethod def list_users() -> list[User]: @@ -767,6 +800,96 @@ class UserStore: } return kwargs + @staticmethod + def _create_user_settings_from_entities( + user_id: str, org_member: OrgMember, user: User, org: Org + ) -> UserSettings: + """Create UserSettings from OrgMember, User, and Org data. + + Uses OrgMember values first. If an OrgMember field is None and there's + a corresponding "default_" field in Org, use the Org value. + Also pulls relevant fields from User. + + Args: + user_id: The Keycloak user ID + org_member: The OrgMember entity + user: The User entity + org: The Org entity + + Returns: + A new UserSettings object populated from the entities + """ + # Mapping from OrgMember fields to corresponding Org "default_" fields + org_member_to_org_default = { + 'llm_model': 'default_llm_model', + 'llm_base_url': 'default_llm_base_url', + 'max_iterations': 'default_max_iterations', + } + + def get_value_with_org_fallback(field_name: str, org_member_value): + """Get value from OrgMember, falling back to Org default if None.""" + if org_member_value is not None: + return org_member_value + org_default_field = org_member_to_org_default.get(field_name) + if org_default_field and hasattr(org, org_default_field): + return getattr(org, org_default_field) + return None + + # Get values from OrgMember with Org fallback for fields with default_ prefix + llm_model = get_value_with_org_fallback('llm_model', org_member.llm_model) + llm_base_url = get_value_with_org_fallback( + 'llm_base_url', org_member.llm_base_url + ) + max_iterations = get_value_with_org_fallback( + 'max_iterations', org_member.max_iterations + ) + + return UserSettings( + keycloak_user_id=user_id, + # OrgMember fields + llm_api_key=org_member.llm_api_key.get_secret_value() + if org_member.llm_api_key + else None, + llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value() + if org_member.llm_api_key_for_byor + else None, + llm_model=llm_model, + llm_base_url=llm_base_url, + max_iterations=max_iterations, + # User fields + accepted_tos=user.accepted_tos, + enable_sound_notifications=user.enable_sound_notifications, + language=user.language, + user_consents_to_analytics=user.user_consents_to_analytics, + email=user.email, + email_verified=user.email_verified, + git_user_name=user.git_user_name, + git_user_email=user.git_user_email, + # Org fields + agent=org.agent, + security_analyzer=org.security_analyzer, + confirmation_mode=org.confirmation_mode, + remote_runtime_resource_factor=org.remote_runtime_resource_factor, + enable_default_condenser=org.enable_default_condenser, + billing_margin=org.billing_margin, + enable_proactive_conversation_starters=org.enable_proactive_conversation_starters, + sandbox_base_container_image=org.sandbox_base_container_image, + sandbox_runtime_container_image=org.sandbox_runtime_container_image, + user_version=org.org_version, + mcp_config=org.mcp_config, + search_api_key=org.search_api_key.get_secret_value() + if org.search_api_key + else None, + sandbox_api_key=org.sandbox_api_key.get_secret_value() + if org.sandbox_api_key + else None, + max_budget_per_task=org.max_budget_per_task, + enable_solvability_analysis=org.enable_solvability_analysis, + v1_enabled=org.v1_enabled, + condenser_max_size=org.condenser_max_size, + already_migrated=False, + ) + @staticmethod def _has_custom_settings( user_settings: UserSettings, old_user_version: int | None @@ -812,3 +935,12 @@ class UserStore: return False # Matches old default return True # Custom model + + +def _is_legacy_value_encrypted(value: str) -> bool: + """Check if a legacy value is encrypted by trying to decrypt it""" + try: + decrypt_legacy_value(value) + return True + except Exception: + return False diff --git a/enterprise/tests/unit/server/routes/test_oauth_device.py b/enterprise/tests/unit/server/routes/test_oauth_device.py index 53682e65f0..7ee8a7282e 100644 --- a/enterprise/tests/unit/server/routes/test_oauth_device.py +++ b/enterprise/tests/unit/server/routes/test_oauth_device.py @@ -1,7 +1,7 @@ """Unit tests for OAuth2 Device Flow endpoints.""" from datetime import UTC, datetime, timedelta -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import HTTPException, Request @@ -22,8 +22,10 @@ def mock_device_code_store(): @pytest.fixture def mock_api_key_store(): - """Mock API key store.""" - return MagicMock() + """Mock API key store with async create_api_key.""" + mock = MagicMock() + mock.create_api_key = AsyncMock() + return mock @pytest.fixture @@ -204,8 +206,9 @@ class TestDeviceVerificationAuthenticated: mock_store.get_by_user_code.return_value = mock_device mock_store.authorize_device_code.return_value = True - # Mock API key store + # Mock API key store with async create_api_key mock_api_key_store = MagicMock() + mock_api_key_store.create_api_key = AsyncMock() mock_api_key_class.get_instance.return_value = mock_api_key_store result = await device_verification_authenticated( @@ -228,8 +231,9 @@ class TestDeviceVerificationAuthenticated: @patch('server.routes.oauth_device.device_code_store') async def test_multiple_device_authentication(self, mock_store, mock_api_key_class): """Test that multiple devices can authenticate simultaneously.""" - # Mock API key store + # Mock API key store with async create_api_key mock_api_key_store = MagicMock() + mock_api_key_store.create_api_key = AsyncMock() mock_api_key_class.get_instance.return_value = mock_api_key_store # Simulate two different devices @@ -486,8 +490,9 @@ class TestDeviceVerificationTransactionIntegrity: mock_store.get_by_user_code.return_value = mock_device mock_store.authorize_device_code.return_value = False # Authorization fails - # Mock API key store + # Mock API key store with async create_api_key mock_api_key_store = MagicMock() + mock_api_key_store.create_api_key = AsyncMock() mock_api_key_class.get_instance.return_value = mock_api_key_store # Should raise HTTPException due to authorization failure @@ -518,9 +523,11 @@ class TestDeviceVerificationTransactionIntegrity: mock_store.authorize_device_code.return_value = True # Authorization succeeds mock_store.deny_device_code.return_value = True # Cleanup succeeds - # Mock API key store to fail on creation + # Mock API key store to fail on creation (async) mock_api_key_store = MagicMock() - mock_api_key_store.create_api_key.side_effect = Exception('Database error') + mock_api_key_store.create_api_key = AsyncMock( + side_effect=Exception('Database error') + ) mock_api_key_class.get_instance.return_value = mock_api_key_store # Should raise HTTPException due to API key creation failure @@ -558,9 +565,11 @@ class TestDeviceVerificationTransactionIntegrity: 'Cleanup failed' ) # Cleanup fails - # Mock API key store to fail on creation + # Mock API key store to fail on creation (async) mock_api_key_store = MagicMock() - mock_api_key_store.create_api_key.side_effect = Exception('Database error') + mock_api_key_store.create_api_key = AsyncMock( + side_effect=Exception('Database error') + ) mock_api_key_class.get_instance.return_value = mock_api_key_store # Should still raise HTTPException for the original API key creation failure @@ -589,8 +598,9 @@ class TestDeviceVerificationTransactionIntegrity: mock_store.get_by_user_code.return_value = mock_device mock_store.authorize_device_code.return_value = True # Authorization succeeds - # Mock API key store + # Mock API key store with async create_api_key mock_api_key_store = MagicMock() + mock_api_key_store.create_api_key = AsyncMock() mock_api_key_class.get_instance.return_value = mock_api_key_store result = await device_verification_authenticated( diff --git a/enterprise/tests/unit/test_api_key_store.py b/enterprise/tests/unit/test_api_key_store.py index 39d0288191..151312211f 100644 --- a/enterprise/tests/unit/test_api_key_store.py +++ b/enterprise/tests/unit/test_api_key_store.py @@ -32,6 +32,11 @@ def api_key_store(mock_session_maker): return ApiKeyStore(mock_session_maker) +def run_sync(func, *args, **kwargs): + """Helper to execute sync functions directly (mocks call_sync_from_async).""" + return func(*args, **kwargs) + + def test_generate_api_key(api_key_store): """Test that generate_api_key returns a string with sk-oh- prefix and expected length.""" key = api_key_store.generate_api_key(length=32) @@ -41,8 +46,12 @@ def test_generate_api_key(api_key_store): assert len(key) == len('sk-oh-') + 32 -@patch('storage.api_key_store.UserStore.get_user_by_id') -def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user): +@pytest.mark.asyncio +@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync) +@patch('storage.api_key_store.UserStore.get_user_by_id_async') +async def test_create_api_key( + mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user +): """Test creating an API key.""" # Setup user_id = 'test-user-123' @@ -51,7 +60,7 @@ def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user): api_key_store.generate_api_key = MagicMock(return_value='test-api-key') # Execute - result = api_key_store.create_api_key(user_id, name) + result = await api_key_store.create_api_key(user_id, name) # Verify assert result == 'test-api-key' @@ -219,8 +228,12 @@ def test_delete_api_key_by_id(api_key_store, mock_session): mock_session.commit.assert_called_once() -@patch('storage.api_key_store.UserStore.get_user_by_id') -def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user): +@pytest.mark.asyncio +@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync) +@patch('storage.api_key_store.UserStore.get_user_by_id_async') +async def test_list_api_keys( + mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user +): """Test listing API keys for a user.""" # Setup user_id = 'test-user-123' @@ -247,7 +260,7 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user): mock_filter_org.all.return_value = [mock_key1, mock_key2] # Execute - result = api_key_store.list_api_keys(user_id) + result = await api_key_store.list_api_keys(user_id) # Verify mock_get_user.assert_called_once_with(user_id) @@ -265,8 +278,12 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user): assert result[1]['expires_at'] is None -@patch('storage.api_key_store.UserStore.get_user_by_id') -def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_user): +@pytest.mark.asyncio +@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync) +@patch('storage.api_key_store.UserStore.get_user_by_id_async') +async def test_retrieve_mcp_api_key( + mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user +): """Test retrieving MCP API key for a user.""" # Setup user_id = 'test-user-123' @@ -287,16 +304,18 @@ def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_u mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key] # Execute - result = api_key_store.retrieve_mcp_api_key(user_id) + result = await api_key_store.retrieve_mcp_api_key(user_id) # Verify mock_get_user.assert_called_once_with(user_id) assert result == 'mcp-test-key' -@patch('storage.api_key_store.UserStore.get_user_by_id') -def test_retrieve_mcp_api_key_not_found( - mock_get_user, api_key_store, mock_session, mock_user +@pytest.mark.asyncio +@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync) +@patch('storage.api_key_store.UserStore.get_user_by_id_async') +async def test_retrieve_mcp_api_key_not_found( + mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user ): """Test retrieving MCP API key when none exists.""" # Setup @@ -314,7 +333,7 @@ def test_retrieve_mcp_api_key_not_found( mock_filter_org.all.return_value = [mock_other_key] # Execute - result = api_key_store.retrieve_mcp_api_key(user_id) + result = await api_key_store.retrieve_mcp_api_key(user_id) # Verify mock_get_user.assert_called_once_with(user_id) diff --git a/enterprise/tests/unit/test_saas_conversation_store.py b/enterprise/tests/unit/test_saas_conversation_store.py index f74079ef48..f4f9a7afe6 100644 --- a/enterprise/tests/unit/test_saas_conversation_store.py +++ b/enterprise/tests/unit/test_saas_conversation_store.py @@ -37,7 +37,11 @@ def mock_user_store(): @pytest.mark.asyncio async def test_save_and_get(session_maker): - store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker) + store = SaasConversationStore( + '5594c7b6-f959-4b81-92e9-b09c206f5081', + UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'), + session_maker, + ) metadata = ConversationMetadata( conversation_id='my-conversation-id', user_id='5594c7b6-f959-4b81-92e9-b09c206f5081', @@ -62,7 +66,11 @@ async def test_save_and_get(session_maker): @pytest.mark.asyncio async def test_search(session_maker): - store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker) + store = SaasConversationStore( + '5594c7b6-f959-4b81-92e9-b09c206f5081', + UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'), + session_maker, + ) # Create test conversations with different timestamps conversations = [ @@ -107,7 +115,11 @@ async def test_search(session_maker): @pytest.mark.asyncio async def test_delete_metadata(session_maker): - store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker) + store = SaasConversationStore( + '5594c7b6-f959-4b81-92e9-b09c206f5081', + UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'), + session_maker, + ) metadata = ConversationMetadata( conversation_id='to-delete', user_id='5594c7b6-f959-4b81-92e9-b09c206f5081', @@ -127,14 +139,22 @@ async def test_delete_metadata(session_maker): @pytest.mark.asyncio async def test_get_nonexistent_metadata(session_maker): - store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker) + store = SaasConversationStore( + '5594c7b6-f959-4b81-92e9-b09c206f5081', + UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'), + session_maker, + ) with pytest.raises(FileNotFoundError): await store.get_metadata('nonexistent-id') @pytest.mark.asyncio async def test_exists(session_maker): - store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker) + store = SaasConversationStore( + '5594c7b6-f959-4b81-92e9-b09c206f5081', + UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'), + session_maker, + ) metadata = ConversationMetadata( conversation_id='exists-test', user_id='5594c7b6-f959-4b81-92e9-b09c206f5081', diff --git a/openhands/core/config/mcp_config.py b/openhands/core/config/mcp_config.py index 00b739cff7..da271b24d5 100644 --- a/openhands/core/config/mcp_config.py +++ b/openhands/core/config/mcp_config.py @@ -360,7 +360,7 @@ class OpenHandsMCPConfig: return None @staticmethod - def create_default_mcp_server_config( + async def create_default_mcp_server_config( host: str, config: 'OpenHandsConfig', user_id: str | None = None ) -> tuple[MCPSHTTPServerConfig | None, list[MCPStdioServerConfig]]: """Create a default MCP server configuration. diff --git a/openhands/core/main.py b/openhands/core/main.py index 03dfca0f1a..f0042bc2a9 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -153,10 +153,11 @@ async def run_controller( # Add MCP tools to the agent if agent.config.enable_mcp: # Add OpenHands' MCP server by default - _, openhands_mcp_stdio_servers = ( - OpenHandsMCPConfigImpl.create_default_mcp_server_config( - config.mcp_host, config, None - ) + ( + _, + openhands_mcp_stdio_servers, + ) = await OpenHandsMCPConfigImpl.create_default_mcp_server_config( + config.mcp_host, config, None ) runtime.config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers) diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index c27c0ff2d7..e55618993e 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -202,10 +202,11 @@ class WebSession: self.logger.debug(f'Merged custom MCP Config: {mcp_config}') # Add OpenHands' MCP server by default - openhands_mcp_server, openhands_mcp_stdio_servers = ( - OpenHandsMCPConfigImpl.create_default_mcp_server_config( - self.config.mcp_host, self.config, self.user_id - ) + ( + openhands_mcp_server, + openhands_mcp_stdio_servers, + ) = await OpenHandsMCPConfigImpl.create_default_mcp_server_config( + self.config.mcp_host, self.config, self.user_id ) if openhands_mcp_server: