Refactor openhands_pr_store.py to use async db sessions (#13186)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-03-03 16:38:41 -07:00
committed by GitHub
parent a1271dc129
commit 4fc5351ed7
5 changed files with 45 additions and 51 deletions

View File

@@ -569,7 +569,7 @@ class GitHubDataCollector:
openhands_helped_author = openhands_commit_count > 0
# Update the PR with OpenHands statistics
update_success = store.update_pr_openhands_stats(
update_success = await store.update_pr_openhands_stats(
repo_id=repo_id,
pr_number=pr_number,
original_updated_at=openhands_pr.updated_at,
@@ -612,7 +612,7 @@ class GitHubDataCollector:
action = payload.get('action', '')
return action == 'closed' and 'pull_request' in payload
def _track_closed_or_merged_pr(self, payload):
async def _track_closed_or_merged_pr(self, payload):
"""
Track PR closed/merged event
"""
@@ -671,17 +671,17 @@ class GitHubDataCollector:
num_general_comments=num_general_comments,
)
store.insert_pr(pr)
await store.insert_pr(pr)
logger.info(f'Tracked PR {status}: {repo_id}#{pr_number}')
def process_payload(self, message: Message):
async def process_payload(self, message: Message):
if not COLLECT_GITHUB_INTERACTIONS:
return
raw_payload = message.message.get('payload', {})
if self._is_pr_closed_or_merged(raw_payload):
self._track_closed_or_merged_pr(raw_payload)
await self._track_closed_or_merged_pr(raw_payload)
async def save_data(self, github_view: ResolverViewInterface):
if not COLLECT_GITHUB_INTERACTIONS:

View File

@@ -42,7 +42,6 @@ from openhands.server.types import (
SessionExpiredError,
)
from openhands.storage.data_models.secrets import Secrets
from openhands.utils.async_utils import call_sync_from_async
class GithubManager(Manager[GithubViewType]):
@@ -242,7 +241,7 @@ class GithubManager(Manager[GithubViewType]):
async def receive_message(self, message: Message):
self._confirm_incoming_source_type(message)
try:
await call_sync_from_async(self.data_collector.process_payload, message)
await self.data_collector.process_payload(message)
except Exception:
logger.warning(
'[Github]: Error processing payload for gh interaction', exc_info=True

View File

@@ -1,44 +1,40 @@
from dataclasses import dataclass
from __future__ import annotations
from datetime import datetime
from sqlalchemy import and_, desc
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from sqlalchemy import and_, desc, select
from storage.database import a_session_maker
from storage.openhands_pr import OpenhandsPR
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import ProviderType
@dataclass
class OpenhandsPRStore:
session_maker: sessionmaker
def insert_pr(self, pr: OpenhandsPR) -> None:
async def insert_pr(self, pr: OpenhandsPR) -> None:
"""
Insert a new PR or delete and recreate if repo_id and pr_number already exist.
"""
with self.session_maker() as session:
async with a_session_maker() as session:
# Check if PR already exists
existing_pr = (
session.query(OpenhandsPR)
.filter(
result = await session.execute(
select(OpenhandsPR).filter(
OpenhandsPR.repo_id == pr.repo_id,
OpenhandsPR.pr_number == pr.pr_number,
OpenhandsPR.provider == pr.provider,
)
.first()
)
existing_pr = result.scalars().first()
if existing_pr:
# Delete existing PR
session.delete(existing_pr)
session.flush()
await session.delete(existing_pr)
await session.flush()
session.add(pr)
session.commit()
await session.commit()
def increment_process_attempts(self, repo_id: str, pr_number: int) -> bool:
async def increment_process_attempts(self, repo_id: str, pr_number: int) -> bool:
"""
Increment the process attempts counter for a PR.
@@ -49,23 +45,22 @@ class OpenhandsPRStore:
Returns:
True if PR was found and updated, False otherwise
"""
with self.session_maker() as session:
pr = (
session.query(OpenhandsPR)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(OpenhandsPR).filter(
OpenhandsPR.repo_id == repo_id, OpenhandsPR.pr_number == pr_number
)
.first()
)
pr = result.scalars().first()
if pr:
pr.process_attempts += 1
session.merge(pr)
session.commit()
await session.merge(pr)
await session.commit()
return True
return False
def update_pr_openhands_stats(
async def update_pr_openhands_stats(
self,
repo_id: str,
pr_number: int,
@@ -90,16 +85,16 @@ class OpenhandsPRStore:
Returns:
True if PR was found and updated, False if not found or timestamp changed
"""
with self.session_maker() as session:
async with a_session_maker() as session:
# Use row-level locking to prevent concurrent modifications
pr: OpenhandsPR | None = (
session.query(OpenhandsPR)
result = await session.execute(
select(OpenhandsPR)
.filter(
OpenhandsPR.repo_id == repo_id, OpenhandsPR.pr_number == pr_number
)
.with_for_update()
.first()
)
pr: OpenhandsPR | None = result.scalars().first()
if not pr:
# Current PR snapshot is stale
@@ -109,7 +104,7 @@ class OpenhandsPRStore:
# Check if the updated_at timestamp has changed (indicating concurrent modification)
if pr.updated_at != original_updated_at:
# Abort transaction - the PR was modified by another process
session.rollback()
await session.rollback()
return False
# Update the OpenHands statistics
@@ -119,11 +114,11 @@ class OpenhandsPRStore:
pr.num_openhands_general_comments = num_openhands_general_comments
pr.processed = True
session.merge(pr)
session.commit()
await session.merge(pr)
await session.commit()
return True
def get_unprocessed_prs(
async def get_unprocessed_prs(
self, limit: int = 50, max_retries: int = 3
) -> list[OpenhandsPR]:
"""
@@ -135,9 +130,9 @@ class OpenhandsPRStore:
Returns:
List of OpenhandsPR objects that need processing
"""
with self.session_maker() as session:
unprocessed_prs = (
session.query(OpenhandsPR)
async with a_session_maker() as session:
result = await session.execute(
select(OpenhandsPR)
.filter(
and_(
~OpenhandsPR.processed,
@@ -147,12 +142,12 @@ class OpenhandsPRStore:
)
.order_by(desc(OpenhandsPR.updated_at))
.limit(limit)
.all()
)
unprocessed_prs = list(result.scalars().all())
return unprocessed_prs
@classmethod
def get_instance(cls):
def get_instance(cls) -> OpenhandsPRStore:
"""Get an instance of the OpenhandsPRStore."""
return OpenhandsPRStore(session_maker)
return OpenhandsPRStore()

View File

@@ -13,7 +13,7 @@ store = OpenhandsPRStore.get_instance()
data_collector = GitHubDataCollector()
def get_unprocessed_prs() -> list[OpenhandsPR]:
async def get_unprocessed_prs() -> list[OpenhandsPR]:
"""
Get unprocessed PR entries from the OpenhandsPR table.
@@ -23,7 +23,7 @@ def get_unprocessed_prs() -> list[OpenhandsPR]:
Returns:
List of OpenhandsPR objects that need processing
"""
unprocessed_prs = store.get_unprocessed_prs(PROCESS_AMOUNT, MAX_RETRIES)
unprocessed_prs = await store.get_unprocessed_prs(PROCESS_AMOUNT, MAX_RETRIES)
logger.info(f'Retrieved {len(unprocessed_prs)} unprocessed PRs for enrichment')
return unprocessed_prs
@@ -35,7 +35,7 @@ async def process_pr(pr: OpenhandsPR):
logger.info(f'Processing PR #{pr.pr_number} from repo {pr.repo_name}')
await data_collector.save_full_pr(pr)
store.increment_process_attempts(pr.repo_id, pr.pr_number)
await store.increment_process_attempts(pr.repo_id, pr.pr_number)
async def main():
@@ -45,7 +45,7 @@ async def main():
logger.info('Starting PR data enrichment process')
# Get unprocessed PRs
unprocessed_prs = get_unprocessed_prs()
unprocessed_prs = await get_unprocessed_prs()
logger.info(f'Found {len(unprocessed_prs)} PRs to process')
# Process each PR

View File

@@ -29,7 +29,7 @@ class TestGithubManagerUserNotFound:
def mock_data_collector(self):
"""Create a mock data collector."""
data_collector = MagicMock()
data_collector.process_payload = MagicMock()
data_collector.process_payload = AsyncMock()
return data_collector
@pytest.fixture