Feat conversations CRUDS API (#5775)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
tofarr 2025-01-02 16:09:08 -07:00 committed by GitHub
parent 15e0a50ff4
commit 50f821f9b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 497 additions and 96 deletions

View File

@ -20,7 +20,9 @@ from openhands.server.routes.conversation import app as conversation_api_router
from openhands.server.routes.feedback import app as feedback_api_router
from openhands.server.routes.files import app as files_api_router
from openhands.server.routes.github import app as github_api_router
from openhands.server.routes.new_conversation import app as new_conversation_api_router
from openhands.server.routes.manage_conversations import (
app as manage_conversation_api_router,
)
from openhands.server.routes.public import app as public_api_router
from openhands.server.routes.security import app as security_api_router
from openhands.server.routes.settings import app as settings_router
@ -58,7 +60,7 @@ app.include_router(files_api_router)
app.include_router(security_api_router)
app.include_router(feedback_api_router)
app.include_router(conversation_api_router)
app.include_router(new_conversation_api_router)
app.include_router(manage_conversation_api_router)
app.include_router(settings_router)
app.include_router(github_api_router)

View File

@ -121,7 +121,7 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
if request.url.path.startswith('/api/conversation'):
# FIXME: we should be able to use path_params
path_parts = request.url.path.split('/')
if len(path_parts) > 3:
if len(path_parts) > 4:
conversation_id = request.url.path.split('/')[3]
if not conversation_id:
return False

View File

@ -0,0 +1,224 @@
import uuid
from datetime import datetime
from typing import Callable
from fastapi import APIRouter, Body, Request
from fastapi.responses import JSONResponse
from github import Github
from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.events.stream import EventStreamSubscriber
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.shared import config, session_manager
from openhands.storage.data_models.conversation_info import ConversationInfo
from openhands.storage.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
)
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.utils.async_utils import (
GENERAL_TIMEOUT,
call_async_from_sync,
call_sync_from_async,
wait_all,
)
app = APIRouter(prefix='/api')
UPDATED_AT_CALLBACK_ID = 'updated_at_callback_id'
class InitSessionRequest(BaseModel):
github_token: str | None = None
latest_event_id: int = -1
selected_repository: str | None = None
args: dict | None = None
@app.post('/conversations')
async def new_conversation(request: Request, data: InitSessionRequest):
"""Initialize a new session or join an existing one.
After successful initialization, the client should connect to the WebSocket
using the returned conversation ID
"""
logger.info('Initializing new conversation')
github_token = data.github_token or ''
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings = await settings_store.load()
logger.info('Settings loaded')
session_init_args: dict = {}
if settings:
session_init_args = {**settings.__dict__, **session_init_args}
session_init_args['github_token'] = github_token
session_init_args['selected_repository'] = data.selected_repository
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
logger.info('Conversation store loaded')
conversation_id = uuid.uuid4().hex
while await conversation_store.exists(conversation_id):
logger.warning(f'Collision on conversation ID: {conversation_id}. Retrying...')
conversation_id = uuid.uuid4().hex
logger.info(f'New conversation ID: {conversation_id}')
user_id = ''
if data.github_token:
logger.info('Fetching Github user ID')
with Github(data.github_token) as g:
gh_user = await call_sync_from_async(g.get_user)
user_id = gh_user.id
logger.info(f'Saving metadata for conversation {conversation_id}')
await conversation_store.save_metadata(
ConversationMetadata(
conversation_id=conversation_id,
github_user_id=user_id,
selected_repository=data.selected_repository,
)
)
logger.info(f'Starting agent loop for conversation {conversation_id}')
event_stream = await session_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data
)
try:
event_stream.subscribe(
EventStreamSubscriber.SERVER,
_create_conversation_update_callback(
data.github_token or '', conversation_id
),
UPDATED_AT_CALLBACK_ID,
)
except ValueError:
pass # Already subscribed - take no action
logger.info(f'Finished initializing conversation {conversation_id}')
return JSONResponse(content={'status': 'ok', 'conversation_id': conversation_id})
@app.get('/conversations')
async def search_conversations(
request: Request,
page_id: str | None = None,
limit: int = 20,
) -> ConversationInfoResultSet:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
conversation_ids = set(
conversation.conversation_id
for conversation in conversation_metadata_result_set.results
)
running_conversations = await session_manager.get_agent_loop_running(
set(conversation_ids)
)
result = ConversationInfoResultSet(
results=await wait_all(
_get_conversation_info(
conversation=conversation,
is_running=conversation.conversation_id in running_conversations,
)
for conversation in conversation_metadata_result_set.results
),
next_page_id=conversation_metadata_result_set.next_page_id,
)
return result
@app.get('/conversations/{conversation_id}')
async def get_conversation(
conversation_id: str, request: Request
) -> ConversationInfo | None:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
try:
metadata = await conversation_store.get_metadata(conversation_id)
is_running = await session_manager.is_agent_loop_running(conversation_id)
conversation_info = await _get_conversation_info(metadata, is_running)
return conversation_info
except FileNotFoundError:
return None
@app.patch('/conversations/{conversation_id}')
async def update_conversation(
request: Request, conversation_id: str, title: str = Body(embed=True)
) -> bool:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
metadata = await conversation_store.get_metadata(conversation_id)
if not metadata:
return False
metadata.title = title
await conversation_store.save_metadata(metadata)
return True
@app.delete('/conversations/{conversation_id}')
async def delete_conversation(
conversation_id: str,
request: Request,
) -> bool:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
try:
await conversation_store.get_metadata(conversation_id)
except FileNotFoundError:
return False
is_running = await session_manager.is_agent_loop_running(conversation_id)
if is_running:
return False
await conversation_store.delete_metadata(conversation_id)
return True
async def _get_conversation_info(
conversation: ConversationMetadata,
is_running: bool,
) -> ConversationInfo | None:
try:
title = conversation.title
if not title:
title = f'Conversation {conversation.conversation_id[:5]}'
return ConversationInfo(
conversation_id=conversation.conversation_id,
title=title,
last_updated_at=conversation.last_updated_at,
selected_repository=conversation.selected_repository,
status=ConversationStatus.RUNNING
if is_running
else ConversationStatus.STOPPED,
)
except Exception: # type: ignore
logger.warning(
f'Error loading conversation: {conversation.conversation_id[:5]}',
exc_info=True,
stack_info=True,
)
return None
def _create_conversation_update_callback(
github_token: str, conversation_id: str
) -> Callable:
def callback(*args, **kwargs):
call_async_from_sync(
_update_timestamp_for_conversation,
GENERAL_TIMEOUT,
github_token,
conversation_id,
)
return callback
async def _update_timestamp_for_conversation(github_token: str, conversation_id: str):
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation = await conversation_store.get_metadata(conversation_id)
conversation.last_updated_at = datetime.now()
await conversation_store.save_metadata(conversation)

