mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Refactor openhands_pr_store.py to use async db sessions (#13186)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -569,7 +569,7 @@ class GitHubDataCollector:
|
||||
openhands_helped_author = openhands_commit_count > 0
|
||||
|
||||
# Update the PR with OpenHands statistics
|
||||
update_success = store.update_pr_openhands_stats(
|
||||
update_success = await store.update_pr_openhands_stats(
|
||||
repo_id=repo_id,
|
||||
pr_number=pr_number,
|
||||
original_updated_at=openhands_pr.updated_at,
|
||||
@@ -612,7 +612,7 @@ class GitHubDataCollector:
|
||||
action = payload.get('action', '')
|
||||
return action == 'closed' and 'pull_request' in payload
|
||||
|
||||
def _track_closed_or_merged_pr(self, payload):
|
||||
async def _track_closed_or_merged_pr(self, payload):
|
||||
"""
|
||||
Track PR closed/merged event
|
||||
"""
|
||||
@@ -671,17 +671,17 @@ class GitHubDataCollector:
|
||||
num_general_comments=num_general_comments,
|
||||
)
|
||||
|
||||
store.insert_pr(pr)
|
||||
await store.insert_pr(pr)
|
||||
logger.info(f'Tracked PR {status}: {repo_id}#{pr_number}')
|
||||
|
||||
def process_payload(self, message: Message):
|
||||
async def process_payload(self, message: Message):
|
||||
if not COLLECT_GITHUB_INTERACTIONS:
|
||||
return
|
||||
|
||||
raw_payload = message.message.get('payload', {})
|
||||
|
||||
if self._is_pr_closed_or_merged(raw_payload):
|
||||
self._track_closed_or_merged_pr(raw_payload)
|
||||
await self._track_closed_or_merged_pr(raw_payload)
|
||||
|
||||
async def save_data(self, github_view: ResolverViewInterface):
|
||||
if not COLLECT_GITHUB_INTERACTIONS:
|
||||
|
||||
@@ -42,7 +42,6 @@ from openhands.server.types import (
|
||||
SessionExpiredError,
|
||||
)
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class GithubManager(Manager[GithubViewType]):
|
||||
@@ -242,7 +241,7 @@ class GithubManager(Manager[GithubViewType]):
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
try:
|
||||
await call_sync_from_async(self.data_collector.process_payload, message)
|
||||
await self.data_collector.process_payload(message)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
'[Github]: Error processing payload for gh interaction', exc_info=True
|
||||
|
||||
@@ -1,44 +1,40 @@
|
||||
from dataclasses import dataclass
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import and_, desc
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import and_, desc, select
|
||||
from storage.database import a_session_maker
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenhandsPRStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
def insert_pr(self, pr: OpenhandsPR) -> None:
|
||||
async def insert_pr(self, pr: OpenhandsPR) -> None:
|
||||
"""
|
||||
Insert a new PR or delete and recreate if repo_id and pr_number already exist.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Check if PR already exists
|
||||
existing_pr = (
|
||||
session.query(OpenhandsPR)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(OpenhandsPR).filter(
|
||||
OpenhandsPR.repo_id == pr.repo_id,
|
||||
OpenhandsPR.pr_number == pr.pr_number,
|
||||
OpenhandsPR.provider == pr.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
existing_pr = result.scalars().first()
|
||||
|
||||
if existing_pr:
|
||||
# Delete existing PR
|
||||
session.delete(existing_pr)
|
||||
session.flush()
|
||||
await session.delete(existing_pr)
|
||||
await session.flush()
|
||||
|
||||
session.add(pr)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
def increment_process_attempts(self, repo_id: str, pr_number: int) -> bool:
|
||||
async def increment_process_attempts(self, repo_id: str, pr_number: int) -> bool:
|
||||
"""
|
||||
Increment the process attempts counter for a PR.
|
||||
|
||||
@@ -49,23 +45,22 @@ class OpenhandsPRStore:
|
||||
Returns:
|
||||
True if PR was found and updated, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
pr = (
|
||||
session.query(OpenhandsPR)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OpenhandsPR).filter(
|
||||
OpenhandsPR.repo_id == repo_id, OpenhandsPR.pr_number == pr_number
|
||||
)
|
||||
.first()
|
||||
)
|
||||
pr = result.scalars().first()
|
||||
|
||||
if pr:
|
||||
pr.process_attempts += 1
|
||||
session.merge(pr)
|
||||
session.commit()
|
||||
await session.merge(pr)
|
||||
await session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_pr_openhands_stats(
|
||||
async def update_pr_openhands_stats(
|
||||
self,
|
||||
repo_id: str,
|
||||
pr_number: int,
|
||||
@@ -90,16 +85,16 @@ class OpenhandsPRStore:
|
||||
Returns:
|
||||
True if PR was found and updated, False if not found or timestamp changed
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Use row-level locking to prevent concurrent modifications
|
||||
pr: OpenhandsPR | None = (
|
||||
session.query(OpenhandsPR)
|
||||
result = await session.execute(
|
||||
select(OpenhandsPR)
|
||||
.filter(
|
||||
OpenhandsPR.repo_id == repo_id, OpenhandsPR.pr_number == pr_number
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
pr: OpenhandsPR | None = result.scalars().first()
|
||||
|
||||
if not pr:
|
||||
# Current PR snapshot is stale
|
||||
@@ -109,7 +104,7 @@ class OpenhandsPRStore:
|
||||
# Check if the updated_at timestamp has changed (indicating concurrent modification)
|
||||
if pr.updated_at != original_updated_at:
|
||||
# Abort transaction - the PR was modified by another process
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
return False
|
||||
|
||||
# Update the OpenHands statistics
|
||||
@@ -119,11 +114,11 @@ class OpenhandsPRStore:
|
||||
pr.num_openhands_general_comments = num_openhands_general_comments
|
||||
pr.processed = True
|
||||
|
||||
session.merge(pr)
|
||||
session.commit()
|
||||
await session.merge(pr)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
def get_unprocessed_prs(
|
||||
async def get_unprocessed_prs(
|
||||
self, limit: int = 50, max_retries: int = 3
|
||||
) -> list[OpenhandsPR]:
|
||||
"""
|
||||
@@ -135,9 +130,9 @@ class OpenhandsPRStore:
|
||||
Returns:
|
||||
List of OpenhandsPR objects that need processing
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
unprocessed_prs = (
|
||||
session.query(OpenhandsPR)
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OpenhandsPR)
|
||||
.filter(
|
||||
and_(
|
||||
~OpenhandsPR.processed,
|
||||
@@ -147,12 +142,12 @@ class OpenhandsPRStore:
|
||||
)
|
||||
.order_by(desc(OpenhandsPR.updated_at))
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
unprocessed_prs = list(result.scalars().all())
|
||||
|
||||
return unprocessed_prs
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
def get_instance(cls) -> OpenhandsPRStore:
|
||||
"""Get an instance of the OpenhandsPRStore."""
|
||||
return OpenhandsPRStore(session_maker)
|
||||
return OpenhandsPRStore()
|
||||
|
||||
@@ -13,7 +13,7 @@ store = OpenhandsPRStore.get_instance()
|
||||
data_collector = GitHubDataCollector()
|
||||
|
||||
|
||||
def get_unprocessed_prs() -> list[OpenhandsPR]:
|
||||
async def get_unprocessed_prs() -> list[OpenhandsPR]:
|
||||
"""
|
||||
Get unprocessed PR entries from the OpenhandsPR table.
|
||||
|
||||
@@ -23,7 +23,7 @@ def get_unprocessed_prs() -> list[OpenhandsPR]:
|
||||
Returns:
|
||||
List of OpenhandsPR objects that need processing
|
||||
"""
|
||||
unprocessed_prs = store.get_unprocessed_prs(PROCESS_AMOUNT, MAX_RETRIES)
|
||||
unprocessed_prs = await store.get_unprocessed_prs(PROCESS_AMOUNT, MAX_RETRIES)
|
||||
logger.info(f'Retrieved {len(unprocessed_prs)} unprocessed PRs for enrichment')
|
||||
return unprocessed_prs
|
||||
|
||||
@@ -35,7 +35,7 @@ async def process_pr(pr: OpenhandsPR):
|
||||
|
||||
logger.info(f'Processing PR #{pr.pr_number} from repo {pr.repo_name}')
|
||||
await data_collector.save_full_pr(pr)
|
||||
store.increment_process_attempts(pr.repo_id, pr.pr_number)
|
||||
await store.increment_process_attempts(pr.repo_id, pr.pr_number)
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -45,7 +45,7 @@ async def main():
|
||||
logger.info('Starting PR data enrichment process')
|
||||
|
||||
# Get unprocessed PRs
|
||||
unprocessed_prs = get_unprocessed_prs()
|
||||
unprocessed_prs = await get_unprocessed_prs()
|
||||
logger.info(f'Found {len(unprocessed_prs)} PRs to process')
|
||||
|
||||
# Process each PR
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestGithubManagerUserNotFound:
|
||||
def mock_data_collector(self):
|
||||
"""Create a mock data collector."""
|
||||
data_collector = MagicMock()
|
||||
data_collector.process_payload = MagicMock()
|
||||
data_collector.process_payload = AsyncMock()
|
||||
return data_collector
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user