mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Implement Export feature for V1 conversations with comprehensive unit tests (#12030)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: hieptl <hieptl.developer@gmail.com>
This commit is contained in:
@@ -29,7 +29,7 @@ else:
|
||||
return await async_iterator.__anext__()
|
||||
|
||||
|
||||
from fastapi import APIRouter, Query, Request, status
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, Response, status
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -546,6 +546,45 @@ async def get_conversation_skills(
|
||||
)
|
||||
|
||||
|
||||
@router.get('/{conversation_id}/download')
|
||||
async def export_conversation(
|
||||
conversation_id: UUID,
|
||||
app_conversation_service: AppConversationService = (
|
||||
app_conversation_service_dependency
|
||||
),
|
||||
):
|
||||
"""Download a conversation trajectory as a zip file.
|
||||
|
||||
Returns a zip file containing all events and metadata for the conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The UUID of the conversation to download
|
||||
|
||||
Returns:
|
||||
A zip file containing the conversation trajectory
|
||||
"""
|
||||
try:
|
||||
# Get the zip file content
|
||||
zip_content = await app_conversation_service.export_conversation(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
# Return as a downloadable zip file
|
||||
return Response(
|
||||
content=zip_content,
|
||||
media_type='application/zip',
|
||||
headers={
|
||||
'Content-Disposition': f'attachment; filename="conversation_{conversation_id}.zip"'
|
||||
},
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f'Failed to download trajectory: {str(e)}'
|
||||
)
|
||||
|
||||
|
||||
async def _consume_remaining(
|
||||
async_iter, db_session: AsyncSession, httpx_client: httpx.AsyncClient
|
||||
):
|
||||
|
||||
@@ -113,6 +113,23 @@ class AppConversationService(ABC):
|
||||
Returns True if the conversation was deleted successfully, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def export_conversation(self, conversation_id: UUID) -> bytes:
|
||||
"""Download a conversation trajectory as a zip file.
|
||||
|
||||
Args:
|
||||
conversation_id: The UUID of the conversation to download.
|
||||
|
||||
This method should:
|
||||
1. Get all events for the conversation
|
||||
2. Create a temporary directory
|
||||
3. Save each event as a JSON file
|
||||
4. Save conversation metadata as meta.json
|
||||
5. Create and return a zip file containing all the data
|
||||
|
||||
Returns the zip file as bytes.
|
||||
"""
|
||||
|
||||
|
||||
class AppConversationServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[AppConversationService], ABC
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
@@ -44,6 +48,7 @@ from openhands.app_server.app_conversation.sql_app_conversation_info_service imp
|
||||
)
|
||||
from openhands.app_server.config import get_event_callback_service
|
||||
from openhands.app_server.errors import SandboxError
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.app_server.event_callback.event_callback_models import EventCallback
|
||||
from openhands.app_server.event_callback.event_callback_service import (
|
||||
EventCallbackService,
|
||||
@@ -71,6 +76,7 @@ from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk import Agent, AgentContext, LocalWorkspace
|
||||
from openhands.sdk.llm import LLM
|
||||
from openhands.sdk.secret import LookupSecret, StaticSecret
|
||||
from openhands.sdk.utils.paging import page_iterator
|
||||
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.tools.preset.default import (
|
||||
@@ -93,6 +99,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
app_conversation_info_service: AppConversationInfoService
|
||||
app_conversation_start_task_service: AppConversationStartTaskService
|
||||
event_callback_service: EventCallbackService
|
||||
event_service: EventService
|
||||
sandbox_service: SandboxService
|
||||
sandbox_spec_service: SandboxSpecService
|
||||
jwt_service: JwtService
|
||||
@@ -1178,6 +1185,61 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
|
||||
return deleted_info or deleted_tasks
|
||||
|
||||
async def export_conversation(self, conversation_id: UUID) -> bytes:
|
||||
"""Download a conversation trajectory as a zip file.
|
||||
|
||||
Args:
|
||||
conversation_id: The UUID of the conversation to download.
|
||||
|
||||
Returns the zip file as bytes.
|
||||
"""
|
||||
# Get the conversation info to verify it exists and user has access
|
||||
conversation_info = (
|
||||
await self.app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if not conversation_info:
|
||||
raise ValueError(f'Conversation not found: {conversation_id}')
|
||||
|
||||
# Create a temporary directory to store files
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Get all events for this conversation
|
||||
i = 0
|
||||
async for event in page_iterator(
|
||||
self.event_service.search_events, conversation_id__eq=conversation_id
|
||||
):
|
||||
event_filename = f'event_{i:06d}_{event.id}.json'
|
||||
event_path = os.path.join(temp_dir, event_filename)
|
||||
|
||||
with open(event_path, 'w') as f:
|
||||
# Use model_dump with mode='json' to handle UUID serialization
|
||||
event_data = event.model_dump(mode='json')
|
||||
json.dump(event_data, f, indent=2)
|
||||
i += 1
|
||||
|
||||
# Create meta.json with conversation info
|
||||
meta_path = os.path.join(temp_dir, 'meta.json')
|
||||
with open(meta_path, 'w') as f:
|
||||
f.write(conversation_info.model_dump_json(indent=2))
|
||||
|
||||
# Create zip file in memory
|
||||
zip_buffer = tempfile.NamedTemporaryFile()
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
# Add all files from temp directory to zip
|
||||
for root, dirs, files in os.walk(temp_dir):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(file_path, temp_dir)
|
||||
zipf.write(file_path, arcname)
|
||||
|
||||
# Read the zip file content
|
||||
zip_buffer.seek(0)
|
||||
zip_content = zip_buffer.read()
|
||||
zip_buffer.close()
|
||||
|
||||
return zip_content
|
||||
|
||||
|
||||
class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
||||
sandbox_startup_timeout: int = Field(
|
||||
@@ -1208,6 +1270,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
||||
from openhands.app_server.config import (
|
||||
get_app_conversation_info_service,
|
||||
get_app_conversation_start_task_service,
|
||||
get_event_service,
|
||||
get_global_config,
|
||||
get_httpx_client,
|
||||
get_jwt_service,
|
||||
@@ -1227,6 +1290,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
||||
state, request
|
||||
) as app_conversation_start_task_service,
|
||||
get_event_callback_service(state, request) as event_callback_service,
|
||||
get_event_service(state, request) as event_service,
|
||||
get_jwt_service(state, request) as jwt_service,
|
||||
get_httpx_client(state, request) as httpx_client,
|
||||
):
|
||||
@@ -1274,6 +1338,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
app_conversation_start_task_service=app_conversation_start_task_service,
|
||||
event_callback_service=event_callback_service,
|
||||
event_service=event_service,
|
||||
jwt_service=jwt_service,
|
||||
sandbox_startup_timeout=self.sandbox_startup_timeout,
|
||||
sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency,
|
||||
|
||||
@@ -22,7 +22,10 @@ async def iterate(fn: Callable, **kwargs) -> AsyncIterator:
|
||||
kwargs['page_id'] = None
|
||||
while True:
|
||||
result_set = await fn(**kwargs)
|
||||
for result in result_set.results:
|
||||
items = getattr(result_set, 'items', None)
|
||||
if items is None:
|
||||
items = getattr(result_set, 'results')
|
||||
for result in items:
|
||||
yield result
|
||||
if result_set.next_page_id is None:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user