mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
refactor(enterprise): use async database sessions in feedback routes (#13137)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
@@ -11,7 +11,6 @@ from openhands.events.event_store import EventStore
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.shared import file_store
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint
|
||||
# is protected. The actual protection is provided by SetAuthCookieMiddleware
|
||||
@@ -37,23 +36,19 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
"""
|
||||
|
||||
# Verify the conversation belongs to the user
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
)
|
||||
metadata = result.scalars().first()
|
||||
if not metadata:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Conversation {conversation_id} not found',
|
||||
)
|
||||
if not metadata:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Conversation {conversation_id} not found',
|
||||
)
|
||||
|
||||
await call_sync_from_async(_verify_conversation)
|
||||
|
||||
# Create an event store to access the events directly
|
||||
# This works even when the conversation is not running
|
||||
@@ -103,12 +98,9 @@ async def submit_conversation_feedback(feedback: FeedbackRequest):
|
||||
)
|
||||
|
||||
# Add to database
|
||||
def _save_feedback():
|
||||
with session_maker() as session:
|
||||
session.add(new_feedback)
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_feedback)
|
||||
async with a_session_maker() as session:
|
||||
session.add(new_feedback)
|
||||
await session.commit()
|
||||
|
||||
return {'status': 'success', 'message': 'Feedback submitted successfully'}
|
||||
|
||||
@@ -127,30 +119,27 @@ async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_us
|
||||
return {}
|
||||
|
||||
# Query for existing feedback for all events
|
||||
def _check_feedback():
|
||||
with session_maker() as session:
|
||||
result = session.execute(
|
||||
select(ConversationFeedback).where(
|
||||
ConversationFeedback.conversation_id == conversation_id,
|
||||
ConversationFeedback.event_id.in_(event_ids),
|
||||
)
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ConversationFeedback).where(
|
||||
ConversationFeedback.conversation_id == conversation_id,
|
||||
ConversationFeedback.event_id.in_(event_ids),
|
||||
)
|
||||
)
|
||||
|
||||
# Create a mapping of event_id to feedback
|
||||
feedback_map = {
|
||||
feedback.event_id: {
|
||||
'exists': True,
|
||||
'rating': feedback.rating,
|
||||
'reason': feedback.reason,
|
||||
}
|
||||
for feedback in result.scalars()
|
||||
# Create a mapping of event_id to feedback
|
||||
feedback_map = {
|
||||
feedback.event_id: {
|
||||
'exists': True,
|
||||
'rating': feedback.rating,
|
||||
'reason': feedback.reason,
|
||||
}
|
||||
for feedback in result.scalars()
|
||||
}
|
||||
|
||||
# Build response including all events
|
||||
response = {}
|
||||
for event_id in event_ids:
|
||||
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
|
||||
# Build response including all events
|
||||
response = {}
|
||||
for event_id in event_ids:
|
||||
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
|
||||
|
||||
return response
|
||||
|
||||
return await call_sync_from_async(_check_feedback)
|
||||
return response
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
@@ -27,6 +28,7 @@ async def test_submit_feedback():
|
||||
"""Test submitting feedback for a conversation."""
|
||||
# Create a mock database session
|
||||
mock_session = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
# Test data
|
||||
feedback_data = FeedbackRequest(
|
||||
@@ -37,19 +39,13 @@ async def test_submit_feedback():
|
||||
metadata={'browser': 'Chrome', 'os': 'Windows'},
|
||||
)
|
||||
|
||||
# Mock session_maker and call_sync_from_async
|
||||
with patch('server.routes.feedback.session_maker') as mock_session_maker, patch(
|
||||
'server.routes.feedback.call_sync_from_async'
|
||||
) as mock_call_sync:
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
|
||||
# Mock call_sync_from_async to execute the function
|
||||
def mock_call_sync_side_effect(func):
|
||||
return func()
|
||||
|
||||
mock_call_sync.side_effect = mock_call_sync_side_effect
|
||||
# Create async context manager for a_session_maker
|
||||
@asynccontextmanager
|
||||
async def mock_a_session_maker():
|
||||
yield mock_session
|
||||
|
||||
# Mock a_session_maker
|
||||
with patch('server.routes.feedback.a_session_maker', mock_a_session_maker):
|
||||
# Call the function
|
||||
result = await submit_conversation_feedback(feedback_data)
|
||||
|
||||
@@ -78,6 +74,7 @@ async def test_invalid_rating():
|
||||
"""Test submitting feedback with an invalid rating."""
|
||||
# Create a mock database session
|
||||
mock_session = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
# Since Pydantic validation happens before our function is called,
|
||||
# we need to patch the validation to test our function's validation
|
||||
@@ -95,14 +92,13 @@ async def test_invalid_rating():
|
||||
# Mock the validation to return our object
|
||||
mock_validate.return_value = feedback_data
|
||||
|
||||
# Mock session_maker and call_sync_from_async
|
||||
with patch('server.routes.feedback.session_maker') as mock_session_maker, patch(
|
||||
'server.routes.feedback.call_sync_from_async'
|
||||
) as mock_call_sync:
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
mock_call_sync.return_value = None
|
||||
# Create async context manager for a_session_maker
|
||||
@asynccontextmanager
|
||||
async def mock_a_session_maker():
|
||||
yield mock_session
|
||||
|
||||
# Mock a_session_maker
|
||||
with patch('server.routes.feedback.a_session_maker', mock_a_session_maker):
|
||||
# Call the function and expect an exception
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await submit_conversation_feedback(feedback_data)
|
||||
|
||||
Reference in New Issue
Block a user