From 4fc5351ed7136fda15eff5a055704a5900171a0a Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Tue, 3 Mar 2026 16:38:41 -0700 Subject: [PATCH] Refactor openhands_pr_store.py to use async db sessions (#13186) Co-authored-by: openhands --- .../integrations/github/data_collector.py | 10 +-- .../integrations/github/github_manager.py | 3 +- enterprise/storage/openhands_pr_store.py | 73 +++++++++---------- .../sync/enrich_user_interaction_data.py | 8 +- .../github/test_github_manager.py | 2 +- 5 files changed, 45 insertions(+), 51 deletions(-) diff --git a/enterprise/integrations/github/data_collector.py b/enterprise/integrations/github/data_collector.py index c6a9b4d6ae..d0844b814e 100644 --- a/enterprise/integrations/github/data_collector.py +++ b/enterprise/integrations/github/data_collector.py @@ -569,7 +569,7 @@ class GitHubDataCollector: openhands_helped_author = openhands_commit_count > 0 # Update the PR with OpenHands statistics - update_success = store.update_pr_openhands_stats( + update_success = await store.update_pr_openhands_stats( repo_id=repo_id, pr_number=pr_number, original_updated_at=openhands_pr.updated_at, @@ -612,7 +612,7 @@ class GitHubDataCollector: action = payload.get('action', '') return action == 'closed' and 'pull_request' in payload - def _track_closed_or_merged_pr(self, payload): + async def _track_closed_or_merged_pr(self, payload): """ Track PR closed/merged event """ @@ -671,17 +671,17 @@ class GitHubDataCollector: num_general_comments=num_general_comments, ) - store.insert_pr(pr) + await store.insert_pr(pr) logger.info(f'Tracked PR {status}: {repo_id}#{pr_number}') - def process_payload(self, message: Message): + async def process_payload(self, message: Message): if not COLLECT_GITHUB_INTERACTIONS: return raw_payload = message.message.get('payload', {}) if self._is_pr_closed_or_merged(raw_payload): - self._track_closed_or_merged_pr(raw_payload) + await self._track_closed_or_merged_pr(raw_payload) async def save_data(self, github_view: ResolverViewInterface): if not COLLECT_GITHUB_INTERACTIONS: diff --git a/enterprise/integrations/github/github_manager.py b/enterprise/integrations/github/github_manager.py index 2447b12894..37b03a330d 100644 --- a/enterprise/integrations/github/github_manager.py +++ b/enterprise/integrations/github/github_manager.py @@ -42,7 +42,6 @@ from openhands.server.types import ( SessionExpiredError, ) from openhands.storage.data_models.secrets import Secrets -from openhands.utils.async_utils import call_sync_from_async class GithubManager(Manager[GithubViewType]): @@ -242,7 +241,7 @@ class GithubManager(Manager[GithubViewType]): async def receive_message(self, message: Message): self._confirm_incoming_source_type(message) try: - await call_sync_from_async(self.data_collector.process_payload, message) + await self.data_collector.process_payload(message) except Exception: logger.warning( '[Github]: Error processing payload for gh interaction', exc_info=True diff --git a/enterprise/storage/openhands_pr_store.py b/enterprise/storage/openhands_pr_store.py index 7bc52369f4..2bc3ad661f 100644 --- a/enterprise/storage/openhands_pr_store.py +++ b/enterprise/storage/openhands_pr_store.py @@ -1,44 +1,40 @@ -from dataclasses import dataclass +from __future__ import annotations + from datetime import datetime -from sqlalchemy import and_, desc -from sqlalchemy.orm import sessionmaker -from storage.database import session_maker +from sqlalchemy import and_, desc, select +from storage.database import a_session_maker from storage.openhands_pr import OpenhandsPR from openhands.core.logger import openhands_logger as logger from openhands.integrations.service_types import ProviderType -@dataclass class OpenhandsPRStore: - session_maker: sessionmaker - - def insert_pr(self, pr: OpenhandsPR) -> None: + async def insert_pr(self, pr: OpenhandsPR) -> None: """ Insert a new PR or delete and recreate if repo_id and pr_number already exist. """ - with self.session_maker() as session: + async with a_session_maker() as session: # Check if PR already exists - existing_pr = ( - session.query(OpenhandsPR) - .filter( + result = await session.execute( + select(OpenhandsPR).filter( OpenhandsPR.repo_id == pr.repo_id, OpenhandsPR.pr_number == pr.pr_number, OpenhandsPR.provider == pr.provider, ) - .first() ) + existing_pr = result.scalars().first() if existing_pr: # Delete existing PR - session.delete(existing_pr) - session.flush() + await session.delete(existing_pr) + await session.flush() session.add(pr) - session.commit() + await session.commit() - def increment_process_attempts(self, repo_id: str, pr_number: int) -> bool: + async def increment_process_attempts(self, repo_id: str, pr_number: int) -> bool: """ Increment the process attempts counter for a PR. @@ -49,23 +45,22 @@ class OpenhandsPRStore: Returns: True if PR was found and updated, False otherwise """ - with self.session_maker() as session: - pr = ( - session.query(OpenhandsPR) - .filter( + async with a_session_maker() as session: + result = await session.execute( + select(OpenhandsPR).filter( OpenhandsPR.repo_id == repo_id, OpenhandsPR.pr_number == pr_number ) - .first() ) + pr = result.scalars().first() if pr: pr.process_attempts += 1 - session.merge(pr) - session.commit() + await session.merge(pr) + await session.commit() return True return False - def update_pr_openhands_stats( + async def update_pr_openhands_stats( self, repo_id: str, pr_number: int, @@ -90,16 +85,16 @@ class OpenhandsPRStore: Returns: True if PR was found and updated, False if not found or timestamp changed """ - with self.session_maker() as session: + async with a_session_maker() as session: # Use row-level locking to prevent concurrent modifications - pr: OpenhandsPR | None = ( - session.query(OpenhandsPR) + result = await session.execute( + select(OpenhandsPR) .filter( OpenhandsPR.repo_id == repo_id, OpenhandsPR.pr_number == pr_number ) .with_for_update() - .first() ) + pr: OpenhandsPR | None = result.scalars().first() if not pr: # Current PR snapshot is stale @@ -109,7 +104,7 @@ class OpenhandsPRStore: # Check if the updated_at timestamp has changed (indicating concurrent modification) if pr.updated_at != original_updated_at: # Abort transaction - the PR was modified by another process - session.rollback() + await session.rollback() return False # Update the OpenHands statistics @@ -119,11 +114,11 @@ class OpenhandsPRStore: pr.num_openhands_general_comments = num_openhands_general_comments pr.processed = True - session.merge(pr) - session.commit() + await session.merge(pr) + await session.commit() return True - def get_unprocessed_prs( + async def get_unprocessed_prs( self, limit: int = 50, max_retries: int = 3 ) -> list[OpenhandsPR]: """ @@ -135,9 +130,9 @@ class OpenhandsPRStore: Returns: List of OpenhandsPR objects that need processing """ - with self.session_maker() as session: - unprocessed_prs = ( - session.query(OpenhandsPR) + async with a_session_maker() as session: + result = await session.execute( + select(OpenhandsPR) .filter( and_( ~OpenhandsPR.processed, @@ -147,12 +142,12 @@ class OpenhandsPRStore: ) .order_by(desc(OpenhandsPR.updated_at)) .limit(limit) - .all() ) + unprocessed_prs = list(result.scalars().all()) return unprocessed_prs @classmethod - def get_instance(cls): + def get_instance(cls) -> OpenhandsPRStore: """Get an instance of the OpenhandsPRStore.""" - return OpenhandsPRStore(session_maker) + return OpenhandsPRStore() diff --git a/enterprise/sync/enrich_user_interaction_data.py b/enterprise/sync/enrich_user_interaction_data.py index 184c1c40cc..611aeabf85 100644 --- a/enterprise/sync/enrich_user_interaction_data.py +++ b/enterprise/sync/enrich_user_interaction_data.py @@ -13,7 +13,7 @@ store = OpenhandsPRStore.get_instance() data_collector = GitHubDataCollector() -def get_unprocessed_prs() -> list[OpenhandsPR]: +async def get_unprocessed_prs() -> list[OpenhandsPR]: """ Get unprocessed PR entries from the OpenhandsPR table. @@ -23,7 +23,7 @@ def get_unprocessed_prs() -> list[OpenhandsPR]: Returns: List of OpenhandsPR objects that need processing """ - unprocessed_prs = store.get_unprocessed_prs(PROCESS_AMOUNT, MAX_RETRIES) + unprocessed_prs = await store.get_unprocessed_prs(PROCESS_AMOUNT, MAX_RETRIES) logger.info(f'Retrieved {len(unprocessed_prs)} unprocessed PRs for enrichment') return unprocessed_prs @@ -35,7 +35,7 @@ async def process_pr(pr: OpenhandsPR): logger.info(f'Processing PR #{pr.pr_number} from repo {pr.repo_name}') await data_collector.save_full_pr(pr) - store.increment_process_attempts(pr.repo_id, pr.pr_number) + await store.increment_process_attempts(pr.repo_id, pr.pr_number) async def main(): @@ -45,7 +45,7 @@ async def main(): logger.info('Starting PR data enrichment process') # Get unprocessed PRs - unprocessed_prs = get_unprocessed_prs() + unprocessed_prs = await get_unprocessed_prs() logger.info(f'Found {len(unprocessed_prs)} PRs to process') # Process each PR diff --git a/enterprise/tests/unit/integrations/github/test_github_manager.py b/enterprise/tests/unit/integrations/github/test_github_manager.py index 864be4ef38..18276bdaec 100644 --- a/enterprise/tests/unit/integrations/github/test_github_manager.py +++ b/enterprise/tests/unit/integrations/github/test_github_manager.py @@ -29,7 +29,7 @@ class TestGithubManagerUserNotFound: def mock_data_collector(self): """Create a mock data collector.""" data_collector = MagicMock() - data_collector.process_payload = MagicMock() + data_collector.process_payload = AsyncMock() return data_collector @pytest.fixture