refactor(enterprise): use async database sessions in feedback routes (#13137)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-03-02 15:17:44 -05:00
committed by GitHub
parent 7fdb423f99
commit c82ee4c7db
2 changed files with 51 additions and 66 deletions

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.future import select
from storage.database import session_maker
from storage.database import a_session_maker
from storage.feedback import ConversationFeedback
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
@@ -11,7 +11,6 @@ from openhands.events.event_store import EventStore
from openhands.server.dependencies import get_dependencies
from openhands.server.shared import file_store
from openhands.server.user_auth import get_user_id
from openhands.utils.async_utils import call_sync_from_async
# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint
# is protected. The actual protection is provided by SetAuthCookieMiddleware
@@ -37,23 +36,19 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
"""
# Verify the conversation belongs to the user
def _verify_conversation():
with session_maker() as session:
metadata = (
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadataSaas.conversation_id == conversation_id,
StoredConversationMetadataSaas.user_id == user_id,
)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(StoredConversationMetadataSaas).where(
StoredConversationMetadataSaas.conversation_id == conversation_id,
StoredConversationMetadataSaas.user_id == user_id,
)
)
metadata = result.scalars().first()
if not metadata:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Conversation {conversation_id} not found',
)
if not metadata:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Conversation {conversation_id} not found',
)
await call_sync_from_async(_verify_conversation)
# Create an event store to access the events directly
# This works even when the conversation is not running
@@ -103,12 +98,9 @@ async def submit_conversation_feedback(feedback: FeedbackRequest):
)
# Add to database
def _save_feedback():
with session_maker() as session:
session.add(new_feedback)
session.commit()
await call_sync_from_async(_save_feedback)
async with a_session_maker() as session:
session.add(new_feedback)
await session.commit()
return {'status': 'success', 'message': 'Feedback submitted successfully'}
@@ -127,30 +119,27 @@ async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_us
return {}
# Query for existing feedback for all events
def _check_feedback():
with session_maker() as session:
result = session.execute(
select(ConversationFeedback).where(
ConversationFeedback.conversation_id == conversation_id,
ConversationFeedback.event_id.in_(event_ids),
)
async with a_session_maker() as session:
result = await session.execute(
select(ConversationFeedback).where(
ConversationFeedback.conversation_id == conversation_id,
ConversationFeedback.event_id.in_(event_ids),
)
)
# Create a mapping of event_id to feedback
feedback_map = {
feedback.event_id: {
'exists': True,
'rating': feedback.rating,
'reason': feedback.reason,
}
for feedback in result.scalars()
# Create a mapping of event_id to feedback
feedback_map = {
feedback.event_id: {
'exists': True,
'rating': feedback.rating,
'reason': feedback.reason,
}
for feedback in result.scalars()
}
# Build response including all events
response = {}
for event_id in event_ids:
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
# Build response including all events
response = {}
for event_id in event_ids:
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
return response
return await call_sync_from_async(_check_feedback)
return response

View File

@@ -1,5 +1,6 @@
import sys
from unittest.mock import MagicMock, patch
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
@@ -27,6 +28,7 @@ async def test_submit_feedback():
"""Test submitting feedback for a conversation."""
# Create a mock database session
mock_session = MagicMock()
mock_session.commit = AsyncMock()
# Test data
feedback_data = FeedbackRequest(
@@ -37,19 +39,13 @@ async def test_submit_feedback():
metadata={'browser': 'Chrome', 'os': 'Windows'},
)
# Mock session_maker and call_sync_from_async
with patch('server.routes.feedback.session_maker') as mock_session_maker, patch(
'server.routes.feedback.call_sync_from_async'
) as mock_call_sync:
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
# Mock call_sync_from_async to execute the function
def mock_call_sync_side_effect(func):
return func()
mock_call_sync.side_effect = mock_call_sync_side_effect
# Create async context manager for a_session_maker
@asynccontextmanager
async def mock_a_session_maker():
yield mock_session
# Mock a_session_maker
with patch('server.routes.feedback.a_session_maker', mock_a_session_maker):
# Call the function
result = await submit_conversation_feedback(feedback_data)
@@ -78,6 +74,7 @@ async def test_invalid_rating():
"""Test submitting feedback with an invalid rating."""
# Create a mock database session
mock_session = MagicMock()
mock_session.commit = AsyncMock()
# Since Pydantic validation happens before our function is called,
# we need to patch the validation to test our function's validation
@@ -95,14 +92,13 @@ async def test_invalid_rating():
# Mock the validation to return our object
mock_validate.return_value = feedback_data
# Mock session_maker and call_sync_from_async
with patch('server.routes.feedback.session_maker') as mock_session_maker, patch(
'server.routes.feedback.call_sync_from_async'
) as mock_call_sync:
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
mock_call_sync.return_value = None
# Create async context manager for a_session_maker
@asynccontextmanager
async def mock_a_session_maker():
yield mock_session
# Mock a_session_maker
with patch('server.routes.feedback.a_session_maker', mock_a_session_maker):
# Call the function and expect an exception
with pytest.raises(HTTPException) as excinfo:
await submit_conversation_feedback(feedback_data)