Refactor enterprise code to use async database sessions (Round 3) (#13148)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-03-03 06:35:19 -07:00
committed by GitHub
parent 4a3a42c858
commit f3026583d7
9 changed files with 167 additions and 508 deletions

View File

@@ -200,7 +200,7 @@ class MetricsCollector(ABC):
"""Base class for metrics collectors."""
@abstractmethod
def collect(self) -> List[MetricResult]:
async def collect(self) -> List[MetricResult]:
"""Collect metrics and return results."""
pass
@@ -264,12 +264,13 @@ class SystemMetricsCollector(MetricsCollector):
def collector_name(self) -> str:
return "system_metrics"
def collect(self) -> List[MetricResult]:
async def collect(self) -> List[MetricResult]:
results = []
# Collect user count
with session_maker() as session:
user_count = session.query(UserSettings).count()
async with a_session_maker() as session:
user_count_result = await session.execute(select(func.count()).select_from(UserSettings))
user_count = user_count_result.scalar()
results.append(MetricResult(
key="total_users",
value=user_count
@@ -277,9 +278,11 @@ class SystemMetricsCollector(MetricsCollector):
# Collect conversation count (last 30 days)
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
conversation_count = session.query(StoredConversationMetadata)\
.filter(StoredConversationMetadata.created_at >= thirty_days_ago)\
.count()
conversation_count_result = await session.execute(
select(func.count()).select_from(StoredConversationMetadata)
.where(StoredConversationMetadata.created_at >= thirty_days_ago)
)
conversation_count = conversation_count_result.scalar()
results.append(MetricResult(
key="conversations_30d",
@@ -303,7 +306,7 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
"""Collect metrics from all registered collectors."""
# Check if collection is needed
if not self._should_collect():
if not await self._should_collect():
return {"status": "skipped", "reason": "too_recent"}
# Collect metrics from all registered collectors
@@ -313,7 +316,7 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
for collector in collector_registry.get_all_collectors():
try:
if collector.should_collect():
results = collector.collect()
results = await collector.collect()
for result in results:
all_metrics[result.key] = result.value
collector_results[collector.collector_name] = len(results)
@@ -322,13 +325,13 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
collector_results[collector.collector_name] = f"error: {e}"
# Store metrics in database
with session_maker() as session:
async with a_session_maker() as session:
telemetry_record = TelemetryMetrics(
metrics_data=all_metrics,
collected_at=datetime.now(timezone.utc)
)
session.add(telemetry_record)
session.commit()
await session.commit()
# Note: No need to track last_collection_at separately
# Can be derived from MAX(collected_at) in telemetry_metrics
@@ -339,11 +342,12 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
"collectors_run": collector_results
}
def _should_collect(self) -> bool:
async def _should_collect(self) -> bool:
"""Check if collection is needed based on interval."""
with session_maker() as session:
async with a_session_maker() as session:
# Get last collection time from metrics table
last_collected = session.query(func.max(TelemetryMetrics.collected_at)).scalar()
result = await session.execute(select(func.max(TelemetryMetrics.collected_at)))
last_collected = result.scalar()
if not last_collected:
return True
@@ -366,17 +370,19 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
"""Upload pending metrics to Replicated."""
# Get pending metrics
with session_maker() as session:
pending_metrics = session.query(TelemetryMetrics)\
.filter(TelemetryMetrics.uploaded_at.is_(None))\
.order_by(TelemetryMetrics.collected_at)\
.all()
async with a_session_maker() as session:
result = await session.execute(
select(TelemetryMetrics)
.where(TelemetryMetrics.uploaded_at.is_(None))
.order_by(TelemetryMetrics.collected_at)
)
pending_metrics = result.scalars().all()
if not pending_metrics:
return {"status": "no_pending_metrics"}
# Get admin email - skip if not available
admin_email = self._get_admin_email()
admin_email = await self._get_admin_email()
if not admin_email:
logger.info("Skipping telemetry upload - no admin email available")
return {
@@ -413,13 +419,15 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
await instance.set_status(InstanceStatus.RUNNING)
# Mark as uploaded
with session_maker() as session:
record = session.query(TelemetryMetrics)\
.filter(TelemetryMetrics.id == metric_record.id)\
.first()
async with a_session_maker() as session:
result = await session.execute(
select(TelemetryMetrics)
.where(TelemetryMetrics.id == metric_record.id)
)
record = result.scalar_one_or_none()
if record:
record.uploaded_at = datetime.now(timezone.utc)
session.commit()
await session.commit()
uploaded_count += 1
@@ -427,14 +435,16 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
logger.error(f"Failed to upload metrics {metric_record.id}: {e}")
# Update error info
with session_maker() as session:
record = session.query(TelemetryMetrics)\
.filter(TelemetryMetrics.id == metric_record.id)\
.first()
async with a_session_maker() as session:
result = await session.execute(
select(TelemetryMetrics)
.where(TelemetryMetrics.id == metric_record.id)
)
record = result.scalar_one_or_none()
if record:
record.upload_attempts += 1
record.last_upload_error = str(e)
session.commit()
await session.commit()
failed_count += 1
@@ -448,7 +458,7 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
"total_processed": len(pending_metrics)
}
def _get_admin_email(self) -> str | None:
async def _get_admin_email(self) -> str | None:
"""Get administrator email for customer identification."""
# 1. Check environment variable first
env_admin_email = os.getenv('OPENHANDS_ADMIN_EMAIL')
@@ -457,12 +467,15 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
return env_admin_email
# 2. Use first active user's email (earliest accepted_tos)
with session_maker() as session:
first_user = session.query(UserSettings)\
.filter(UserSettings.email.isnot(None))\
.filter(UserSettings.accepted_tos.isnot(None))\
.order_by(UserSettings.accepted_tos.asc())\
.first()
async with a_session_maker() as session:
result = await session.execute(
select(UserSettings)
.where(UserSettings.email.isnot(None))
.where(UserSettings.accepted_tos.isnot(None))
.order_by(UserSettings.accepted_tos.asc())
.limit(1)
)
first_user = result.scalar_one_or_none()
if first_user and first_user.email:
logger.info(f"Using first active user email: {first_user.email}")
@@ -474,15 +487,16 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
async def _update_telemetry_identity(self, customer_id: str, instance_id: str) -> None:
"""Update or create telemetry identity record."""
with session_maker() as session:
identity = session.query(TelemetryIdentity).first()
async with a_session_maker() as session:
result = await session.execute(select(TelemetryIdentity).limit(1))
identity = result.scalar_one_or_none()
if not identity:
identity = TelemetryIdentity()
session.add(identity)
identity.customer_id = customer_id
identity.instance_id = instance_id
session.commit()
await session.commit()
```
### 4.4 License Warning System
@@ -503,11 +517,13 @@ async def get_license_status():
if not _is_openhands_enterprise():
return {"warn": False, "message": ""}
with session_maker() as session:
async with a_session_maker() as session:
# Get last successful upload time from metrics table
last_upload = session.query(func.max(TelemetryMetrics.uploaded_at))\
.filter(TelemetryMetrics.uploaded_at.isnot(None))\
.scalar()
result = await session.execute(
select(func.max(TelemetryMetrics.uploaded_at))
.where(TelemetryMetrics.uploaded_at.isnot(None))
)
last_upload = result.scalar()
if not last_upload:
# No successful uploads yet - show warning after 4 days
@@ -521,10 +537,13 @@ async def get_license_status():
if days_since_upload > 4:
# Find oldest unsent batch
oldest_unsent = session.query(TelemetryMetrics)\
.filter(TelemetryMetrics.uploaded_at.is_(None))\
.order_by(TelemetryMetrics.collected_at)\
.first()
result = await session.execute(
select(TelemetryMetrics)
.where(TelemetryMetrics.uploaded_at.is_(None))
.order_by(TelemetryMetrics.collected_at)
.limit(1)
)
oldest_unsent = result.scalar_one_or_none()
if oldest_unsent:
# Calculate expiration date (oldest unsent + 34 days)
@@ -630,19 +649,23 @@ spec:
- python
- -c
- |
import asyncio
from enterprise.storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from enterprise.storage.database import session_maker
from enterprise.storage.database import a_session_maker
from enterprise.server.telemetry.collection_processor import TelemetryCollectionProcessor
# Create collection task
processor = TelemetryCollectionProcessor()
task = MaintenanceTask()
task.set_processor(processor)
task.status = MaintenanceTaskStatus.PENDING
async def main():
# Create collection task
processor = TelemetryCollectionProcessor()
task = MaintenanceTask()
task.set_processor(processor)
task.status = MaintenanceTaskStatus.PENDING
with session_maker() as session:
session.add(task)
session.commit()
async with a_session_maker() as session:
session.add(task)
await session.commit()
asyncio.run(main())
restartPolicy: OnFailure
```
@@ -680,23 +703,27 @@ spec:
- python
- -c
- |
import asyncio
from enterprise.storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from enterprise.storage.database import session_maker
from enterprise.storage.database import a_session_maker
from enterprise.server.telemetry.upload_processor import TelemetryUploadProcessor
import os
# Create upload task
processor = TelemetryUploadProcessor(
replicated_publishable_key=os.getenv('REPLICATED_PUBLISHABLE_KEY'),
replicated_app_slug=os.getenv('REPLICATED_APP_SLUG', 'openhands-enterprise')
)
task = MaintenanceTask()
task.set_processor(processor)
task.status = MaintenanceTaskStatus.PENDING
async def main():
# Create upload task
processor = TelemetryUploadProcessor(
replicated_publishable_key=os.getenv('REPLICATED_PUBLISHABLE_KEY'),
replicated_app_slug=os.getenv('REPLICATED_APP_SLUG', 'openhands-enterprise')
)
task = MaintenanceTask()
task.set_processor(processor)
task.status = MaintenanceTaskStatus.PENDING
with session_maker() as session:
session.add(task)
session.commit()
async with a_session_maker() as session:
session.add(task)
await session.commit()
asyncio.run(main())
restartPolicy: OnFailure
```

View File

@@ -1,207 +0,0 @@
#!/usr/bin/env python
"""
This script can be removed once orgs is established - probably after Feb 15 2026
Downgrade script for migrated users.
This script identifies users who have been migrated (already_migrated=True)
and reverts them back to the pre-migration state.
Usage:
# Dry run - just list the users that would be downgraded
python downgrade_migrated_users.py --dry-run
# Downgrade a specific user by their keycloak_user_id
python downgrade_migrated_users.py --user-id <user_id>
# Downgrade all migrated users (with confirmation)
python downgrade_migrated_users.py --all
# Downgrade all migrated users without confirmation (dangerous!)
python downgrade_migrated_users.py --all --no-confirm
"""
import argparse
import asyncio
import sys
# Add the enterprise directory to the path
sys.path.insert(0, '/workspace/project/OpenHands/enterprise')
from server.logger import logger
from sqlalchemy import select, text
from storage.database import session_maker
from storage.user_settings import UserSettings
from storage.user_store import UserStore
def get_migrated_users() -> list[str]:
"""Get list of keycloak_user_ids for users who have been migrated.
This includes:
1. Users with already_migrated=True in user_settings (migrated users)
2. Users in the 'user' table who don't have a user_settings entry (new sign-ups)
"""
with session_maker() as session:
# Get users from user_settings with already_migrated=True
migrated_result = session.execute(
select(UserSettings.keycloak_user_id).where(
UserSettings.already_migrated.is_(True)
)
)
migrated_users = {row[0] for row in migrated_result.fetchall() if row[0]}
# Get users from the 'user' table (new sign-ups won't have user_settings)
# These are users who signed up after the migration was deployed
new_signup_result = session.execute(
text("""
SELECT CAST(u.id AS VARCHAR)
FROM "user" u
WHERE NOT EXISTS (
SELECT 1 FROM user_settings us
WHERE us.keycloak_user_id = CAST(u.id AS VARCHAR)
)
""")
)
new_signups = {row[0] for row in new_signup_result.fetchall() if row[0]}
# Combine both sets
all_users = migrated_users | new_signups
return list(all_users)
async def downgrade_user(user_id: str) -> bool:
"""Downgrade a single user.
Args:
user_id: The keycloak_user_id to downgrade
Returns:
True if successful, False otherwise
"""
try:
result = await UserStore.downgrade_user(user_id)
if result:
print(f'✓ Successfully downgraded user: {user_id}')
return True
else:
print(f'✗ Failed to downgrade user: {user_id}')
return False
except Exception as e:
print(f'✗ Error downgrading user {user_id}: {e}')
logger.exception(
'downgrade_script:error',
extra={'user_id': user_id, 'error': str(e)},
)
return False
async def main():
parser = argparse.ArgumentParser(
description='Downgrade migrated users back to pre-migration state'
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Just list users that would be downgraded, without making changes',
)
parser.add_argument(
'--user-id',
type=str,
help='Downgrade a specific user by keycloak_user_id',
)
parser.add_argument(
'--all',
action='store_true',
help='Downgrade all migrated users',
)
parser.add_argument(
'--no-confirm',
action='store_true',
help='Skip confirmation prompt (use with caution!)',
)
args = parser.parse_args()
# Get list of migrated users
migrated_users = get_migrated_users()
print(f'\nFound {len(migrated_users)} migrated user(s).')
if args.dry_run:
print('\n--- DRY RUN MODE ---')
print('The following users would be downgraded:')
for user_id in migrated_users:
print(f' - {user_id}')
print('\nNo changes were made.')
return
if args.user_id:
# Downgrade a specific user
if args.user_id not in migrated_users:
print(f'\nUser {args.user_id} is not in the migrated users list.')
print('Either the user was not migrated, or the user_id is incorrect.')
return
print(f'\nDowngrading user: {args.user_id}')
if not args.no_confirm:
confirm = input('Are you sure? (yes/no): ')
if confirm.lower() != 'yes':
print('Cancelled.')
return
success = await downgrade_user(args.user_id)
if success:
print('\nDowngrade completed successfully.')
else:
print('\nDowngrade failed. Check logs for details.')
sys.exit(1)
elif args.all:
# Downgrade all migrated users
if not migrated_users:
print('\nNo migrated users to downgrade.')
return
print(f'\n⚠️ About to downgrade {len(migrated_users)} user(s).')
if not args.no_confirm:
print('\nThis will:')
print(' - Revert LiteLLM team/user budget settings')
print(' - Delete organization entries')
print(' - Delete user entries in the new schema')
print(' - Reset the already_migrated flag')
print('\nUsers to downgrade:')
for user_id in migrated_users[:10]: # Show first 10
print(f' - {user_id}')
if len(migrated_users) > 10:
print(f' ... and {len(migrated_users) - 10} more')
confirm = input('\nType "yes" to proceed: ')
if confirm.lower() != 'yes':
print('Cancelled.')
return
print('\nStarting downgrade...\n')
success_count = 0
fail_count = 0
for user_id in migrated_users:
success = await downgrade_user(user_id)
if success:
success_count += 1
else:
fail_count += 1
print('\n--- Summary ---')
print(f'Successful: {success_count}')
print(f'Failed: {fail_count}')
if fail_count > 0:
sys.exit(1)
else:
parser.print_help()
print('\nPlease specify --dry-run, --user-id, or --all')
if __name__ == '__main__':
asyncio.run(main())

View File

@@ -27,7 +27,6 @@ from server.rate_limit import setup_rate_limit_handler # noqa: E402
from server.routes.api_keys import api_router as api_keys_router # noqa: E402
from server.routes.auth import api_router, oauth_router # noqa: E402
from server.routes.billing import billing_router # noqa: E402
from server.routes.debugging import add_debugging_routes # noqa: E402
from server.routes.email import api_router as email_router # noqa: E402
from server.routes.event_webhook import event_webhook_router # noqa: E402
from server.routes.feedback import router as feedback_router # noqa: E402
@@ -124,9 +123,6 @@ override_llm_models_dependency(base_app)
base_app.include_router(invitation_router) # Add routes for org invitation management
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
add_github_proxy_routes(base_app)
add_debugging_routes(
base_app
) # Add diagnostic routes for testing and debugging (disabled in production)
base_app.include_router(slack_router)
if ENABLE_JIRA:
base_app.include_router(jira_integration_router)

View File

@@ -7,7 +7,8 @@ from uuid import uuid4
import socketio
from server.logger import logger
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
from storage.database import session_maker
from sqlalchemy import select
from storage.database import a_session_maker
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from openhands.core.config import LLMConfig
@@ -523,15 +524,14 @@ class ClusteredConversationManager(StandaloneConversationManager):
f'local_connection_to_stopped_conversation:{connection_id}:{conversation_id}'
)
# Look up the user_id from the database
with session_maker() as session:
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(StoredConversationMetadataSaas).where(
StoredConversationMetadataSaas.conversation_id
== conversation_id
)
.first()
)
conversation_metadata_saas = result.scalars().first()
user_id = (
str(conversation_metadata_saas.user_id)
if conversation_metadata_saas

View File

@@ -1,163 +0,0 @@
import asyncio
import os
import time
from threading import Thread
from fastapi import APIRouter, FastAPI
from sqlalchemy import func, select
from storage.database import a_session_maker, get_engine, session_maker
from storage.user import User
from openhands.core.logger import openhands_logger as logger
from openhands.utils.async_utils import wait_all
# Safety flag to prevent chaos routes from being added in production environments
# Only enables these routes in non-production environments
ADD_DEBUGGING_ROUTES = os.environ.get('ADD_DEBUGGING_ROUTES') in ('1', 'true')
def add_debugging_routes(api: FastAPI):
"""
# HERE BE DRAGONS!
Chaos scripts for debugging and stress testing the system.
This module contains endpoints that deliberately stress test and potentially break
the system to help identify weaknesses and bottlenecks. It includes a safety check
to ensure these routes are never deployed to production environments.
The routes in this module are specifically designed for:
- Testing connection pool behavior under load
- Simulating database connection exhaustion
- Testing async vs sync database access patterns
- Simulating event loop blocking
"""
if not ADD_DEBUGGING_ROUTES:
return
chaos_router = APIRouter(prefix='/debugging')
@chaos_router.get('/pool-stats')
def pool_stats() -> dict[str, int]:
"""
Returns current database connection pool statistics.
This endpoint provides real-time metrics about the SQLAlchemy connection pool:
- checked_in: Number of connections currently available in the pool
- checked_out: Number of connections currently in use
- overflow: Number of overflow connections created beyond pool_size
"""
engine = get_engine()
return {
'checked_in': engine.pool.checkedin(),
'checked_out': engine.pool.checkedout(),
'overflow': engine.pool.overflow(),
}
@chaos_router.get('/test-db')
def test_db(num_tests: int = 10, delay: int = 1) -> str:
"""
Stress tests the database connection pool using multiple threads.
Creates multiple threads that each open a database connection, perform a query,
hold the connection for the specified delay, and then release it.
Parameters:
num_tests: Number of concurrent database connections to create
delay: Number of seconds each connection is held open
This test helps identify connection pool exhaustion issues and connection
leaks under concurrent load.
"""
threads = [Thread(target=_db_check, args=(delay,)) for _ in range(num_tests)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
return 'success'
@chaos_router.get('/a-test-db')
async def a_chaos_monkey(num_tests: int = 10, delay: int = 1) -> str:
"""
Stress tests the async database connection pool.
Similar to /test-db but uses async connections and coroutines instead of threads.
This endpoint helps compare the behavior of async vs sync connection pools
under similar load conditions.
Parameters:
num_tests: Number of concurrent async database connections to create
delay: Number of seconds each connection is held open
"""
await wait_all((_a_db_check(delay) for _ in range(num_tests)))
return 'success'
@chaos_router.get('/lock-main-runloop')
async def lock_main_runloop(duration: int = 10) -> str:
"""
Deliberately blocks the main asyncio event loop.
This endpoint uses a synchronous sleep operation in an async function,
which blocks the entire FastAPI server's event loop for the specified duration.
This simulates what happens when CPU-intensive operations or blocking I/O
operations are incorrectly used in async code.
Parameters:
duration: Number of seconds to block the event loop
WARNING: This will make the entire server unresponsive for the duration!
"""
time.sleep(duration)
return 'success'
api.include_router(chaos_router) # Add routes for readiness checks
def _db_check(delay: int):
"""
Executes a single request against the database with an artificial delay.
This helper function:
1. Opens a database connection from the pool
2. Executes a simple query to count users
3. Holds the connection for the specified delay
4. Logs connection pool statistics
5. Implicitly returns the connection to the pool when the session closes
Args:
delay: Number of seconds to hold the database connection
"""
with session_maker() as session:
num_users = session.query(User).count()
time.sleep(delay)
engine = get_engine()
logger.info(
'check',
extra={
'num_users': num_users,
'checked_in': engine.pool.checkedin(),
'checked_out': engine.pool.checkedout(),
'overflow': engine.pool.overflow(),
},
)
async def _a_db_check(delay: int):
"""
Executes a single async request against the database with an artificial delay.
This is the async version of _db_check that:
1. Opens an async database connection from the pool
2. Executes a simple query to count users using SQLAlchemy's async API
3. Holds the connection for the specified delay using asyncio.sleep
4. Logs the results
5. Implicitly returns the connection to the pool when the async session closes
Args:
delay: Number of seconds to hold the database connection
"""
async with a_session_maker() as a_session:
stmt = select(func.count(User.id))
num_users = await a_session.execute(stmt)
await asyncio.sleep(delay)
logger.info(f'a_num_users:{num_users.scalar_one()}')

View File

@@ -31,7 +31,8 @@ from server.logger import logger
from slack_sdk.oauth import AuthorizeUrlGenerator
from slack_sdk.signature import SignatureVerifier
from slack_sdk.web.async_client import AsyncWebClient
from storage.database import session_maker
from sqlalchemy import delete
from storage.database import a_session_maker
from storage.slack_team_store import SlackTeamStore
from storage.slack_user import SlackUser
from storage.user_store import UserStore
@@ -239,15 +240,15 @@ async def keycloak_callback(
slack_display_name=slack_display_name,
)
with session_maker(expire_on_commit=False) as session:
async with a_session_maker(expire_on_commit=False) as session:
# First delete any existing tokens
session.query(SlackUser).filter(
SlackUser.slack_user_id == slack_user_id
).delete()
await session.execute(
delete(SlackUser).where(SlackUser.slack_user_id == slack_user_id)
)
# Store the token
session.add(slack_user)
session.commit()
await session.commit()
message = Message(source=SourceType.SLACK, message=payload)

View File

@@ -19,9 +19,9 @@ from server.utils.conversation_callback_utils import (
process_event,
update_conversation_metadata,
)
from sqlalchemy import orm
from sqlalchemy import select
from storage.api_key_store import ApiKeyStore
from storage.database import session_maker
from storage.database import a_session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
@@ -59,7 +59,6 @@ from openhands.storage.locations import (
get_conversation_event_filename,
get_conversation_events_dir,
)
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.http_session import httpx_verify_option
from openhands.utils.import_utils import get_impl
from openhands.utils.shutdown_listener import should_continue
@@ -166,8 +165,8 @@ class SaasNestedConversationManager(ConversationManager):
}
if user_id:
user_conversation_ids = await call_sync_from_async(
self._get_recent_conversation_ids_for_user, user_id
user_conversation_ids = await self._get_recent_conversation_ids_for_user(
user_id
)
conversation_ids = conversation_ids.intersection(user_conversation_ids)
@@ -643,19 +642,18 @@ class SaasNestedConversationManager(ConversationManager):
},
)
def _get_user_id_from_conversation(self, conversation_id: str) -> str:
async def _get_user_id_from_conversation(self, conversation_id: str) -> str:
"""
Get user_id from conversation_id.
"""
with session_maker() as session:
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(StoredConversationMetadataSaas).where(
StoredConversationMetadataSaas.conversation_id == conversation_id
)
.first()
)
conversation_metadata_saas = result.scalars().first()
if not conversation_metadata_saas:
raise ValueError(f'No conversation found {conversation_id}')
@@ -753,8 +751,8 @@ class SaasNestedConversationManager(ConversationManager):
user_id_for_convo = user_id
if not user_id_for_convo:
try:
user_id_for_convo = await call_sync_from_async(
self._get_user_id_from_conversation, conversation_id
user_id_for_convo = await self._get_user_id_from_conversation(
conversation_id
)
except Exception:
continue
@@ -995,23 +993,23 @@ class SaasNestedConversationManager(ConversationManager):
}
return conversation_ids
def _get_recent_conversation_ids_for_user(self, user_id: str) -> set[str]:
with session_maker() as session:
async def _get_recent_conversation_ids_for_user(self, user_id: str) -> set[str]:
async with a_session_maker() as session:
# Only include conversations updated in the past week
one_week_ago = datetime.now(UTC) - timedelta(days=7)
query = (
session.query(StoredConversationMetadata.conversation_id)
result = await session.execute(
select(StoredConversationMetadata.conversation_id)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.filter(
.where(
StoredConversationMetadataSaas.user_id == user_id,
StoredConversationMetadata.last_updated_at >= one_week_ago,
)
)
user_conversation_ids = set(query)
user_conversation_ids = set(result.scalars().all())
return user_conversation_ids
async def _get_runtime(self, sid: str) -> dict | None:
@@ -1055,14 +1053,13 @@ class SaasNestedConversationManager(ConversationManager):
await asyncio.sleep(_POLLING_INTERVAL)
agent_loop_infos = await self.get_agent_loop_info()
with session_maker() as session:
for agent_loop_info in agent_loop_infos:
if agent_loop_info.status != ConversationStatus.RUNNING:
continue
try:
await self._poll_agent_loop_events(agent_loop_info, session)
except Exception as e:
logger.exception(f'error_polling_events:{str(e)}')
for agent_loop_info in agent_loop_infos:
if agent_loop_info.status != ConversationStatus.RUNNING:
continue
try:
await self._poll_agent_loop_events(agent_loop_info)
except Exception as e:
logger.exception(f'error_polling_events:{str(e)}')
except Exception as e:
try:
asyncio.get_running_loop()
@@ -1071,23 +1068,27 @@ class SaasNestedConversationManager(ConversationManager):
# Loop has been shut down, exit gracefully
return
async def _poll_agent_loop_events(
self, agent_loop_info: AgentLoopInfo, session: orm.Session
):
async def _poll_agent_loop_events(self, agent_loop_info: AgentLoopInfo):
"""This method is typically only run in localhost, where the webhook callbacks from the remote runtime are unavailable"""
if agent_loop_info.status != ConversationStatus.RUNNING:
return
conversation_id = agent_loop_info.conversation_id
conversation_metadata = (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.conversation_id == conversation_id)
.first()
)
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
.first()
)
async with a_session_maker() as session:
result = await session.execute(
select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_id == conversation_id
)
)
conversation_metadata = result.scalars().first()
result = await session.execute(
select(StoredConversationMetadataSaas).where(
StoredConversationMetadataSaas.conversation_id == conversation_id
)
)
conversation_metadata_saas = result.scalars().first()
if conversation_metadata is None or conversation_metadata_saas is None:
# Conversation is running in different server
return

View File

@@ -1,4 +1,5 @@
import asyncio
import contextlib
import json
import time
from dataclasses import dataclass
@@ -444,11 +445,19 @@ async def test_disconnect_from_stopped_with_stopped_remote():
# Create a mock SIO with scan results for only remote_session1
sio = get_mock_sio(scan_keys=[b'ohcnv:user1:remote_session1'])
# Mock the database connection to avoid actual database connections
db_mock = MagicMock()
db_session_mock = MagicMock()
db_mock.__enter__.return_value = db_session_mock
session_maker_mock = MagicMock(return_value=db_mock)
# Mock the async database session
mock_user = MagicMock()
mock_user.user_id = 'user1'
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_user
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
@contextlib.asynccontextmanager
async def mock_a_session_maker():
yield mock_session
with (
patch(
@@ -456,8 +465,8 @@ async def test_disconnect_from_stopped_with_stopped_remote():
AsyncMock(),
),
patch(
'server.clustered_conversation_manager.session_maker',
session_maker_mock,
'server.clustered_conversation_manager.a_session_maker',
mock_a_session_maker,
),
patch('asyncio.create_task', MagicMock()),
):
@@ -484,11 +493,6 @@ async def test_disconnect_from_stopped_with_stopped_remote():
MagicMock()
)
# Create a mock for the database query result
mock_user = MagicMock()
mock_user.user_id = 'user1'
db_session_mock.query.return_value.filter.return_value.first.return_value = mock_user
# Mock the _handle_remote_conversation_stopped method with the correct signature
conversation_manager._handle_remote_conversation_stopped = AsyncMock()