View File

@ -1,80 +0,0 @@
import uuid
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from github import Github
from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.server.data_models.conversation_metadata import ConversationMetadata
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.shared import config, session_manager
from openhands.utils.async_utils import call_sync_from_async
app = APIRouter(prefix='/api')
class InitSessionRequest(BaseModel):
github_token: str | None = None
latest_event_id: int = -1
selected_repository: str | None = None
args: dict | None = None
@app.post('/conversations')
async def new_conversation(request: Request, data: InitSessionRequest):
"""Initialize a new session or join an existing one.
After successful initialization, the client should connect to the WebSocket
using the returned conversation ID
"""
logger.info('Initializing new conversation')
github_token = ''
if data.github_token:
github_token = data.github_token
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings = await settings_store.load()
logger.info('Settings loaded')
session_init_args: dict = {}
if settings:
session_init_args = {**settings.__dict__, **session_init_args}
session_init_args['github_token'] = github_token
session_init_args['selected_repository'] = data.selected_repository
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
logger.info('Conversation store loaded')
conversation_id = uuid.uuid4().hex
while await conversation_store.exists(conversation_id):
logger.warning(f'Collision on conversation ID: {conversation_id}. Retrying...')
conversation_id = uuid.uuid4().hex
logger.info(f'New conversation ID: {conversation_id}')
user_id = ''
if data.github_token:
logger.info('Fetching Github user ID')
with Github(data.github_token) as g:
gh_user = await call_sync_from_async(g.get_user)
user_id = gh_user.id
logger.info(f'Saving metadata for conversation {conversation_id}')
await conversation_store.save_metadata(
ConversationMetadata(
conversation_id=conversation_id,
github_user_id=user_id,
selected_repository=data.selected_repository,
)
)
logger.info(f'Starting agent loop for conversation {conversation_id}')
await session_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data
)
logger.info(f'Finished initializing conversation {conversation_id}')
return JSONResponse(content={'status': 'ok', 'conversation_id': conversation_id})

View File

