mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-25 21:36:52 +08:00
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
231 lines
8.2 KiB
Python
231 lines
8.2 KiB
Python
# pyright: reportArgumentType=false
|
|
"""SQL implementation of EventCallbackService."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import AsyncGenerator
|
|
from uuid import UUID
|
|
|
|
from fastapi import Request
|
|
from sqlalchemy import UUID as SQLUUID
|
|
from sqlalchemy import Column, Enum, String, and_, func, or_, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from openhands.app_server.event_callback.event_callback_models import (
|
|
CreateEventCallbackRequest,
|
|
EventCallback,
|
|
EventCallbackPage,
|
|
EventCallbackProcessor,
|
|
EventKind,
|
|
)
|
|
from openhands.app_server.event_callback.event_callback_result_models import (
|
|
EventCallbackResultStatus,
|
|
)
|
|
from openhands.app_server.event_callback.event_callback_service import (
|
|
EventCallbackService,
|
|
EventCallbackServiceInjector,
|
|
)
|
|
from openhands.app_server.services.injector import InjectorState
|
|
from openhands.app_server.utils.sql_utils import (
|
|
Base,
|
|
UtcDateTime,
|
|
create_json_type_decorator,
|
|
row2dict,
|
|
)
|
|
from openhands.sdk import Event
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
# TODO: Add user level filtering to this class
|
|
|
|
|
|
class StoredEventCallback(Base): # type: ignore
|
|
__tablename__ = 'event_callback'
|
|
id = Column(SQLUUID, primary_key=True)
|
|
conversation_id = Column(SQLUUID, nullable=True)
|
|
processor = Column(create_json_type_decorator(EventCallbackProcessor))
|
|
event_kind = Column(String, nullable=True)
|
|
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
|
|
|
|
|
|
class StoredEventCallbackResult(Base): # type: ignore
|
|
__tablename__ = 'event_callback_result'
|
|
id = Column(SQLUUID, primary_key=True)
|
|
status = Column(Enum(EventCallbackResultStatus), nullable=True)
|
|
event_callback_id = Column(SQLUUID, index=True)
|
|
event_id = Column(SQLUUID, index=True)
|
|
conversation_id = Column(SQLUUID, index=True)
|
|
detail = Column(String, nullable=True)
|
|
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
|
|
|
|
|
|
@dataclass
|
|
class SQLEventCallbackService(EventCallbackService):
|
|
"""SQL implementation of EventCallbackService."""
|
|
|
|
db_session: AsyncSession
|
|
|
|
async def create_event_callback(
|
|
self, request: CreateEventCallbackRequest
|
|
) -> EventCallback:
|
|
"""Create a new event callback."""
|
|
# Create EventCallback from request
|
|
event_callback = EventCallback(
|
|
conversation_id=request.conversation_id,
|
|
processor=request.processor,
|
|
event_kind=request.event_kind,
|
|
)
|
|
|
|
# Create stored version and add to db_session
|
|
stored_callback = StoredEventCallback(**event_callback.model_dump())
|
|
self.db_session.add(stored_callback)
|
|
await self.db_session.commit()
|
|
await self.db_session.refresh(stored_callback)
|
|
return EventCallback(**row2dict(stored_callback))
|
|
|
|
async def get_event_callback(self, id: UUID) -> EventCallback | None:
|
|
"""Get a single event callback, returning None if not found."""
|
|
stmt = select(StoredEventCallback).where(StoredEventCallback.id == id)
|
|
result = await self.db_session.execute(stmt)
|
|
stored_callback = result.scalar_one_or_none()
|
|
if stored_callback:
|
|
return EventCallback(**row2dict(stored_callback))
|
|
return None
|
|
|
|
async def delete_event_callback(self, id: UUID) -> bool:
|
|
"""Delete an event callback, returning True if deleted, False if not found."""
|
|
stmt = select(StoredEventCallback).where(StoredEventCallback.id == id)
|
|
result = await self.db_session.execute(stmt)
|
|
stored_callback = result.scalar_one_or_none()
|
|
|
|
if stored_callback is None:
|
|
return False
|
|
|
|
await self.db_session.delete(stored_callback)
|
|
await self.db_session.commit()
|
|
return True
|
|
|
|
async def search_event_callbacks(
|
|
self,
|
|
conversation_id__eq: UUID | None = None,
|
|
event_kind__eq: EventKind | None = None,
|
|
event_id__eq: UUID | None = None,
|
|
page_id: str | None = None,
|
|
limit: int = 100,
|
|
) -> EventCallbackPage:
|
|
"""Search for event callbacks, optionally filtered by parameters."""
|
|
# Build the query with filters
|
|
conditions = []
|
|
|
|
if conversation_id__eq is not None:
|
|
conditions.append(
|
|
StoredEventCallback.conversation_id == conversation_id__eq
|
|
)
|
|
|
|
if event_kind__eq is not None:
|
|
conditions.append(StoredEventCallback.event_kind == event_kind__eq)
|
|
|
|
# Note: event_id__eq is not stored in the event_callbacks table
|
|
# This parameter might be used for filtering results after retrieval
|
|
# or might be intended for a different use case
|
|
|
|
# Build the base query
|
|
stmt = select(StoredEventCallback)
|
|
|
|
if conditions:
|
|
stmt = stmt.where(and_(*conditions))
|
|
|
|
# Handle pagination
|
|
if page_id is not None:
|
|
# Parse page_id to get offset or cursor
|
|
try:
|
|
offset = int(page_id)
|
|
stmt = stmt.offset(offset)
|
|
except ValueError:
|
|
# If page_id is not a valid integer, start from beginning
|
|
offset = 0
|
|
else:
|
|
offset = 0
|
|
|
|
# Apply limit and get one extra to check if there are more results
|
|
stmt = stmt.limit(limit + 1).order_by(StoredEventCallback.created_at.desc())
|
|
|
|
result = await self.db_session.execute(stmt)
|
|
stored_callbacks = result.scalars().all()
|
|
|
|
# Check if there are more results
|
|
has_more = len(stored_callbacks) > limit
|
|
if has_more:
|
|
stored_callbacks = stored_callbacks[:limit]
|
|
|
|
# Calculate next page ID
|
|
next_page_id = None
|
|
if has_more:
|
|
next_page_id = str(offset + limit)
|
|
|
|
# Convert stored callbacks to domain models
|
|
callbacks = [EventCallback(**row2dict(cb)) for cb in stored_callbacks]
|
|
return EventCallbackPage(items=callbacks, next_page_id=next_page_id)
|
|
|
|
async def execute_callbacks(self, conversation_id: UUID, event: Event) -> None:
|
|
query = (
|
|
select(StoredEventCallback)
|
|
.where(
|
|
or_(
|
|
StoredEventCallback.event_kind == event.kind,
|
|
StoredEventCallback.event_kind.is_(None),
|
|
)
|
|
)
|
|
.where(
|
|
or_(
|
|
StoredEventCallback.conversation_id == conversation_id,
|
|
StoredEventCallback.conversation_id.is_(None),
|
|
)
|
|
)
|
|
)
|
|
result = await self.db_session.execute(query)
|
|
stored_callbacks = result.scalars().all()
|
|
if stored_callbacks:
|
|
callbacks = [EventCallback(**row2dict(cb)) for cb in stored_callbacks]
|
|
await asyncio.gather(
|
|
*[
|
|
self.execute_callback(conversation_id, callback, event)
|
|
for callback in callbacks
|
|
]
|
|
)
|
|
await self.db_session.commit()
|
|
|
|
async def execute_callback(
|
|
self, conversation_id: UUID, callback: EventCallback, event: Event
|
|
):
|
|
try:
|
|
result = await callback.processor(conversation_id, callback, event)
|
|
stored_result = StoredEventCallbackResult(**row2dict(result))
|
|
except Exception as exc:
|
|
_logger.exception(f'Exception in callback {callback.id}', stack_info=True)
|
|
stored_result = StoredEventCallbackResult(
|
|
status=EventCallbackResultStatus.ERROR,
|
|
event_callback_id=callback.id,
|
|
event_id=event.id,
|
|
conversation_id=conversation_id,
|
|
detail=str(exc),
|
|
)
|
|
self.db_session.add(stored_result)
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
"""Stop using this event callback service."""
|
|
pass
|
|
|
|
|
|
class SQLEventCallbackServiceInjector(EventCallbackServiceInjector):
|
|
async def inject(
|
|
self, state: InjectorState, request: Request | None = None
|
|
) -> AsyncGenerator[EventCallbackService, None]:
|
|
from openhands.app_server.config import get_db_session
|
|
|
|
async with get_db_session(state) as db_session:
|
|
yield SQLEventCallbackService(db_session=db_session)
|