OpenHands/enterprise/storage/proactive_conversation_store.py
2025-09-04 15:44:54 -04:00

167 lines
6.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Callable
from integrations.github.github_types import (
WorkflowRun,
WorkflowRunGroup,
WorkflowRunStatus,
)
from sqlalchemy import and_, delete, select, update
from sqlalchemy.orm import sessionmaker
from storage.database import a_session_maker
from storage.proactive_convos import ProactiveConversation
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import ProviderType
@dataclass
class ProactiveConversationStore:
a_session_maker: sessionmaker = a_session_maker
def get_repo_id(self, provider: ProviderType, repo_id):
return f'{provider.value}##{repo_id}'
async def store_workflow_information(
self,
provider: ProviderType,
repo_id: str,
incoming_commit: str,
workflow: WorkflowRun,
pr_number: int,
get_all_workflows: Callable,
) -> WorkflowRunGroup | None:
"""
1. Get the workflow based on repo_id, pr_number, commit
2. If the field doesn't exist
- Fetch the workflow statuses and store them
- Create a new record
3. Check the incoming workflow run payload, and update statuses based on its fields
4. If all statuses are completed with at least one failure, return WorkflowGroup information else None
This method uses an explicit transaction with row-level locking to ensure
thread safety when multiple processes access the same database rows.
"""
should_send = False
provider_repo_id = self.get_repo_id(provider, repo_id)
final_workflow_group = None
async with self.a_session_maker() as session:
# Start an explicit transaction with row-level locking
async with session.begin():
# Get the existing proactive conversation entry with FOR UPDATE lock
# This ensures exclusive access to these rows during the transaction
stmt = (
select(ProactiveConversation)
.where(
and_(
ProactiveConversation.repo_id == provider_repo_id,
ProactiveConversation.pr_number == pr_number,
ProactiveConversation.commit == incoming_commit,
)
)
.with_for_update() # This adds the row-level lock
)
result = await session.execute(stmt)
commit_entry = result.scalars().first()
# Interaction is complete, do not duplicate event
if commit_entry and commit_entry.conversation_starter_sent:
return None
# Get current workflow statuses
workflow_runs = (
get_all_workflows()
if not commit_entry
else commit_entry.workflow_runs
)
workflow_run_group = (
workflow_runs
if isinstance(workflow_runs, WorkflowRunGroup)
else WorkflowRunGroup(**workflow_runs)
)
# Update with latest incoming workflow information
workflow_run_group.runs[workflow.id] = workflow
statuses = [
workflow.status for _, workflow in workflow_run_group.runs.items()
]
is_none_pending = all(
status != WorkflowRunStatus.PENDING for status in statuses
)
if is_none_pending:
should_send = WorkflowRunStatus.FAILURE in statuses
if should_send:
final_workflow_group = workflow_run_group
if commit_entry:
# Update existing entry (either with workflow status updates, or marking as comment sent)
await session.execute(
update(ProactiveConversation)
.where(
ProactiveConversation.repo_id == provider_repo_id,
ProactiveConversation.pr_number == pr_number,
ProactiveConversation.commit == incoming_commit,
)
.values(
workflow_runs=workflow_run_group.model_dump(),
conversation_starter_sent=should_send,
)
)
else:
convo_record = ProactiveConversation(
repo_id=provider_repo_id,
pr_number=pr_number,
commit=incoming_commit,
workflow_runs=workflow_run_group.model_dump(),
conversation_starter_sent=should_send,
)
session.add(convo_record)
return final_workflow_group
async def clean_old_convos(self, older_than_minutes=30):
"""
Clean up proactive conversation records that are older than the specified time.
Args:
older_than_minutes: Number of minutes. Records older than this will be deleted.
Defaults to 30 minutes.
"""
# Calculate the cutoff time (current time - older_than_minutes)
cutoff_time = datetime.now(UTC) - timedelta(minutes=older_than_minutes)
async with self.a_session_maker() as session:
async with session.begin():
# Delete records older than the cutoff time
delete_stmt = delete(ProactiveConversation).where(
ProactiveConversation.last_updated_at < cutoff_time
)
result = await session.execute(delete_stmt)
# Log the number of deleted records
deleted_count = result.rowcount
logger.info(
f'Deleted {deleted_count} proactive conversation records older than {older_than_minutes} minutes'
)
@classmethod
async def get_instance(cls) -> ProactiveConversationStore:
"""Get an instance of the GitlabWebhookStore.
Returns:
An instance of GitlabWebhookStore
"""
return ProactiveConversationStore(a_session_maker)