@ -18,12 +18,8 @@ ConversationStoreImpl = get_impl(
@app.get('/settings')
async def load_settings(
request: Request,
) -> Settings | None:
github_token = ''
if hasattr(request.state, 'github_token'):
github_token = request.state.github_token
async def load_settings(request: Request) -> Settings | None:
github_token = getattr(request.state, 'github_token', '') or ''
try:
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings = await settings_store.load()

View File

@ -3,7 +3,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from openhands.core.config.app_config import AppConfig
from openhands.server.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.conversation_metadata_result_set import (
ConversationMetadataResultSet,
)
class ConversationStore(ABC):
@ -19,10 +22,22 @@ class ConversationStore(ABC):
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
"""Load conversation metadata"""
@abstractmethod
async def delete_metadata(self, conversation_id: str) -> None:
"""delete conversation metadata"""
@abstractmethod
async def exists(self, conversation_id: str) -> bool:
"""Check if conversation exists"""
@abstractmethod
async def search(
self,
page_id: str | None = None,
limit: int = 20,
) -> ConversationMetadataResultSet:
"""Search conversations"""
@classmethod
@abstractmethod
async def get_instance(

View File

@ -1,15 +1,26 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from pydantic import TypeAdapter
from openhands.core.config.app_config import AppConfig
from openhands.server.data_models.conversation_metadata import ConversationMetadata
from openhands.core.logger import openhands_logger as logger
from openhands.storage import get_file_store
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.conversation_metadata_result_set import (
ConversationMetadataResultSet,
)
from openhands.storage.files import FileStore
from openhands.storage.locations import get_conversation_metadata_filename
from openhands.storage.locations import (
CONVERSATION_BASE_DIR,
get_conversation_metadata_filename,
)
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.search_utils import offset_to_page_id, page_id_to_offset
conversation_metadata_type_adapter = TypeAdapter(ConversationMetadata)
@dataclass
@ -17,14 +28,19 @@ class FileConversationStore(ConversationStore):
file_store: FileStore
async def save_metadata(self, metadata: ConversationMetadata):
json_str = json.dumps(metadata.__dict__)
json_str = conversation_metadata_type_adapter.dump_json(metadata)
path = self.get_conversation_metadata_filename(metadata.conversation_id)
await call_sync_from_async(self.file_store.write, path, json_str)
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
path = self.get_conversation_metadata_filename(conversation_id)
json_str = await call_sync_from_async(self.file_store.read, path)
return ConversationMetadata(**json.loads(json_str))
result = conversation_metadata_type_adapter.validate_json(json_str)
return result
async def delete_metadata(self, conversation_id: str) -> None:
path = self.get_conversation_metadata_filename(conversation_id)
await call_sync_from_async(self.file_store.delete, path)
async def exists(self, conversation_id: str) -> bool:
path = self.get_conversation_metadata_filename(conversation_id)
@ -34,6 +50,41 @@ class FileConversationStore(ConversationStore):
except FileNotFoundError:
return False
async def search(
self,
page_id: str | None = None,
limit: int = 20,
) -> ConversationMetadataResultSet:
conversations: list[ConversationMetadata] = []
metadata_dir = self.get_conversation_metadata_dir()
try:
conversation_ids = [
path.split('/')[-2]
for path in self.file_store.list(metadata_dir)
if not path.startswith(f'{metadata_dir}/.')
]
except FileNotFoundError:
return ConversationMetadataResultSet([])
num_conversations = len(conversation_ids)
start = page_id_to_offset(page_id)
end = min(limit + start, num_conversations)
conversation_ids = conversation_ids[start:end]
conversations = []
for conversation_id in conversation_ids:
try:
conversations.append(await self.get_metadata(conversation_id))
except Exception:
logger.warning(
f'Error loading conversation: {conversation_id}',
exc_info=True,
stack_info=True,
)
next_page_id = offset_to_page_id(end, end < num_conversations)
return ConversationMetadataResultSet(conversations, next_page_id)
def get_conversation_metadata_dir(self) -> str:
return CONVERSATION_BASE_DIR
def get_conversation_metadata_filename(self, conversation_id: str) -> str:
return get_conversation_metadata_filename(conversation_id)

View File

@ -0,0 +1,15 @@
from dataclasses import dataclass
from datetime import datetime
from openhands.storage.data_models.conversation_status import ConversationStatus
@dataclass
class ConversationInfo:
"""Information about a conversation"""
conversation_id: str
title: str
last_updated_at: datetime | None = None
status: ConversationStatus = ConversationStatus.STOPPED
selected_repository: str | None = None

View File

@ -0,0 +1,9 @@
from dataclasses import dataclass, field
from openhands.storage.data_models.conversation_info import ConversationInfo
@dataclass
class ConversationInfoResultSet:
results: list[ConversationInfo] = field(default_factory=list)
next_page_id: str | None = None

View File

@ -1,8 +1,11 @@
from dataclasses import dataclass
from datetime import datetime
@dataclass
class ConversationMetadata:
conversation_id: str
github_user_id: str
github_user_id: int | str
selected_repository: str | None
title: str | None = None
last_updated_at: datetime | None = None

View File

@ -0,0 +1,9 @@
from dataclasses import dataclass, field
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
@dataclass
class ConversationMetadataResultSet:
results: list[ConversationMetadata] = field(default_factory=list)
next_page_id: str | None = None

View File

@ -0,0 +1,6 @@
from enum import Enum
class ConversationStatus(Enum):
RUNNING = 'RUNNING'
STOPPED = 'STOPPED'

View File

@ -0,0 +1,15 @@
import base64
def offset_to_page_id(offset: int, has_next: bool) -> str | None:
if not has_next:
return None
next_page_id = base64.b64encode(str(offset).encode()).decode()
return next_page_id
def page_id_to_offset(page_id: str | None) -> int:
if not page_id:
return 0
offset = int(base64.b64decode(page_id).decode())
return offset

View File

@ -0,0 +1,112 @@
import json
from contextlib import contextmanager
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from openhands.server.routes.manage_conversations import (
get_conversation,
search_conversations,
update_conversation,
)
from openhands.storage.data_models.conversation_info import ConversationInfo
from openhands.storage.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
)
from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.storage.memory import InMemoryFileStore
@contextmanager
def _patch_store():
file_store = InMemoryFileStore()
file_store.write(
'sessions/some_conversation_id/metadata.json',
json.dumps(
{
'title': 'Some Conversation',
'selected_repository': 'foobar',
'conversation_id': 'some_conversation_id',
'github_user_id': 'github_user',
'last_updated_at': '2025-01-01T00:00:00',
}
),
)
with patch(
'openhands.storage.conversation.file_conversation_store.get_file_store',
MagicMock(return_value=file_store),
):
with patch(
'openhands.server.routes.manage_conversations.session_manager.file_store',
file_store,
):
yield
@pytest.mark.asyncio
async def test_search_conversations():
with _patch_store():
result_set = await search_conversations(
MagicMock(state=MagicMock(github_token=''))
)
expected = ConversationInfoResultSet(
results=[
ConversationInfo(
conversation_id='some_conversation_id',
title='Some Conversation',
last_updated_at=datetime.fromisoformat('2025-01-01T00:00:00'),
status=ConversationStatus.STOPPED,
selected_repository='foobar',
)
]
)
assert result_set == expected
@pytest.mark.asyncio
async def test_get_conversation():
with _patch_store():
conversation = await get_conversation(
'some_conversation_id', MagicMock(state=MagicMock(github_token=''))
)
expected = ConversationInfo(
conversation_id='some_conversation_id',
title='Some Conversation',
last_updated_at=datetime.fromisoformat('2025-01-01T00:00:00'),
status=ConversationStatus.STOPPED,
selected_repository='foobar',
)
assert conversation == expected
@pytest.mark.asyncio
async def test_get_missing_conversation():
with _patch_store():
assert (
await get_conversation(
'no_such_conversation', MagicMock(state=MagicMock(github_token=''))
)
is None
)
@pytest.mark.asyncio
async def test_update_conversation():
with _patch_store():
await update_conversation(
MagicMock(state=MagicMock(github_token='')),
'some_conversation_id',
'New Title',
)
conversation = await get_conversation(
'some_conversation_id', MagicMock(state=MagicMock(github_token=''))
)
expected = ConversationInfo(
conversation_id='some_conversation_id',
title='New Title',
last_updated_at=datetime.fromisoformat('2025-01-01T00:00:00'),
status=ConversationStatus.STOPPED,
selected_repository='foobar',
)
assert conversation == expected

View File

@ -0,0 +1,24 @@
from openhands.utils.search_utils import offset_to_page_id, page_id_to_offset
def test_offset_to_page_id():
# Test with has_next=True
assert bool(offset_to_page_id(10, True))
assert bool(offset_to_page_id(0, True))
# Test with has_next=False should return None
assert offset_to_page_id(10, False) is None
assert offset_to_page_id(0, False) is None
def test_page_id_to_offset():
# Test with None should return 0
assert page_id_to_offset(None) == 0
def test_bidirectional_conversion():
# Test converting offset to page_id and back
test_offsets = [0, 1, 10, 100, 1000]
for offset in test_offsets:
page_id = offset_to_page_id(offset, True)
assert page_id_to_offset(page_id) == offset