mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Feat conversations CRUDS API (#5775)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
15e0a50ff4
commit
50f821f9b9
@ -20,7 +20,9 @@ from openhands.server.routes.conversation import app as conversation_api_router
|
||||
from openhands.server.routes.feedback import app as feedback_api_router
|
||||
from openhands.server.routes.files import app as files_api_router
|
||||
from openhands.server.routes.github import app as github_api_router
|
||||
from openhands.server.routes.new_conversation import app as new_conversation_api_router
|
||||
from openhands.server.routes.manage_conversations import (
|
||||
app as manage_conversation_api_router,
|
||||
)
|
||||
from openhands.server.routes.public import app as public_api_router
|
||||
from openhands.server.routes.security import app as security_api_router
|
||||
from openhands.server.routes.settings import app as settings_router
|
||||
@ -58,7 +60,7 @@ app.include_router(files_api_router)
|
||||
app.include_router(security_api_router)
|
||||
app.include_router(feedback_api_router)
|
||||
app.include_router(conversation_api_router)
|
||||
app.include_router(new_conversation_api_router)
|
||||
app.include_router(manage_conversation_api_router)
|
||||
app.include_router(settings_router)
|
||||
app.include_router(github_api_router)
|
||||
|
||||
|
||||
@ -121,7 +121,7 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
if request.url.path.startswith('/api/conversation'):
|
||||
# FIXME: we should be able to use path_params
|
||||
path_parts = request.url.path.split('/')
|
||||
if len(path_parts) > 3:
|
||||
if len(path_parts) > 4:
|
||||
conversation_id = request.url.path.split('/')[3]
|
||||
if not conversation_id:
|
||||
return False
|
||||
|
||||
224
openhands/server/routes/manage_conversations.py
Normal file
224
openhands/server/routes/manage_conversations.py
Normal file
@ -0,0 +1,224 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import APIRouter, Body, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from github import Github
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.shared import config, session_manager
|
||||
from openhands.storage.data_models.conversation_info import ConversationInfo
|
||||
from openhands.storage.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.utils.async_utils import (
|
||||
GENERAL_TIMEOUT,
|
||||
call_async_from_sync,
|
||||
call_sync_from_async,
|
||||
wait_all,
|
||||
)
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
UPDATED_AT_CALLBACK_ID = 'updated_at_callback_id'
|
||||
|
||||
|
||||
class InitSessionRequest(BaseModel):
|
||||
github_token: str | None = None
|
||||
latest_event_id: int = -1
|
||||
selected_repository: str | None = None
|
||||
args: dict | None = None
|
||||
|
||||
|
||||
@app.post('/conversations')
|
||||
async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
"""Initialize a new session or join an existing one.
|
||||
After successful initialization, the client should connect to the WebSocket
|
||||
using the returned conversation ID
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
github_token = data.github_token or ''
|
||||
|
||||
logger.info('Loading settings')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
|
||||
settings = await settings_store.load()
|
||||
logger.info('Settings loaded')
|
||||
|
||||
session_init_args: dict = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
|
||||
session_init_args['github_token'] = github_token
|
||||
session_init_args['selected_repository'] = data.selected_repository
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
logger.info('Conversation store loaded')
|
||||
|
||||
conversation_id = uuid.uuid4().hex
|
||||
while await conversation_store.exists(conversation_id):
|
||||
logger.warning(f'Collision on conversation ID: {conversation_id}. Retrying...')
|
||||
conversation_id = uuid.uuid4().hex
|
||||
logger.info(f'New conversation ID: {conversation_id}')
|
||||
|
||||
user_id = ''
|
||||
if data.github_token:
|
||||
logger.info('Fetching Github user ID')
|
||||
with Github(data.github_token) as g:
|
||||
gh_user = await call_sync_from_async(g.get_user)
|
||||
user_id = gh_user.id
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
await conversation_store.save_metadata(
|
||||
ConversationMetadata(
|
||||
conversation_id=conversation_id,
|
||||
github_user_id=user_id,
|
||||
selected_repository=data.selected_repository,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f'Starting agent loop for conversation {conversation_id}')
|
||||
event_stream = await session_manager.maybe_start_agent_loop(
|
||||
conversation_id, conversation_init_data
|
||||
)
|
||||
try:
|
||||
event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
_create_conversation_update_callback(
|
||||
data.github_token or '', conversation_id
|
||||
),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
pass # Already subscribed - take no action
|
||||
logger.info(f'Finished initializing conversation {conversation_id}')
|
||||
return JSONResponse(content={'status': 'ok', 'conversation_id': conversation_id})
|
||||
|
||||
|
||||
@app.get('/conversations')
|
||||
async def search_conversations(
|
||||
request: Request,
|
||||
page_id: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ConversationInfoResultSet:
|
||||
github_token = getattr(request.state, 'github_token', '') or ''
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
||||
conversation_ids = set(
|
||||
conversation.conversation_id
|
||||
for conversation in conversation_metadata_result_set.results
|
||||
)
|
||||
running_conversations = await session_manager.get_agent_loop_running(
|
||||
set(conversation_ids)
|
||||
)
|
||||
result = ConversationInfoResultSet(
|
||||
results=await wait_all(
|
||||
_get_conversation_info(
|
||||
conversation=conversation,
|
||||
is_running=conversation.conversation_id in running_conversations,
|
||||
)
|
||||
for conversation in conversation_metadata_result_set.results
|
||||
),
|
||||
next_page_id=conversation_metadata_result_set.next_page_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@app.get('/conversations/{conversation_id}')
|
||||
async def get_conversation(
|
||||
conversation_id: str, request: Request
|
||||
) -> ConversationInfo | None:
|
||||
github_token = getattr(request.state, 'github_token', '') or ''
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
try:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await session_manager.is_agent_loop_running(conversation_id)
|
||||
conversation_info = await _get_conversation_info(metadata, is_running)
|
||||
return conversation_info
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
@app.patch('/conversations/{conversation_id}')
|
||||
async def update_conversation(
|
||||
request: Request, conversation_id: str, title: str = Body(embed=True)
|
||||
) -> bool:
|
||||
github_token = getattr(request.state, 'github_token', '') or ''
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if not metadata:
|
||||
return False
|
||||
metadata.title = title
|
||||
await conversation_store.save_metadata(metadata)
|
||||
return True
|
||||
|
||||
|
||||
@app.delete('/conversations/{conversation_id}')
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
request: Request,
|
||||
) -> bool:
|
||||
github_token = getattr(request.state, 'github_token', '') or ''
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
is_running = await session_manager.is_agent_loop_running(conversation_id)
|
||||
if is_running:
|
||||
return False
|
||||
await conversation_store.delete_metadata(conversation_id)
|
||||
return True
|
||||
|
||||
|
||||
async def _get_conversation_info(
|
||||
conversation: ConversationMetadata,
|
||||
is_running: bool,
|
||||
) -> ConversationInfo | None:
|
||||
try:
|
||||
title = conversation.title
|
||||
if not title:
|
||||
title = f'Conversation {conversation.conversation_id[:5]}'
|
||||
return ConversationInfo(
|
||||
conversation_id=conversation.conversation_id,
|
||||
title=title,
|
||||
last_updated_at=conversation.last_updated_at,
|
||||
selected_repository=conversation.selected_repository,
|
||||
status=ConversationStatus.RUNNING
|
||||
if is_running
|
||||
else ConversationStatus.STOPPED,
|
||||
)
|
||||
except Exception: # type: ignore
|
||||
logger.warning(
|
||||
f'Error loading conversation: {conversation.conversation_id[:5]}',
|
||||
exc_info=True,
|
||||
stack_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _create_conversation_update_callback(
|
||||
github_token: str, conversation_id: str
|
||||
) -> Callable:
|
||||
def callback(*args, **kwargs):
|
||||
call_async_from_sync(
|
||||
_update_timestamp_for_conversation,
|
||||
GENERAL_TIMEOUT,
|
||||
github_token,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
async def _update_timestamp_for_conversation(github_token: str, conversation_id: str):
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
conversation = await conversation_store.get_metadata(conversation_id)
|
||||
conversation.last_updated_at = datetime.now()
|
||||
await conversation_store.save_metadata(conversation)
|
||||
@ -1,80 +0,0 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from github import Github
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.shared import config, session_manager
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
|
||||
class InitSessionRequest(BaseModel):
|
||||
github_token: str | None = None
|
||||
latest_event_id: int = -1
|
||||
selected_repository: str | None = None
|
||||
args: dict | None = None
|
||||
|
||||
|
||||
@app.post('/conversations')
|
||||
async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
"""Initialize a new session or join an existing one.
|
||||
After successful initialization, the client should connect to the WebSocket
|
||||
using the returned conversation ID
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
github_token = ''
|
||||
if data.github_token:
|
||||
github_token = data.github_token
|
||||
|
||||
logger.info('Loading settings')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
|
||||
settings = await settings_store.load()
|
||||
logger.info('Settings loaded')
|
||||
|
||||
session_init_args: dict = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
|
||||
session_init_args['github_token'] = github_token
|
||||
session_init_args['selected_repository'] = data.selected_repository
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
logger.info('Conversation store loaded')
|
||||
|
||||
conversation_id = uuid.uuid4().hex
|
||||
while await conversation_store.exists(conversation_id):
|
||||
logger.warning(f'Collision on conversation ID: {conversation_id}. Retrying...')
|
||||
conversation_id = uuid.uuid4().hex
|
||||
logger.info(f'New conversation ID: {conversation_id}')
|
||||
|
||||
user_id = ''
|
||||
if data.github_token:
|
||||
logger.info('Fetching Github user ID')
|
||||
with Github(data.github_token) as g:
|
||||
gh_user = await call_sync_from_async(g.get_user)
|
||||
user_id = gh_user.id
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
await conversation_store.save_metadata(
|
||||
ConversationMetadata(
|
||||
conversation_id=conversation_id,
|
||||
github_user_id=user_id,
|
||||
selected_repository=data.selected_repository,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f'Starting agent loop for conversation {conversation_id}')
|
||||
await session_manager.maybe_start_agent_loop(
|
||||
conversation_id, conversation_init_data
|
||||
)
|
||||
logger.info(f'Finished initializing conversation {conversation_id}')
|
||||
return JSONResponse(content={'status': 'ok', 'conversation_id': conversation_id})
|
||||
@ -18,12 +18,8 @@ ConversationStoreImpl = get_impl(
|
||||
|
||||
|
||||
@app.get('/settings')
|
||||
async def load_settings(
|
||||
request: Request,
|
||||
) -> Settings | None:
|
||||
github_token = ''
|
||||
if hasattr(request.state, 'github_token'):
|
||||
github_token = request.state.github_token
|
||||
async def load_settings(request: Request) -> Settings | None:
|
||||
github_token = getattr(request.state, 'github_token', '') or ''
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
|
||||
settings = await settings_store.load()
|
||||
|
||||
@ -3,7 +3,10 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.server.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_metadata_result_set import (
|
||||
ConversationMetadataResultSet,
|
||||
)
|
||||
|
||||
|
||||
class ConversationStore(ABC):
|
||||
@ -19,10 +22,22 @@ class ConversationStore(ABC):
|
||||
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
|
||||
"""Load conversation metadata"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
"""delete conversation metadata"""
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, conversation_id: str) -> bool:
|
||||
"""Check if conversation exists"""
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
page_id: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ConversationMetadataResultSet:
|
||||
"""Search conversations"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(
|
||||
|
||||
@ -1,15 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.server.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_metadata_result_set import (
|
||||
ConversationMetadataResultSet,
|
||||
)
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import get_conversation_metadata_filename
|
||||
from openhands.storage.locations import (
|
||||
CONVERSATION_BASE_DIR,
|
||||
get_conversation_metadata_filename,
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.search_utils import offset_to_page_id, page_id_to_offset
|
||||
|
||||
conversation_metadata_type_adapter = TypeAdapter(ConversationMetadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -17,14 +28,19 @@ class FileConversationStore(ConversationStore):
|
||||
file_store: FileStore
|
||||
|
||||
async def save_metadata(self, metadata: ConversationMetadata):
|
||||
json_str = json.dumps(metadata.__dict__)
|
||||
json_str = conversation_metadata_type_adapter.dump_json(metadata)
|
||||
path = self.get_conversation_metadata_filename(metadata.conversation_id)
|
||||
await call_sync_from_async(self.file_store.write, path, json_str)
|
||||
|
||||
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
|
||||
path = self.get_conversation_metadata_filename(conversation_id)
|
||||
json_str = await call_sync_from_async(self.file_store.read, path)
|
||||
return ConversationMetadata(**json.loads(json_str))
|
||||
result = conversation_metadata_type_adapter.validate_json(json_str)
|
||||
return result
|
||||
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
path = self.get_conversation_metadata_filename(conversation_id)
|
||||
await call_sync_from_async(self.file_store.delete, path)
|
||||
|
||||
async def exists(self, conversation_id: str) -> bool:
|
||||
path = self.get_conversation_metadata_filename(conversation_id)
|
||||
@ -34,6 +50,41 @@ class FileConversationStore(ConversationStore):
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
async def search(
|
||||
self,
|
||||
page_id: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ConversationMetadataResultSet:
|
||||
conversations: list[ConversationMetadata] = []
|
||||
metadata_dir = self.get_conversation_metadata_dir()
|
||||
try:
|
||||
conversation_ids = [
|
||||
path.split('/')[-2]
|
||||
for path in self.file_store.list(metadata_dir)
|
||||
if not path.startswith(f'{metadata_dir}/.')
|
||||
]
|
||||
except FileNotFoundError:
|
||||
return ConversationMetadataResultSet([])
|
||||
num_conversations = len(conversation_ids)
|
||||
start = page_id_to_offset(page_id)
|
||||
end = min(limit + start, num_conversations)
|
||||
conversation_ids = conversation_ids[start:end]
|
||||
conversations = []
|
||||
for conversation_id in conversation_ids:
|
||||
try:
|
||||
conversations.append(await self.get_metadata(conversation_id))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f'Error loading conversation: {conversation_id}',
|
||||
exc_info=True,
|
||||
stack_info=True,
|
||||
)
|
||||
next_page_id = offset_to_page_id(end, end < num_conversations)
|
||||
return ConversationMetadataResultSet(conversations, next_page_id)
|
||||
|
||||
def get_conversation_metadata_dir(self) -> str:
|
||||
return CONVERSATION_BASE_DIR
|
||||
|
||||
def get_conversation_metadata_filename(self, conversation_id: str) -> str:
|
||||
return get_conversation_metadata_filename(conversation_id)
|
||||
|
||||
|
||||
15
openhands/storage/data_models/conversation_info.py
Normal file
15
openhands/storage/data_models/conversation_info.py
Normal file
@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationInfo:
|
||||
"""Information about a conversation"""
|
||||
|
||||
conversation_id: str
|
||||
title: str
|
||||
last_updated_at: datetime | None = None
|
||||
status: ConversationStatus = ConversationStatus.STOPPED
|
||||
selected_repository: str | None = None
|
||||
@ -0,0 +1,9 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from openhands.storage.data_models.conversation_info import ConversationInfo
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationInfoResultSet:
|
||||
results: list[ConversationInfo] = field(default_factory=list)
|
||||
next_page_id: str | None = None
|
||||
@ -1,8 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationMetadata:
|
||||
conversation_id: str
|
||||
github_user_id: str
|
||||
github_user_id: int | str
|
||||
selected_repository: str | None
|
||||
title: str | None = None
|
||||
last_updated_at: datetime | None = None
|
||||
@ -0,0 +1,9 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationMetadataResultSet:
|
||||
results: list[ConversationMetadata] = field(default_factory=list)
|
||||
next_page_id: str | None = None
|
||||
6
openhands/storage/data_models/conversation_status.py
Normal file
6
openhands/storage/data_models/conversation_status.py
Normal file
@ -0,0 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ConversationStatus(Enum):
|
||||
RUNNING = 'RUNNING'
|
||||
STOPPED = 'STOPPED'
|
||||
15
openhands/utils/search_utils.py
Normal file
15
openhands/utils/search_utils.py
Normal file
@ -0,0 +1,15 @@
|
||||
import base64
|
||||
|
||||
|
||||
def offset_to_page_id(offset: int, has_next: bool) -> str | None:
|
||||
if not has_next:
|
||||
return None
|
||||
next_page_id = base64.b64encode(str(offset).encode()).decode()
|
||||
return next_page_id
|
||||
|
||||
|
||||
def page_id_to_offset(page_id: str | None) -> int:
|
||||
if not page_id:
|
||||
return 0
|
||||
offset = int(base64.b64decode(page_id).decode())
|
||||
return offset
|
||||
112
tests/unit/test_conversation.py
Normal file
112
tests/unit/test_conversation.py
Normal file
@ -0,0 +1,112 @@
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.server.routes.manage_conversations import (
|
||||
get_conversation,
|
||||
search_conversations,
|
||||
update_conversation,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_info import ConversationInfo
|
||||
from openhands.storage.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_store():
|
||||
file_store = InMemoryFileStore()
|
||||
file_store.write(
|
||||
'sessions/some_conversation_id/metadata.json',
|
||||
json.dumps(
|
||||
{
|
||||
'title': 'Some Conversation',
|
||||
'selected_repository': 'foobar',
|
||||
'conversation_id': 'some_conversation_id',
|
||||
'github_user_id': 'github_user',
|
||||
'last_updated_at': '2025-01-01T00:00:00',
|
||||
}
|
||||
),
|
||||
)
|
||||
with patch(
|
||||
'openhands.storage.conversation.file_conversation_store.get_file_store',
|
||||
MagicMock(return_value=file_store),
|
||||
):
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.session_manager.file_store',
|
||||
file_store,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_conversations():
|
||||
with _patch_store():
|
||||
result_set = await search_conversations(
|
||||
MagicMock(state=MagicMock(github_token=''))
|
||||
)
|
||||
expected = ConversationInfoResultSet(
|
||||
results=[
|
||||
ConversationInfo(
|
||||
conversation_id='some_conversation_id',
|
||||
title='Some Conversation',
|
||||
last_updated_at=datetime.fromisoformat('2025-01-01T00:00:00'),
|
||||
status=ConversationStatus.STOPPED,
|
||||
selected_repository='foobar',
|
||||
)
|
||||
]
|
||||
)
|
||||
assert result_set == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation():
|
||||
with _patch_store():
|
||||
conversation = await get_conversation(
|
||||
'some_conversation_id', MagicMock(state=MagicMock(github_token=''))
|
||||
)
|
||||
expected = ConversationInfo(
|
||||
conversation_id='some_conversation_id',
|
||||
title='Some Conversation',
|
||||
last_updated_at=datetime.fromisoformat('2025-01-01T00:00:00'),
|
||||
status=ConversationStatus.STOPPED,
|
||||
selected_repository='foobar',
|
||||
)
|
||||
assert conversation == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_missing_conversation():
|
||||
with _patch_store():
|
||||
assert (
|
||||
await get_conversation(
|
||||
'no_such_conversation', MagicMock(state=MagicMock(github_token=''))
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_conversation():
|
||||
with _patch_store():
|
||||
await update_conversation(
|
||||
MagicMock(state=MagicMock(github_token='')),
|
||||
'some_conversation_id',
|
||||
'New Title',
|
||||
)
|
||||
conversation = await get_conversation(
|
||||
'some_conversation_id', MagicMock(state=MagicMock(github_token=''))
|
||||
)
|
||||
expected = ConversationInfo(
|
||||
conversation_id='some_conversation_id',
|
||||
title='New Title',
|
||||
last_updated_at=datetime.fromisoformat('2025-01-01T00:00:00'),
|
||||
status=ConversationStatus.STOPPED,
|
||||
selected_repository='foobar',
|
||||
)
|
||||
assert conversation == expected
|
||||
24
tests/unit/test_search_utils.py
Normal file
24
tests/unit/test_search_utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
from openhands.utils.search_utils import offset_to_page_id, page_id_to_offset
|
||||
|
||||
|
||||
def test_offset_to_page_id():
|
||||
# Test with has_next=True
|
||||
assert bool(offset_to_page_id(10, True))
|
||||
assert bool(offset_to_page_id(0, True))
|
||||
|
||||
# Test with has_next=False should return None
|
||||
assert offset_to_page_id(10, False) is None
|
||||
assert offset_to_page_id(0, False) is None
|
||||
|
||||
|
||||
def test_page_id_to_offset():
|
||||
# Test with None should return 0
|
||||
assert page_id_to_offset(None) == 0
|
||||
|
||||
|
||||
def test_bidirectional_conversion():
|
||||
# Test converting offset to page_id and back
|
||||
test_offsets = [0, 1, 10, 100, 1000]
|
||||
for offset in test_offsets:
|
||||
page_id = offset_to_page_id(offset, True)
|
||||
assert page_id_to_offset(page_id) == offset
|
||||
Loading…
x
Reference in New Issue
Block a user