Refactor enterprise code to use async database sessions (Round 3) (#13148)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-03-03 06:35:19 -07:00
committed by GitHub
parent 4a3a42c858
commit f3026583d7
9 changed files with 167 additions and 508 deletions

View File

@@ -200,7 +200,7 @@ class MetricsCollector(ABC):
"""Base class for metrics collectors."""
@abstractmethod
def collect(self) -> List[MetricResult]:
async def collect(self) -> List[MetricResult]:
"""Collect metrics and return results."""
pass
@@ -264,12 +264,13 @@ class SystemMetricsCollector(MetricsCollector):
def collector_name(self) -> str:
return "system_metrics"
def collect(self) -> List[MetricResult]:
async def collect(self) -> List[MetricResult]:
results = []
# Collect user count
with session_maker() as session:
user_count = session.query(UserSettings).count()
async with a_session_maker() as session:
user_count_result = await session.execute(select(func.count()).select_from(UserSettings))
user_count = user_count_result.scalar()
results.append(MetricResult(
key="total_users",
value=user_count
@@ -277,9 +278,11 @@ class SystemMetricsCollector(MetricsCollector):
# Collect conversation count (last 30 days)
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
conversation_count = session.query(StoredConversationMetadata)\
.filter(StoredConversationMetadata.created_at >= thirty_days_ago)\
.count()
conversation_count_result = await session.execute(
select(func.count()).select_from(StoredConversationMetadata)
.where(StoredConversationMetadata.created_at >= thirty_days_ago)
)
conversation_count = conversation_count_result.scalar()
results.append(MetricResult(
key="conversations_30d",
@@ -303,7 +306,7 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
"""Collect metrics from all registered collectors."""
# Check if collection is needed
if not self._should_collect():
if not await self._should_collect():
return {"status": "skipped", "reason": "too_recent"}
# Collect metrics from all registered collectors
@@ -313,7 +316,7 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
for collector in collector_registry.get_all_collectors():
try:
if collector.should_collect():
results = collector.collect()
results = await collector.collect()
for result in results:
all_metrics[result.key] = result.value
collector_results[collector.collector_name] = len(results)
@@ -322,13 +325,13 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
collector_results[collector.collector_name] = f"error: {e}"
# Store metrics in database
with session_maker() as session:
async with a_session_maker() as session:
telemetry_record = TelemetryMetrics(
metrics_data=all_metrics,
collected_at=datetime.now(timezone.utc)
)
session.add(telemetry_record)
session.commit()
await session.commit()
# Note: No need to track last_collection_at separately
# Can be derived from MAX(collected_at) in telemetry_metrics
@@ -339,11 +342,12 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
"collectors_run": collector_results
}
def _should_collect(self) -> bool:
async def _should_collect(self) -> bool:
"""Check if collection is needed based on interval."""
with session_maker() as session:
async with a_session_maker() as session:
# Get last collection time from metrics table
last_collected = session.query(func.max(TelemetryMetrics.collected_at)).scalar()
result = await session.execute(select(func.max(TelemetryMetrics.collected_at)))
last_collected = result.scalar()
if not last_collected:
return True
@@ -366,17 +370,19 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
"""Upload pending metrics to Replicated."""
# Get pending metrics
with session_maker() as session:
pending_metrics = session.query(TelemetryMetrics)\
.filter(TelemetryMetrics.uploaded_at.is_(None))\
.order_by(TelemetryMetrics.collected_at)\
.all()
async with a_session_maker() as session:
result = await session.execute(
select(TelemetryMetrics)
.where(TelemetryMetrics.uploaded_at.is_(None))
.order_by(TelemetryMetrics.collected_at)
)
pending_metrics = result.scalars().all()
if not pending_metrics:
return {"status": "no_pending_metrics"}
# Get admin email - skip if not available
admin_email = self._get_admin_email()
admin_email = await self._get_admin_email()
if not admin_email:
logger.info("Skipping telemetry upload - no admin email available")
return {
@@ -413,13 +419,15 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
await instance.set_status(InstanceStatus.RUNNING)
# Mark as uploaded
with session_maker() as session:
record = session.query(TelemetryMetrics)\
.filter(TelemetryMetrics.id == metric_record.id)\
.first()
async with a_session_maker() as session:
result = await session.execute(
select(TelemetryMetrics)
.where(TelemetryMetrics.id == metric_record.id)
)
record = result.scalar_one_or_none()
if record:
record.uploaded_at = datetime.now(timezone.utc)
session.commit()
await session.commit()
uploaded_count += 1
@@ -427,14 +435,16 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
logger.error(f"Failed to upload metrics {metric_record.id}: {e}")
# Update error info
with session_maker() as session:
record = session.query(TelemetryMetrics)\
.filter(TelemetryMetrics.id == metric_record.id)\
.first()
async with a_session_maker() as session:
result = await session.execute(
select(TelemetryMetrics)
.where(TelemetryMetrics.id == metric_record.id)
)
record = result.scalar_one_or_none()
if record:
record.upload_attempts += 1
record.last_upload_error = str(e)
session.commit()
await session.commit()
failed_count += 1
@@ -448,7 +458,7 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
"total_processed": len(pending_metrics)
}
def _get_admin_email(self) -> str | None:
async def _get_admin_email(self) -> str | None:
"""Get administrator email for customer identification."""
# 1. Check environment variable first
env_admin_email = os.getenv('OPENHANDS_ADMIN_EMAIL')
@@ -457,12 +467,15 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
return env_admin_email
# 2. Use first active user's email (earliest accepted_tos)
with session_maker() as session:
first_user = session.query(UserSettings)\
.filter(UserSettings.email.isnot(None))\
.filter(UserSettings.accepted_tos.isnot(None))\
.order_by(UserSettings.accepted_tos.asc())\
.first()
async with a_session_maker() as session:
result = await session.execute(
select(UserSettings)
.where(UserSettings.email.isnot(None))
.where(UserSettings.accepted_tos.isnot(None))
.order_by(UserSettings.accepted_tos.asc())
.limit(1)
)
first_user = result.scalar_one_or_none()
if first_user and first_user.email:
logger.info(f"Using first active user email: {first_user.email}")
@@ -474,15 +487,16 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
async def _update_telemetry_identity(self, customer_id: str, instance_id: str) -> None:
"""Update or create telemetry identity record."""
with session_maker() as session:
identity = session.query(TelemetryIdentity).first()
async with a_session_maker() as session:
result = await session.execute(select(TelemetryIdentity).limit(1))
identity = result.scalar_one_or_none()
if not identity:
identity = TelemetryIdentity()
session.add(identity)
identity.customer_id = customer_id
identity.instance_id = instance_id
session.commit()
await session.commit()
```
### 4.4 License Warning System
@@ -503,11 +517,13 @@ async def get_license_status():
if not _is_openhands_enterprise():
return {"warn": False, "message": ""}
with session_maker() as session:
async with a_session_maker() as session:
# Get last successful upload time from metrics table
last_upload = session.query(func.max(TelemetryMetrics.uploaded_at))\
.filter(TelemetryMetrics.uploaded_at.isnot(None))\
.scalar()
result = await session.execute(
select(func.max(TelemetryMetrics.uploaded_at))
.where(TelemetryMetrics.uploaded_at.isnot(None))
)
last_upload = result.scalar()
if not last_upload:
# No successful uploads yet - show warning after 4 days
@@ -521,10 +537,13 @@ async def get_license_status():
if days_since_upload > 4:
# Find oldest unsent batch
oldest_unsent = session.query(TelemetryMetrics)\
.filter(TelemetryMetrics.uploaded_at.is_(None))\
.order_by(TelemetryMetrics.collected_at)\
.first()
result = await session.execute(
select(TelemetryMetrics)
.where(TelemetryMetrics.uploaded_at.is_(None))
.order_by(TelemetryMetrics.collected_at)
.limit(1)
)
oldest_unsent = result.scalar_one_or_none()
if oldest_unsent:
# Calculate expiration date (oldest unsent + 34 days)
@@ -630,19 +649,23 @@ spec:
- python
- -c
- |
import asyncio
from enterprise.storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from enterprise.storage.database import session_maker
from enterprise.storage.database import a_session_maker
from enterprise.server.telemetry.collection_processor import TelemetryCollectionProcessor
# Create collection task
processor = TelemetryCollectionProcessor()
task = MaintenanceTask()
task.set_processor(processor)
task.status = MaintenanceTaskStatus.PENDING
async def main():
# Create collection task
processor = TelemetryCollectionProcessor()
task = MaintenanceTask()
task.set_processor(processor)
task.status = MaintenanceTaskStatus.PENDING
with session_maker() as session:
session.add(task)
session.commit()
async with a_session_maker() as session:
session.add(task)
await session.commit()
asyncio.run(main())
restartPolicy: OnFailure
```
@@ -680,23 +703,27 @@ spec:
- python
- -c
- |
import asyncio
from enterprise.storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from enterprise.storage.database import session_maker
from enterprise.storage.database import a_session_maker
from enterprise.server.telemetry.upload_processor import TelemetryUploadProcessor
import os
# Create upload task
processor = TelemetryUploadProcessor(
replicated_publishable_key=os.getenv('REPLICATED_PUBLISHABLE_KEY'),
replicated_app_slug=os.getenv('REPLICATED_APP_SLUG', 'openhands-enterprise')
)
task = MaintenanceTask()
task.set_processor(processor)
task.status = MaintenanceTaskStatus.PENDING
async def main():
# Create upload task
processor = TelemetryUploadProcessor(
replicated_publishable_key=os.getenv('REPLICATED_PUBLISHABLE_KEY'),
replicated_app_slug=os.getenv('REPLICATED_APP_SLUG', 'openhands-enterprise')
)
task = MaintenanceTask()
task.set_processor(processor)
task.status = MaintenanceTaskStatus.PENDING
with session_maker() as session:
session.add(task)
session.commit()
async with a_session_maker() as session:
session.add(task)
await session.commit()
asyncio.run(main())
restartPolicy: OnFailure
```