refactor(enterprise): use async database sessions in feedback routes (#13137)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-03-02 15:17:44 -05:00
committed by GitHub
parent 7fdb423f99
commit c82ee4c7db
2 changed files with 51 additions and 66 deletions

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.future import select from sqlalchemy.future import select
from storage.database import session_maker from storage.database import a_session_maker
from storage.feedback import ConversationFeedback from storage.feedback import ConversationFeedback
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
@@ -11,7 +11,6 @@ from openhands.events.event_store import EventStore
from openhands.server.dependencies import get_dependencies from openhands.server.dependencies import get_dependencies
from openhands.server.shared import file_store from openhands.server.shared import file_store
from openhands.server.user_auth import get_user_id from openhands.server.user_auth import get_user_id
from openhands.utils.async_utils import call_sync_from_async
# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint # We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint
# is protected. The actual protection is provided by SetAuthCookieMiddleware # is protected. The actual protection is provided by SetAuthCookieMiddleware
@@ -37,23 +36,19 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
""" """
# Verify the conversation belongs to the user # Verify the conversation belongs to the user
def _verify_conversation(): async with a_session_maker() as session:
with session_maker() as session: result = await session.execute(
metadata = ( select(StoredConversationMetadataSaas).where(
session.query(StoredConversationMetadataSaas) StoredConversationMetadataSaas.conversation_id == conversation_id,
.filter( StoredConversationMetadataSaas.user_id == user_id,
StoredConversationMetadataSaas.conversation_id == conversation_id, )
StoredConversationMetadataSaas.user_id == user_id, )
) metadata = result.scalars().first()
.first() if not metadata:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Conversation {conversation_id} not found',
) )
if not metadata:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Conversation {conversation_id} not found',
)
await call_sync_from_async(_verify_conversation)
# Create an event store to access the events directly # Create an event store to access the events directly
# This works even when the conversation is not running # This works even when the conversation is not running
@@ -103,12 +98,9 @@ async def submit_conversation_feedback(feedback: FeedbackRequest):
) )
# Add to database # Add to database
def _save_feedback(): async with a_session_maker() as session:
with session_maker() as session: session.add(new_feedback)
session.add(new_feedback) await session.commit()
session.commit()
await call_sync_from_async(_save_feedback)
return {'status': 'success', 'message': 'Feedback submitted successfully'} return {'status': 'success', 'message': 'Feedback submitted successfully'}
@@ -127,30 +119,27 @@ async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_us
return {} return {}
# Query for existing feedback for all events # Query for existing feedback for all events
def _check_feedback(): async with a_session_maker() as session:
with session_maker() as session: result = await session.execute(
result = session.execute( select(ConversationFeedback).where(
select(ConversationFeedback).where( ConversationFeedback.conversation_id == conversation_id,
ConversationFeedback.conversation_id == conversation_id, ConversationFeedback.event_id.in_(event_ids),
ConversationFeedback.event_id.in_(event_ids),
)
) )
)
# Create a mapping of event_id to feedback # Create a mapping of event_id to feedback
feedback_map = { feedback_map = {
feedback.event_id: { feedback.event_id: {
'exists': True, 'exists': True,
'rating': feedback.rating, 'rating': feedback.rating,
'reason': feedback.reason, 'reason': feedback.reason,
}
for feedback in result.scalars()
} }
for feedback in result.scalars()
}
# Build response including all events # Build response including all events
response = {} response = {}
for event_id in event_ids: for event_id in event_ids:
response[str(event_id)] = feedback_map.get(event_id, {'exists': False}) response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
return response return response
return await call_sync_from_async(_check_feedback)

View File

@@ -1,5 +1,6 @@
import sys import sys
from unittest.mock import MagicMock, patch from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from fastapi import HTTPException from fastapi import HTTPException
@@ -27,6 +28,7 @@ async def test_submit_feedback():
"""Test submitting feedback for a conversation.""" """Test submitting feedback for a conversation."""
# Create a mock database session # Create a mock database session
mock_session = MagicMock() mock_session = MagicMock()
mock_session.commit = AsyncMock()
# Test data # Test data
feedback_data = FeedbackRequest( feedback_data = FeedbackRequest(
@@ -37,19 +39,13 @@ async def test_submit_feedback():
metadata={'browser': 'Chrome', 'os': 'Windows'}, metadata={'browser': 'Chrome', 'os': 'Windows'},
) )
# Mock session_maker and call_sync_from_async # Create async context manager for a_session_maker
with patch('server.routes.feedback.session_maker') as mock_session_maker, patch( @asynccontextmanager
'server.routes.feedback.call_sync_from_async' async def mock_a_session_maker():
) as mock_call_sync: yield mock_session
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
# Mock call_sync_from_async to execute the function
def mock_call_sync_side_effect(func):
return func()
mock_call_sync.side_effect = mock_call_sync_side_effect
# Mock a_session_maker
with patch('server.routes.feedback.a_session_maker', mock_a_session_maker):
# Call the function # Call the function
result = await submit_conversation_feedback(feedback_data) result = await submit_conversation_feedback(feedback_data)
@@ -78,6 +74,7 @@ async def test_invalid_rating():
"""Test submitting feedback with an invalid rating.""" """Test submitting feedback with an invalid rating."""
# Create a mock database session # Create a mock database session
mock_session = MagicMock() mock_session = MagicMock()
mock_session.commit = AsyncMock()
# Since Pydantic validation happens before our function is called, # Since Pydantic validation happens before our function is called,
# we need to patch the validation to test our function's validation # we need to patch the validation to test our function's validation
@@ -95,14 +92,13 @@ async def test_invalid_rating():
# Mock the validation to return our object # Mock the validation to return our object
mock_validate.return_value = feedback_data mock_validate.return_value = feedback_data
# Mock session_maker and call_sync_from_async # Create async context manager for a_session_maker
with patch('server.routes.feedback.session_maker') as mock_session_maker, patch( @asynccontextmanager
'server.routes.feedback.call_sync_from_async' async def mock_a_session_maker():
) as mock_call_sync: yield mock_session
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
mock_call_sync.return_value = None
# Mock a_session_maker
with patch('server.routes.feedback.a_session_maker', mock_a_session_maker):
# Call the function and expect an exception # Call the function and expect an exception
with pytest.raises(HTTPException) as excinfo: with pytest.raises(HTTPException) as excinfo:
await submit_conversation_feedback(feedback_data) await submit_conversation_feedback(feedback_data)