diff --git a/AGENTS.md b/AGENTS.md index 425ca5a1a6..878a26e884 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -165,7 +165,7 @@ Each integration follows a consistent pattern with service classes, storage mode **Import Patterns:** - Use relative imports without `enterprise.` prefix in enterprise code -- Example: `from storage.database import session_maker` not `from enterprise.storage.database import session_maker` +- Example: `from storage.database import a_session_maker` not `from enterprise.storage.database import a_session_maker` - This ensures code works in both OpenHands and enterprise contexts **Test Structure:** diff --git a/enterprise/doc/design-doc/openhands-enterprise-telemetry-design.md b/enterprise/doc/design-doc/openhands-enterprise-telemetry-design.md index 4fc9f72c00..5ed22d782b 100644 --- a/enterprise/doc/design-doc/openhands-enterprise-telemetry-design.md +++ b/enterprise/doc/design-doc/openhands-enterprise-telemetry-design.md @@ -200,7 +200,7 @@ class MetricsCollector(ABC): """Base class for metrics collectors.""" @abstractmethod - def collect(self) -> List[MetricResult]: + async def collect(self) -> List[MetricResult]: """Collect metrics and return results.""" pass @@ -264,12 +264,13 @@ class SystemMetricsCollector(MetricsCollector): def collector_name(self) -> str: return "system_metrics" - def collect(self) -> List[MetricResult]: + async def collect(self) -> List[MetricResult]: results = [] # Collect user count - with session_maker() as session: - user_count = session.query(UserSettings).count() + async with a_session_maker() as session: + user_count_result = await session.execute(select(func.count()).select_from(UserSettings)) + user_count = user_count_result.scalar() results.append(MetricResult( key="total_users", value=user_count @@ -277,9 +278,11 @@ class SystemMetricsCollector(MetricsCollector): # Collect conversation count (last 30 days) thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30) - conversation_count = session.query(StoredConversationMetadata)\ - .filter(StoredConversationMetadata.created_at >= thirty_days_ago)\ - .count() + conversation_count_result = await session.execute( + select(func.count()).select_from(StoredConversationMetadata) + .where(StoredConversationMetadata.created_at >= thirty_days_ago) + ) + conversation_count = conversation_count_result.scalar() results.append(MetricResult( key="conversations_30d", @@ -303,7 +306,7 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor): """Collect metrics from all registered collectors.""" # Check if collection is needed - if not self._should_collect(): + if not await self._should_collect(): return {"status": "skipped", "reason": "too_recent"} # Collect metrics from all registered collectors @@ -313,7 +316,7 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor): for collector in collector_registry.get_all_collectors(): try: if collector.should_collect(): - results = collector.collect() + results = await collector.collect() for result in results: all_metrics[result.key] = result.value collector_results[collector.collector_name] = len(results) @@ -322,13 +325,13 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor): collector_results[collector.collector_name] = f"error: {e}" # Store metrics in database - with session_maker() as session: + async with a_session_maker() as session: telemetry_record = TelemetryMetrics( metrics_data=all_metrics, collected_at=datetime.now(timezone.utc) ) session.add(telemetry_record) - session.commit() + await session.commit() # Note: No need to track last_collection_at separately # Can be derived from MAX(collected_at) in telemetry_metrics @@ -339,11 +342,12 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor): "collectors_run": collector_results } - def _should_collect(self) -> bool: + async def _should_collect(self) -> bool: """Check if collection is needed based on interval.""" - with session_maker() as session: + async with a_session_maker() as session: # Get last collection time from metrics table - last_collected = session.query(func.max(TelemetryMetrics.collected_at)).scalar() + result = await session.execute(select(func.max(TelemetryMetrics.collected_at))) + last_collected = result.scalar() if not last_collected: return True @@ -366,17 +370,19 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor): """Upload pending metrics to Replicated.""" # Get pending metrics - with session_maker() as session: - pending_metrics = session.query(TelemetryMetrics)\ - .filter(TelemetryMetrics.uploaded_at.is_(None))\ - .order_by(TelemetryMetrics.collected_at)\ - .all() + async with a_session_maker() as session: + result = await session.execute( + select(TelemetryMetrics) + .where(TelemetryMetrics.uploaded_at.is_(None)) + .order_by(TelemetryMetrics.collected_at) + ) + pending_metrics = result.scalars().all() if not pending_metrics: return {"status": "no_pending_metrics"} # Get admin email - skip if not available - admin_email = self._get_admin_email() + admin_email = await self._get_admin_email() if not admin_email: logger.info("Skipping telemetry upload - no admin email available") return { @@ -413,13 +419,15 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor): await instance.set_status(InstanceStatus.RUNNING) # Mark as uploaded - with session_maker() as session: - record = session.query(TelemetryMetrics)\ - .filter(TelemetryMetrics.id == metric_record.id)\ - .first() + async with a_session_maker() as session: + result = await session.execute( + select(TelemetryMetrics) + .where(TelemetryMetrics.id == metric_record.id) + ) + record = result.scalar_one_or_none() if record: record.uploaded_at = datetime.now(timezone.utc) - session.commit() + await session.commit() uploaded_count += 1 @@ -427,14 +435,16 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor): logger.error(f"Failed to upload metrics {metric_record.id}: {e}") # Update error info - with session_maker() as session: - record = session.query(TelemetryMetrics)\ - .filter(TelemetryMetrics.id == metric_record.id)\ - .first() + async with a_session_maker() as session: + result = await session.execute( + select(TelemetryMetrics) + .where(TelemetryMetrics.id == metric_record.id) + ) + record = result.scalar_one_or_none() if record: record.upload_attempts += 1 record.last_upload_error = str(e) - session.commit() + await session.commit() failed_count += 1 @@ -448,7 +458,7 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor): "total_processed": len(pending_metrics) } - def _get_admin_email(self) -> str | None: + async def _get_admin_email(self) -> str | None: """Get administrator email for customer identification.""" # 1. Check environment variable first env_admin_email = os.getenv('OPENHANDS_ADMIN_EMAIL') @@ -457,12 +467,15 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor): return env_admin_email # 2. Use first active user's email (earliest accepted_tos) - with session_maker() as session: - first_user = session.query(UserSettings)\ - .filter(UserSettings.email.isnot(None))\ - .filter(UserSettings.accepted_tos.isnot(None))\ - .order_by(UserSettings.accepted_tos.asc())\ - .first() + async with a_session_maker() as session: + result = await session.execute( + select(UserSettings) + .where(UserSettings.email.isnot(None)) + .where(UserSettings.accepted_tos.isnot(None)) + .order_by(UserSettings.accepted_tos.asc()) + .limit(1) + ) + first_user = result.scalar_one_or_none() if first_user and first_user.email: logger.info(f"Using first active user email: {first_user.email}") @@ -474,15 +487,16 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor): async def _update_telemetry_identity(self, customer_id: str, instance_id: str) -> None: """Update or create telemetry identity record.""" - with session_maker() as session: - identity = session.query(TelemetryIdentity).first() + async with a_session_maker() as session: + result = await session.execute(select(TelemetryIdentity).limit(1)) + identity = result.scalar_one_or_none() if not identity: identity = TelemetryIdentity() session.add(identity) identity.customer_id = customer_id identity.instance_id = instance_id - session.commit() + await session.commit() ``` ### 4.4 License Warning System @@ -503,11 +517,13 @@ async def get_license_status(): if not _is_openhands_enterprise(): return {"warn": False, "message": ""} - with session_maker() as session: + async with a_session_maker() as session: # Get last successful upload time from metrics table - last_upload = session.query(func.max(TelemetryMetrics.uploaded_at))\ - .filter(TelemetryMetrics.uploaded_at.isnot(None))\ - .scalar() + result = await session.execute( + select(func.max(TelemetryMetrics.uploaded_at)) + .where(TelemetryMetrics.uploaded_at.isnot(None)) + ) + last_upload = result.scalar() if not last_upload: # No successful uploads yet - show warning after 4 days @@ -521,10 +537,13 @@ async def get_license_status(): if days_since_upload > 4: # Find oldest unsent batch - oldest_unsent = session.query(TelemetryMetrics)\ - .filter(TelemetryMetrics.uploaded_at.is_(None))\ - .order_by(TelemetryMetrics.collected_at)\ - .first() + result = await session.execute( + select(TelemetryMetrics) + .where(TelemetryMetrics.uploaded_at.is_(None)) + .order_by(TelemetryMetrics.collected_at) + .limit(1) + ) + oldest_unsent = result.scalar_one_or_none() if oldest_unsent: # Calculate expiration date (oldest unsent + 34 days) @@ -630,19 +649,23 @@ spec: - python - -c - | + import asyncio from enterprise.storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus - from enterprise.storage.database import session_maker + from enterprise.storage.database import a_session_maker from enterprise.server.telemetry.collection_processor import TelemetryCollectionProcessor - # Create collection task - processor = TelemetryCollectionProcessor() - task = MaintenanceTask() - task.set_processor(processor) - task.status = MaintenanceTaskStatus.PENDING + async def main(): + # Create collection task + processor = TelemetryCollectionProcessor() + task = MaintenanceTask() + task.set_processor(processor) + task.status = MaintenanceTaskStatus.PENDING - with session_maker() as session: - session.add(task) - session.commit() + async with a_session_maker() as session: + session.add(task) + await session.commit() + + asyncio.run(main()) restartPolicy: OnFailure ``` @@ -680,23 +703,27 @@ spec: - python - -c - | + import asyncio from enterprise.storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus - from enterprise.storage.database import session_maker + from enterprise.storage.database import a_session_maker from enterprise.server.telemetry.upload_processor import TelemetryUploadProcessor import os - # Create upload task - processor = TelemetryUploadProcessor( - replicated_publishable_key=os.getenv('REPLICATED_PUBLISHABLE_KEY'), - replicated_app_slug=os.getenv('REPLICATED_APP_SLUG', 'openhands-enterprise') - ) - task = MaintenanceTask() - task.set_processor(processor) - task.status = MaintenanceTaskStatus.PENDING + async def main(): + # Create upload task + processor = TelemetryUploadProcessor( + replicated_publishable_key=os.getenv('REPLICATED_PUBLISHABLE_KEY'), + replicated_app_slug=os.getenv('REPLICATED_APP_SLUG', 'openhands-enterprise') + ) + task = MaintenanceTask() + task.set_processor(processor) + task.status = MaintenanceTaskStatus.PENDING - with session_maker() as session: - session.add(task) - session.commit() + async with a_session_maker() as session: + session.add(task) + await session.commit() + + asyncio.run(main()) restartPolicy: OnFailure ``` diff --git a/enterprise/downgrade_migrated_users.py b/enterprise/downgrade_migrated_users.py deleted file mode 100644 index a9798476bc..0000000000 --- a/enterprise/downgrade_migrated_users.py +++ /dev/null @@ -1,207 +0,0 @@ -#!/usr/bin/env python -""" -This script can be removed once orgs is established - probably after Feb 15 2026 - -Downgrade script for migrated users. - -This script identifies users who have been migrated (already_migrated=True) -and reverts them back to the pre-migration state. - -Usage: - # Dry run - just list the users that would be downgraded - python downgrade_migrated_users.py --dry-run - - # Downgrade a specific user by their keycloak_user_id - python downgrade_migrated_users.py --user-id - - # Downgrade all migrated users (with confirmation) - python downgrade_migrated_users.py --all - - # Downgrade all migrated users without confirmation (dangerous!) - python downgrade_migrated_users.py --all --no-confirm -""" - -import argparse -import asyncio -import sys - -# Add the enterprise directory to the path -sys.path.insert(0, '/workspace/project/OpenHands/enterprise') - -from server.logger import logger -from sqlalchemy import select, text -from storage.database import session_maker -from storage.user_settings import UserSettings -from storage.user_store import UserStore - - -def get_migrated_users() -> list[str]: - """Get list of keycloak_user_ids for users who have been migrated. - - This includes: - 1. Users with already_migrated=True in user_settings (migrated users) - 2. Users in the 'user' table who don't have a user_settings entry (new sign-ups) - """ - with session_maker() as session: - # Get users from user_settings with already_migrated=True - migrated_result = session.execute( - select(UserSettings.keycloak_user_id).where( - UserSettings.already_migrated.is_(True) - ) - ) - migrated_users = {row[0] for row in migrated_result.fetchall() if row[0]} - - # Get users from the 'user' table (new sign-ups won't have user_settings) - # These are users who signed up after the migration was deployed - new_signup_result = session.execute( - text(""" - SELECT CAST(u.id AS VARCHAR) - FROM "user" u - WHERE NOT EXISTS ( - SELECT 1 FROM user_settings us - WHERE us.keycloak_user_id = CAST(u.id AS VARCHAR) - ) - """) - ) - new_signups = {row[0] for row in new_signup_result.fetchall() if row[0]} - - # Combine both sets - all_users = migrated_users | new_signups - return list(all_users) - - -async def downgrade_user(user_id: str) -> bool: - """Downgrade a single user. - - Args: - user_id: The keycloak_user_id to downgrade - - Returns: - True if successful, False otherwise - """ - try: - result = await UserStore.downgrade_user(user_id) - if result: - print(f'✓ Successfully downgraded user: {user_id}') - return True - else: - print(f'✗ Failed to downgrade user: {user_id}') - return False - except Exception as e: - print(f'✗ Error downgrading user {user_id}: {e}') - logger.exception( - 'downgrade_script:error', - extra={'user_id': user_id, 'error': str(e)}, - ) - return False - - -async def main(): - parser = argparse.ArgumentParser( - description='Downgrade migrated users back to pre-migration state' - ) - parser.add_argument( - '--dry-run', - action='store_true', - help='Just list users that would be downgraded, without making changes', - ) - parser.add_argument( - '--user-id', - type=str, - help='Downgrade a specific user by keycloak_user_id', - ) - parser.add_argument( - '--all', - action='store_true', - help='Downgrade all migrated users', - ) - parser.add_argument( - '--no-confirm', - action='store_true', - help='Skip confirmation prompt (use with caution!)', - ) - - args = parser.parse_args() - - # Get list of migrated users - migrated_users = get_migrated_users() - print(f'\nFound {len(migrated_users)} migrated user(s).') - - if args.dry_run: - print('\n--- DRY RUN MODE ---') - print('The following users would be downgraded:') - for user_id in migrated_users: - print(f' - {user_id}') - print('\nNo changes were made.') - return - - if args.user_id: - # Downgrade a specific user - if args.user_id not in migrated_users: - print(f'\nUser {args.user_id} is not in the migrated users list.') - print('Either the user was not migrated, or the user_id is incorrect.') - return - - print(f'\nDowngrading user: {args.user_id}') - if not args.no_confirm: - confirm = input('Are you sure? (yes/no): ') - if confirm.lower() != 'yes': - print('Cancelled.') - return - - success = await downgrade_user(args.user_id) - if success: - print('\nDowngrade completed successfully.') - else: - print('\nDowngrade failed. Check logs for details.') - sys.exit(1) - - elif args.all: - # Downgrade all migrated users - if not migrated_users: - print('\nNo migrated users to downgrade.') - return - - print(f'\n⚠️ About to downgrade {len(migrated_users)} user(s).') - if not args.no_confirm: - print('\nThis will:') - print(' - Revert LiteLLM team/user budget settings') - print(' - Delete organization entries') - print(' - Delete user entries in the new schema') - print(' - Reset the already_migrated flag') - print('\nUsers to downgrade:') - for user_id in migrated_users[:10]: # Show first 10 - print(f' - {user_id}') - if len(migrated_users) > 10: - print(f' ... and {len(migrated_users) - 10} more') - - confirm = input('\nType "yes" to proceed: ') - if confirm.lower() != 'yes': - print('Cancelled.') - return - - print('\nStarting downgrade...\n') - success_count = 0 - fail_count = 0 - - for user_id in migrated_users: - success = await downgrade_user(user_id) - if success: - success_count += 1 - else: - fail_count += 1 - - print('\n--- Summary ---') - print(f'Successful: {success_count}') - print(f'Failed: {fail_count}') - - if fail_count > 0: - sys.exit(1) - - else: - parser.print_help() - print('\nPlease specify --dry-run, --user-id, or --all') - - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/enterprise/saas_server.py b/enterprise/saas_server.py index 4fd2a6b569..106ca93200 100644 --- a/enterprise/saas_server.py +++ b/enterprise/saas_server.py @@ -27,7 +27,6 @@ from server.rate_limit import setup_rate_limit_handler # noqa: E402 from server.routes.api_keys import api_router as api_keys_router # noqa: E402 from server.routes.auth import api_router, oauth_router # noqa: E402 from server.routes.billing import billing_router # noqa: E402 -from server.routes.debugging import add_debugging_routes # noqa: E402 from server.routes.email import api_router as email_router # noqa: E402 from server.routes.event_webhook import event_webhook_router # noqa: E402 from server.routes.feedback import router as feedback_router # noqa: E402 @@ -124,9 +123,6 @@ override_llm_models_dependency(base_app) base_app.include_router(invitation_router) # Add routes for org invitation management base_app.include_router(invitation_accept_router) # Add route for accepting invitations add_github_proxy_routes(base_app) -add_debugging_routes( - base_app -) # Add diagnostic routes for testing and debugging (disabled in production) base_app.include_router(slack_router) if ENABLE_JIRA: base_app.include_router(jira_integration_router) diff --git a/enterprise/server/clustered_conversation_manager.py b/enterprise/server/clustered_conversation_manager.py index 69ac2f3bd6..b8b6e04b63 100644 --- a/enterprise/server/clustered_conversation_manager.py +++ b/enterprise/server/clustered_conversation_manager.py @@ -7,7 +7,8 @@ from uuid import uuid4 import socketio from server.logger import logger from server.utils.conversation_callback_utils import invoke_conversation_callbacks -from storage.database import session_maker +from sqlalchemy import select +from storage.database import a_session_maker from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas from openhands.core.config import LLMConfig @@ -523,15 +524,14 @@ class ClusteredConversationManager(StandaloneConversationManager): f'local_connection_to_stopped_conversation:{connection_id}:{conversation_id}' ) # Look up the user_id from the database - with session_maker() as session: - conversation_metadata_saas = ( - session.query(StoredConversationMetadataSaas) - .filter( + async with a_session_maker() as session: + result = await session.execute( + select(StoredConversationMetadataSaas).where( StoredConversationMetadataSaas.conversation_id == conversation_id ) - .first() ) + conversation_metadata_saas = result.scalars().first() user_id = ( str(conversation_metadata_saas.user_id) if conversation_metadata_saas diff --git a/enterprise/server/routes/debugging.py b/enterprise/server/routes/debugging.py deleted file mode 100644 index cb49254976..0000000000 --- a/enterprise/server/routes/debugging.py +++ /dev/null @@ -1,163 +0,0 @@ -import asyncio -import os -import time -from threading import Thread - -from fastapi import APIRouter, FastAPI -from sqlalchemy import func, select -from storage.database import a_session_maker, get_engine, session_maker -from storage.user import User - -from openhands.core.logger import openhands_logger as logger -from openhands.utils.async_utils import wait_all - -# Safety flag to prevent chaos routes from being added in production environments -# Only enables these routes in non-production environments -ADD_DEBUGGING_ROUTES = os.environ.get('ADD_DEBUGGING_ROUTES') in ('1', 'true') - - -def add_debugging_routes(api: FastAPI): - """ - # HERE BE DRAGONS! - Chaos scripts for debugging and stress testing the system. - - This module contains endpoints that deliberately stress test and potentially break - the system to help identify weaknesses and bottlenecks. It includes a safety check - to ensure these routes are never deployed to production environments. - - The routes in this module are specifically designed for: - - Testing connection pool behavior under load - - Simulating database connection exhaustion - - Testing async vs sync database access patterns - - Simulating event loop blocking - """ - - if not ADD_DEBUGGING_ROUTES: - return - - chaos_router = APIRouter(prefix='/debugging') - - @chaos_router.get('/pool-stats') - def pool_stats() -> dict[str, int]: - """ - Returns current database connection pool statistics. - - This endpoint provides real-time metrics about the SQLAlchemy connection pool: - - checked_in: Number of connections currently available in the pool - - checked_out: Number of connections currently in use - - overflow: Number of overflow connections created beyond pool_size - """ - engine = get_engine() - return { - 'checked_in': engine.pool.checkedin(), - 'checked_out': engine.pool.checkedout(), - 'overflow': engine.pool.overflow(), - } - - @chaos_router.get('/test-db') - def test_db(num_tests: int = 10, delay: int = 1) -> str: - """ - Stress tests the database connection pool using multiple threads. - - Creates multiple threads that each open a database connection, perform a query, - hold the connection for the specified delay, and then release it. - - Parameters: - num_tests: Number of concurrent database connections to create - delay: Number of seconds each connection is held open - - This test helps identify connection pool exhaustion issues and connection - leaks under concurrent load. - """ - threads = [Thread(target=_db_check, args=(delay,)) for _ in range(num_tests)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - return 'success' - - @chaos_router.get('/a-test-db') - async def a_chaos_monkey(num_tests: int = 10, delay: int = 1) -> str: - """ - Stress tests the async database connection pool. - - Similar to /test-db but uses async connections and coroutines instead of threads. - This endpoint helps compare the behavior of async vs sync connection pools - under similar load conditions. - - Parameters: - num_tests: Number of concurrent async database connections to create - delay: Number of seconds each connection is held open - """ - await wait_all((_a_db_check(delay) for _ in range(num_tests))) - return 'success' - - @chaos_router.get('/lock-main-runloop') - async def lock_main_runloop(duration: int = 10) -> str: - """ - Deliberately blocks the main asyncio event loop. - - This endpoint uses a synchronous sleep operation in an async function, - which blocks the entire FastAPI server's event loop for the specified duration. - This simulates what happens when CPU-intensive operations or blocking I/O - operations are incorrectly used in async code. - - Parameters: - duration: Number of seconds to block the event loop - - WARNING: This will make the entire server unresponsive for the duration! - """ - time.sleep(duration) - return 'success' - - api.include_router(chaos_router) # Add routes for readiness checks - - -def _db_check(delay: int): - """ - Executes a single request against the database with an artificial delay. - - This helper function: - 1. Opens a database connection from the pool - 2. Executes a simple query to count users - 3. Holds the connection for the specified delay - 4. Logs connection pool statistics - 5. Implicitly returns the connection to the pool when the session closes - - Args: - delay: Number of seconds to hold the database connection - """ - with session_maker() as session: - num_users = session.query(User).count() - time.sleep(delay) - engine = get_engine() - logger.info( - 'check', - extra={ - 'num_users': num_users, - 'checked_in': engine.pool.checkedin(), - 'checked_out': engine.pool.checkedout(), - 'overflow': engine.pool.overflow(), - }, - ) - - -async def _a_db_check(delay: int): - """ - Executes a single async request against the database with an artificial delay. - - This is the async version of _db_check that: - 1. Opens an async database connection from the pool - 2. Executes a simple query to count users using SQLAlchemy's async API - 3. Holds the connection for the specified delay using asyncio.sleep - 4. Logs the results - 5. Implicitly returns the connection to the pool when the async session closes - - Args: - delay: Number of seconds to hold the database connection - """ - async with a_session_maker() as a_session: - stmt = select(func.count(User.id)) - num_users = await a_session.execute(stmt) - await asyncio.sleep(delay) - logger.info(f'a_num_users:{num_users.scalar_one()}') diff --git a/enterprise/server/routes/integration/slack.py b/enterprise/server/routes/integration/slack.py index 3cda0bcb9a..e0d7f53f46 100644 --- a/enterprise/server/routes/integration/slack.py +++ b/enterprise/server/routes/integration/slack.py @@ -31,7 +31,8 @@ from server.logger import logger from slack_sdk.oauth import AuthorizeUrlGenerator from slack_sdk.signature import SignatureVerifier from slack_sdk.web.async_client import AsyncWebClient -from storage.database import session_maker +from sqlalchemy import delete +from storage.database import a_session_maker from storage.slack_team_store import SlackTeamStore from storage.slack_user import SlackUser from storage.user_store import UserStore @@ -239,15 +240,15 @@ async def keycloak_callback( slack_display_name=slack_display_name, ) - with session_maker(expire_on_commit=False) as session: + async with a_session_maker(expire_on_commit=False) as session: # First delete any existing tokens - session.query(SlackUser).filter( - SlackUser.slack_user_id == slack_user_id - ).delete() + await session.execute( + delete(SlackUser).where(SlackUser.slack_user_id == slack_user_id) + ) # Store the token session.add(slack_user) - session.commit() + await session.commit() message = Message(source=SourceType.SLACK, message=payload) diff --git a/enterprise/server/saas_nested_conversation_manager.py b/enterprise/server/saas_nested_conversation_manager.py index d4479da0b2..be5f787b10 100644 --- a/enterprise/server/saas_nested_conversation_manager.py +++ b/enterprise/server/saas_nested_conversation_manager.py @@ -19,9 +19,9 @@ from server.utils.conversation_callback_utils import ( process_event, update_conversation_metadata, ) -from sqlalchemy import orm +from sqlalchemy import select from storage.api_key_store import ApiKeyStore -from storage.database import session_maker +from storage.database import a_session_maker from storage.stored_conversation_metadata import StoredConversationMetadata from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas @@ -59,7 +59,6 @@ from openhands.storage.locations import ( get_conversation_event_filename, get_conversation_events_dir, ) -from openhands.utils.async_utils import call_sync_from_async from openhands.utils.http_session import httpx_verify_option from openhands.utils.import_utils import get_impl from openhands.utils.shutdown_listener import should_continue @@ -166,8 +165,8 @@ class SaasNestedConversationManager(ConversationManager): } if user_id: - user_conversation_ids = await call_sync_from_async( - self._get_recent_conversation_ids_for_user, user_id + user_conversation_ids = await self._get_recent_conversation_ids_for_user( + user_id ) conversation_ids = conversation_ids.intersection(user_conversation_ids) @@ -643,19 +642,18 @@ class SaasNestedConversationManager(ConversationManager): }, ) - def _get_user_id_from_conversation(self, conversation_id: str) -> str: + async def _get_user_id_from_conversation(self, conversation_id: str) -> str: """ Get user_id from conversation_id. """ - with session_maker() as session: - conversation_metadata_saas = ( - session.query(StoredConversationMetadataSaas) - .filter( + async with a_session_maker() as session: + result = await session.execute( + select(StoredConversationMetadataSaas).where( StoredConversationMetadataSaas.conversation_id == conversation_id ) - .first() ) + conversation_metadata_saas = result.scalars().first() if not conversation_metadata_saas: raise ValueError(f'No conversation found {conversation_id}') @@ -753,8 +751,8 @@ class SaasNestedConversationManager(ConversationManager): user_id_for_convo = user_id if not user_id_for_convo: try: - user_id_for_convo = await call_sync_from_async( - self._get_user_id_from_conversation, conversation_id + user_id_for_convo = await self._get_user_id_from_conversation( + conversation_id ) except Exception: continue @@ -995,23 +993,23 @@ class SaasNestedConversationManager(ConversationManager): } return conversation_ids - def _get_recent_conversation_ids_for_user(self, user_id: str) -> set[str]: - with session_maker() as session: + async def _get_recent_conversation_ids_for_user(self, user_id: str) -> set[str]: + async with a_session_maker() as session: # Only include conversations updated in the past week one_week_ago = datetime.now(UTC) - timedelta(days=7) - query = ( - session.query(StoredConversationMetadata.conversation_id) + result = await session.execute( + select(StoredConversationMetadata.conversation_id) .join( StoredConversationMetadataSaas, StoredConversationMetadata.conversation_id == StoredConversationMetadataSaas.conversation_id, ) - .filter( + .where( StoredConversationMetadataSaas.user_id == user_id, StoredConversationMetadata.last_updated_at >= one_week_ago, ) ) - user_conversation_ids = set(query) + user_conversation_ids = set(result.scalars().all()) return user_conversation_ids async def _get_runtime(self, sid: str) -> dict | None: @@ -1055,14 +1053,13 @@ class SaasNestedConversationManager(ConversationManager): await asyncio.sleep(_POLLING_INTERVAL) agent_loop_infos = await self.get_agent_loop_info() - with session_maker() as session: - for agent_loop_info in agent_loop_infos: - if agent_loop_info.status != ConversationStatus.RUNNING: - continue - try: - await self._poll_agent_loop_events(agent_loop_info, session) - except Exception as e: - logger.exception(f'error_polling_events:{str(e)}') + for agent_loop_info in agent_loop_infos: + if agent_loop_info.status != ConversationStatus.RUNNING: + continue + try: + await self._poll_agent_loop_events(agent_loop_info) + except Exception as e: + logger.exception(f'error_polling_events:{str(e)}') except Exception as e: try: asyncio.get_running_loop() @@ -1071,23 +1068,27 @@ class SaasNestedConversationManager(ConversationManager): # Loop has been shut down, exit gracefully return - async def _poll_agent_loop_events( - self, agent_loop_info: AgentLoopInfo, session: orm.Session - ): + async def _poll_agent_loop_events(self, agent_loop_info: AgentLoopInfo): """This method is typically only run in localhost, where the webhook callbacks from the remote runtime are unavailable""" if agent_loop_info.status != ConversationStatus.RUNNING: return conversation_id = agent_loop_info.conversation_id - conversation_metadata = ( - session.query(StoredConversationMetadata) - .filter(StoredConversationMetadata.conversation_id == conversation_id) - .first() - ) - conversation_metadata_saas = ( - session.query(StoredConversationMetadataSaas) - .filter(StoredConversationMetadataSaas.conversation_id == conversation_id) - .first() - ) + + async with a_session_maker() as session: + result = await session.execute( + select(StoredConversationMetadata).where( + StoredConversationMetadata.conversation_id == conversation_id + ) + ) + conversation_metadata = result.scalars().first() + + result = await session.execute( + select(StoredConversationMetadataSaas).where( + StoredConversationMetadataSaas.conversation_id == conversation_id + ) + ) + conversation_metadata_saas = result.scalars().first() + if conversation_metadata is None or conversation_metadata_saas is None: # Conversation is running in different server return diff --git a/enterprise/tests/unit/test_clustered_conversation_manager.py b/enterprise/tests/unit/test_clustered_conversation_manager.py index fefa29732d..0503d360cf 100644 --- a/enterprise/tests/unit/test_clustered_conversation_manager.py +++ b/enterprise/tests/unit/test_clustered_conversation_manager.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import json import time from dataclasses import dataclass @@ -444,11 +445,19 @@ async def test_disconnect_from_stopped_with_stopped_remote(): # Create a mock SIO with scan results for only remote_session1 sio = get_mock_sio(scan_keys=[b'ohcnv:user1:remote_session1']) - # Mock the database connection to avoid actual database connections - db_mock = MagicMock() - db_session_mock = MagicMock() - db_mock.__enter__.return_value = db_session_mock - session_maker_mock = MagicMock(return_value=db_mock) + # Mock the async database session + mock_user = MagicMock() + mock_user.user_id = 'user1' + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_user + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + @contextlib.asynccontextmanager + async def mock_a_session_maker(): + yield mock_session with ( patch( @@ -456,8 +465,8 @@ async def test_disconnect_from_stopped_with_stopped_remote(): AsyncMock(), ), patch( - 'server.clustered_conversation_manager.session_maker', - session_maker_mock, + 'server.clustered_conversation_manager.a_session_maker', + mock_a_session_maker, ), patch('asyncio.create_task', MagicMock()), ): @@ -484,11 +493,6 @@ async def test_disconnect_from_stopped_with_stopped_remote(): MagicMock() ) - # Create a mock for the database query result - mock_user = MagicMock() - mock_user.user_id = 'user1' - db_session_mock.query.return_value.filter.return_value.first.return_value = mock_user - # Mock the _handle_remote_conversation_stopped method with the correct signature conversation_manager._handle_remote_conversation_stopped = AsyncMock()