From c82ee4c7db611c9317e53e0d71de89c0604c3b0e Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Mon, 2 Mar 2026 15:17:44 -0500 Subject: [PATCH] refactor(enterprise): use async database sessions in feedback routes (#13137) Co-authored-by: openhands --- enterprise/server/routes/feedback.py | 81 +++++++++++--------------- enterprise/tests/unit/test_feedback.py | 36 +++++------- 2 files changed, 51 insertions(+), 66 deletions(-) diff --git a/enterprise/server/routes/feedback.py b/enterprise/server/routes/feedback.py index 7cfbf05fac..f8ce971808 100644 --- a/enterprise/server/routes/feedback.py +++ b/enterprise/server/routes/feedback.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, Field from sqlalchemy.future import select -from storage.database import session_maker +from storage.database import a_session_maker from storage.feedback import ConversationFeedback from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas @@ -11,7 +11,6 @@ from openhands.events.event_store import EventStore from openhands.server.dependencies import get_dependencies from openhands.server.shared import file_store from openhands.server.user_auth import get_user_id -from openhands.utils.async_utils import call_sync_from_async # We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint # is protected. The actual protection is provided by SetAuthCookieMiddleware @@ -37,23 +36,19 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]: """ # Verify the conversation belongs to the user - def _verify_conversation(): - with session_maker() as session: - metadata = ( - session.query(StoredConversationMetadataSaas) - .filter( - StoredConversationMetadataSaas.conversation_id == conversation_id, - StoredConversationMetadataSaas.user_id == user_id, - ) - .first() + async with a_session_maker() as session: + result = await session.execute( + select(StoredConversationMetadataSaas).where( + StoredConversationMetadataSaas.conversation_id == conversation_id, + StoredConversationMetadataSaas.user_id == user_id, + ) + ) + metadata = result.scalars().first() + if not metadata: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f'Conversation {conversation_id} not found', ) - if not metadata: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f'Conversation {conversation_id} not found', - ) - - await call_sync_from_async(_verify_conversation) # Create an event store to access the events directly # This works even when the conversation is not running @@ -103,12 +98,9 @@ async def submit_conversation_feedback(feedback: FeedbackRequest): ) # Add to database - def _save_feedback(): - with session_maker() as session: - session.add(new_feedback) - session.commit() - - await call_sync_from_async(_save_feedback) + async with a_session_maker() as session: + session.add(new_feedback) + await session.commit() return {'status': 'success', 'message': 'Feedback submitted successfully'} @@ -127,30 +119,27 @@ async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_us return {} # Query for existing feedback for all events - def _check_feedback(): - with session_maker() as session: - result = session.execute( - select(ConversationFeedback).where( - ConversationFeedback.conversation_id == conversation_id, - ConversationFeedback.event_id.in_(event_ids), - ) + async with a_session_maker() as session: + result = await session.execute( + select(ConversationFeedback).where( + ConversationFeedback.conversation_id == conversation_id, + ConversationFeedback.event_id.in_(event_ids), ) + ) - # Create a mapping of event_id to feedback - feedback_map = { - feedback.event_id: { - 'exists': True, - 'rating': feedback.rating, - 'reason': feedback.reason, - } - for feedback in result.scalars() + # Create a mapping of event_id to feedback + feedback_map = { + feedback.event_id: { + 'exists': True, + 'rating': feedback.rating, + 'reason': feedback.reason, } + for feedback in result.scalars() + } - # Build response including all events - response = {} - for event_id in event_ids: - response[str(event_id)] = feedback_map.get(event_id, {'exists': False}) + # Build response including all events + response = {} + for event_id in event_ids: + response[str(event_id)] = feedback_map.get(event_id, {'exists': False}) - return response - - return await call_sync_from_async(_check_feedback) + return response diff --git a/enterprise/tests/unit/test_feedback.py b/enterprise/tests/unit/test_feedback.py index 5c53732e94..d2617cb0a9 100644 --- a/enterprise/tests/unit/test_feedback.py +++ b/enterprise/tests/unit/test_feedback.py @@ -1,5 +1,6 @@ import sys -from unittest.mock import MagicMock, patch +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import HTTPException @@ -27,6 +28,7 @@ async def test_submit_feedback(): """Test submitting feedback for a conversation.""" # Create a mock database session mock_session = MagicMock() + mock_session.commit = AsyncMock() # Test data feedback_data = FeedbackRequest( @@ -37,19 +39,13 @@ async def test_submit_feedback(): metadata={'browser': 'Chrome', 'os': 'Windows'}, ) - # Mock session_maker and call_sync_from_async - with patch('server.routes.feedback.session_maker') as mock_session_maker, patch( - 'server.routes.feedback.call_sync_from_async' - ) as mock_call_sync: - mock_session_maker.return_value.__enter__.return_value = mock_session - mock_session_maker.return_value.__exit__.return_value = None - - # Mock call_sync_from_async to execute the function - def mock_call_sync_side_effect(func): - return func() - - mock_call_sync.side_effect = mock_call_sync_side_effect + # Create async context manager for a_session_maker + @asynccontextmanager + async def mock_a_session_maker(): + yield mock_session + # Mock a_session_maker + with patch('server.routes.feedback.a_session_maker', mock_a_session_maker): # Call the function result = await submit_conversation_feedback(feedback_data) @@ -78,6 +74,7 @@ async def test_invalid_rating(): """Test submitting feedback with an invalid rating.""" # Create a mock database session mock_session = MagicMock() + mock_session.commit = AsyncMock() # Since Pydantic validation happens before our function is called, # we need to patch the validation to test our function's validation @@ -95,14 +92,13 @@ async def test_invalid_rating(): # Mock the validation to return our object mock_validate.return_value = feedback_data - # Mock session_maker and call_sync_from_async - with patch('server.routes.feedback.session_maker') as mock_session_maker, patch( - 'server.routes.feedback.call_sync_from_async' - ) as mock_call_sync: - mock_session_maker.return_value.__enter__.return_value = mock_session - mock_session_maker.return_value.__exit__.return_value = None - mock_call_sync.return_value = None + # Create async context manager for a_session_maker + @asynccontextmanager + async def mock_a_session_maker(): + yield mock_session + # Mock a_session_maker + with patch('server.routes.feedback.a_session_maker', mock_a_session_maker): # Call the function and expect an exception with pytest.raises(HTTPException) as excinfo: await submit_conversation_feedback(feedback_data)