mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Fix asyncio event loop conflict in get_user_by_id (#12475)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: tofarr <tofarr@gmail.com>
This commit is contained in:
@@ -63,10 +63,10 @@ async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
|
||||
try:
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
|
||||
current_org_id = str(user.current_org_id)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
return None
|
||||
current_org_id = str(user.current_org_id)
|
||||
key = await LiteLlmManager.generate_key(
|
||||
user_id,
|
||||
current_org_id,
|
||||
|
||||
@@ -37,7 +37,6 @@ from openhands.server.services.conversation_service import create_provider_token
|
||||
from openhands.server.shared import config
|
||||
from openhands.server.user_auth import get_access_token
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
@@ -177,7 +176,7 @@ async def keycloak_callback(
|
||||
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
user = await UserStore.create_user(user_id, user_info)
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
@@ -104,7 +103,7 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
@@ -257,9 +256,7 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
user = await call_sync_from_async(
|
||||
UserStore.get_user_by_id, billing_session.user_id
|
||||
)
|
||||
user = await UserStore.get_user_by_id_async(billing_session.user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
|
||||
@@ -38,7 +38,6 @@ from storage.user_store import UserStore
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.shared import config, sio
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
signature_verifier = SignatureVerifier(signing_secret=SLACK_SIGNING_SECRET)
|
||||
slack_router = APIRouter(prefix='/slack')
|
||||
@@ -197,7 +196,7 @@ async def keycloak_callback(
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
keycloak_user_id = user_info['sub']
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, keycloak_user_id)
|
||||
user = await UserStore.get_user_by_id_async(keycloak_user_id)
|
||||
if not user:
|
||||
return _html_response(
|
||||
title='Failed to authenticate.',
|
||||
|
||||
@@ -21,7 +21,6 @@ from server.logger import logger
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# Timeout in seconds for BYOR key verification requests to LiteLLM
|
||||
@@ -676,7 +675,7 @@ class LiteLlmManager:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, keycloak_user_id)
|
||||
user = await UserStore.get_user_by_id_async(keycloak_user_id)
|
||||
if not user:
|
||||
return {}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -26,7 +25,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
async def load(self) -> Secrets | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
|
||||
with self.session_maker() as session:
|
||||
@@ -53,7 +52,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def store(self, item: Secrets):
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
|
||||
@@ -6,6 +6,7 @@ import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
@@ -299,7 +300,12 @@ class UserStore:
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID."""
|
||||
"""Get user by Keycloak user ID (sync version).
|
||||
|
||||
Note: This method uses call_async_from_sync internally which creates a new
|
||||
event loop. If you're already in an async context, use get_user_by_id_async
|
||||
instead to avoid event loop conflicts.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
@@ -342,8 +348,6 @@ class UserStore:
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
token_manager = TokenManager()
|
||||
user_info = call_async_from_sync(
|
||||
token_manager.get_user_info_from_user_id,
|
||||
@@ -361,6 +365,62 @@ class UserStore:
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_id_async(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID (async version).
|
||||
|
||||
This is the preferred method when calling from an async context as it
|
||||
avoids event loop conflicts that can occur with the sync version.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if we need to migrate from user_settings
|
||||
while not await UserStore._acquire_user_creation_lock(user_id):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id)
|
||||
user = await UserStore.migrate_user(
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
|
||||
@@ -150,7 +150,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = None
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
@@ -185,7 +185,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
@@ -257,7 +257,7 @@ async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
@@ -304,7 +304,7 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
@@ -344,7 +344,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
@@ -579,7 +579,7 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
@@ -642,7 +642,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
@@ -705,7 +705,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
@@ -766,7 +766,7 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
@@ -811,7 +811,7 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
@@ -855,7 +855,7 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
@@ -912,7 +912,7 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -969,7 +969,7 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -1029,7 +1029,7 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -1185,7 +1185,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -1249,7 +1249,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
@@ -1331,7 +1331,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -1418,7 +1418,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -1502,7 +1502,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -1585,7 +1585,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -1665,7 +1665,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -1731,7 +1731,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -1803,7 +1803,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -1873,7 +1873,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
@@ -81,7 +81,8 @@ async def test_get_credits_lite_llm_error():
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
@@ -112,7 +113,8 @@ async def test_get_credits_success():
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch('httpx.AsyncClient', return_value=mock_client),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
@@ -301,7 +303,8 @@ async def test_success_callback_success():
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
@@ -358,7 +361,8 @@ async def test_success_callback_lite_llm_error():
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
|
||||
@@ -503,7 +503,9 @@ class TestLiteLlmManager:
|
||||
mock_org_member.org_id = 'test-ord-id'
|
||||
mock_org_member.llm_api_key = 'test-api-key'
|
||||
mock_user.org_members = [mock_org_member]
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(
|
||||
return_value=mock_user
|
||||
)
|
||||
|
||||
result = await LiteLlmManager._get_key_info(
|
||||
mock_http_client, 'test-ord-id', 'test-user-id'
|
||||
@@ -519,7 +521,7 @@ class TestLiteLlmManager:
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('storage.user_store.UserStore') as mock_user_store:
|
||||
mock_user_store.get_user_by_id.return_value = None
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=None)
|
||||
|
||||
result = await LiteLlmManager._get_key_info(
|
||||
mock_http_client, 'test-ord-id', 'test-user-id'
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
@@ -35,7 +35,10 @@ def secrets_store(session_maker, mock_config):
|
||||
|
||||
class TestSaasSecretsStore:
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_store_and_load(self, mock_get_user, secrets_store, mock_user):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
@@ -72,7 +75,10 @@ class TestSaasSecretsStore:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_encryption_decryption(self, mock_get_user, secrets_store, mock_user):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
@@ -169,7 +175,10 @@ class TestSaasSecretsStore:
|
||||
assert await secrets_store.load() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_update_existing_secrets(
|
||||
self, mock_get_user, secrets_store, mock_user
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user