mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
167 lines
6.4 KiB
Python
167 lines
6.4 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Callable
|
|
|
|
from integrations.github.github_types import (
|
|
WorkflowRun,
|
|
WorkflowRunGroup,
|
|
WorkflowRunStatus,
|
|
)
|
|
from sqlalchemy import and_, delete, select, update
|
|
from sqlalchemy.orm import sessionmaker
|
|
from storage.database import a_session_maker
|
|
from storage.proactive_convos import ProactiveConversation
|
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
from openhands.integrations.service_types import ProviderType
|
|
|
|
|
|
@dataclass
|
|
class ProactiveConversationStore:
|
|
a_session_maker: sessionmaker = a_session_maker
|
|
|
|
def get_repo_id(self, provider: ProviderType, repo_id):
|
|
return f'{provider.value}##{repo_id}'
|
|
|
|
async def store_workflow_information(
|
|
self,
|
|
provider: ProviderType,
|
|
repo_id: str,
|
|
incoming_commit: str,
|
|
workflow: WorkflowRun,
|
|
pr_number: int,
|
|
get_all_workflows: Callable,
|
|
) -> WorkflowRunGroup | None:
|
|
"""
|
|
1. Get the workflow based on repo_id, pr_number, commit
|
|
2. If the field doesn't exist
|
|
- Fetch the workflow statuses and store them
|
|
- Create a new record
|
|
3. Check the incoming workflow run payload, and update statuses based on its fields
|
|
4. If all statuses are completed with at least one failure, return WorkflowGroup information else None
|
|
|
|
This method uses an explicit transaction with row-level locking to ensure
|
|
thread safety when multiple processes access the same database rows.
|
|
"""
|
|
|
|
should_send = False
|
|
provider_repo_id = self.get_repo_id(provider, repo_id)
|
|
|
|
final_workflow_group = None
|
|
|
|
async with self.a_session_maker() as session:
|
|
# Start an explicit transaction with row-level locking
|
|
async with session.begin():
|
|
# Get the existing proactive conversation entry with FOR UPDATE lock
|
|
# This ensures exclusive access to these rows during the transaction
|
|
stmt = (
|
|
select(ProactiveConversation)
|
|
.where(
|
|
and_(
|
|
ProactiveConversation.repo_id == provider_repo_id,
|
|
ProactiveConversation.pr_number == pr_number,
|
|
ProactiveConversation.commit == incoming_commit,
|
|
)
|
|
)
|
|
.with_for_update() # This adds the row-level lock
|
|
)
|
|
result = await session.execute(stmt)
|
|
commit_entry = result.scalars().first()
|
|
|
|
# Interaction is complete, do not duplicate event
|
|
if commit_entry and commit_entry.conversation_starter_sent:
|
|
return None
|
|
|
|
# Get current workflow statuses
|
|
workflow_runs = (
|
|
get_all_workflows()
|
|
if not commit_entry
|
|
else commit_entry.workflow_runs
|
|
)
|
|
|
|
workflow_run_group = (
|
|
workflow_runs
|
|
if isinstance(workflow_runs, WorkflowRunGroup)
|
|
else WorkflowRunGroup(**workflow_runs)
|
|
)
|
|
|
|
# Update with latest incoming workflow information
|
|
workflow_run_group.runs[workflow.id] = workflow
|
|
|
|
statuses = [
|
|
workflow.status for _, workflow in workflow_run_group.runs.items()
|
|
]
|
|
|
|
is_none_pending = all(
|
|
status != WorkflowRunStatus.PENDING for status in statuses
|
|
)
|
|
|
|
if is_none_pending:
|
|
should_send = WorkflowRunStatus.FAILURE in statuses
|
|
|
|
if should_send:
|
|
final_workflow_group = workflow_run_group
|
|
|
|
if commit_entry:
|
|
# Update existing entry (either with workflow status updates, or marking as comment sent)
|
|
await session.execute(
|
|
update(ProactiveConversation)
|
|
.where(
|
|
ProactiveConversation.repo_id == provider_repo_id,
|
|
ProactiveConversation.pr_number == pr_number,
|
|
ProactiveConversation.commit == incoming_commit,
|
|
)
|
|
.values(
|
|
workflow_runs=workflow_run_group.model_dump(),
|
|
conversation_starter_sent=should_send,
|
|
)
|
|
)
|
|
else:
|
|
convo_record = ProactiveConversation(
|
|
repo_id=provider_repo_id,
|
|
pr_number=pr_number,
|
|
commit=incoming_commit,
|
|
workflow_runs=workflow_run_group.model_dump(),
|
|
conversation_starter_sent=should_send,
|
|
)
|
|
session.add(convo_record)
|
|
|
|
return final_workflow_group
|
|
|
|
async def clean_old_convos(self, older_than_minutes=30):
|
|
"""
|
|
Clean up proactive conversation records that are older than the specified time.
|
|
|
|
Args:
|
|
older_than_minutes: Number of minutes. Records older than this will be deleted.
|
|
Defaults to 30 minutes.
|
|
"""
|
|
|
|
# Calculate the cutoff time (current time - older_than_minutes)
|
|
cutoff_time = datetime.now(UTC) - timedelta(minutes=older_than_minutes)
|
|
|
|
async with self.a_session_maker() as session:
|
|
async with session.begin():
|
|
# Delete records older than the cutoff time
|
|
delete_stmt = delete(ProactiveConversation).where(
|
|
ProactiveConversation.last_updated_at < cutoff_time
|
|
)
|
|
result = await session.execute(delete_stmt)
|
|
|
|
# Log the number of deleted records
|
|
deleted_count = result.rowcount
|
|
logger.info(
|
|
f'Deleted {deleted_count} proactive conversation records older than {older_than_minutes} minutes'
|
|
)
|
|
|
|
@classmethod
|
|
async def get_instance(cls) -> ProactiveConversationStore:
|
|
"""Get an instance of the GitlabWebhookStore.
|
|
|
|
Returns:
|
|
An instance of GitlabWebhookStore
|
|
"""
|
|
return ProactiveConversationStore(a_session_maker)
|