OpenHands/openhands/app_server/event_callback/sql_event_callback_service.py
Tim O'Farrell f292f3a84d
V1 Integration (#11183)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
2025-10-14 02:16:44 +00:00

231 lines
8.2 KiB
Python

# pyright: reportArgumentType=false
"""SQL implementation of EventCallbackService."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from sqlalchemy import UUID as SQLUUID
from sqlalchemy import Column, Enum, String, and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.app_server.event_callback.event_callback_models import (
CreateEventCallbackRequest,
EventCallback,
EventCallbackPage,
EventCallbackProcessor,
EventKind,
)
from openhands.app_server.event_callback.event_callback_result_models import (
EventCallbackResultStatus,
)
from openhands.app_server.event_callback.event_callback_service import (
EventCallbackService,
EventCallbackServiceInjector,
)
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.utils.sql_utils import (
Base,
UtcDateTime,
create_json_type_decorator,
row2dict,
)
from openhands.sdk import Event
_logger = logging.getLogger(__name__)
# TODO: Add user level filtering to this class
class StoredEventCallback(Base): # type: ignore
__tablename__ = 'event_callback'
id = Column(SQLUUID, primary_key=True)
conversation_id = Column(SQLUUID, nullable=True)
processor = Column(create_json_type_decorator(EventCallbackProcessor))
event_kind = Column(String, nullable=True)
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
class StoredEventCallbackResult(Base): # type: ignore
__tablename__ = 'event_callback_result'
id = Column(SQLUUID, primary_key=True)
status = Column(Enum(EventCallbackResultStatus), nullable=True)
event_callback_id = Column(SQLUUID, index=True)
event_id = Column(SQLUUID, index=True)
conversation_id = Column(SQLUUID, index=True)
detail = Column(String, nullable=True)
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
@dataclass
class SQLEventCallbackService(EventCallbackService):
"""SQL implementation of EventCallbackService."""
db_session: AsyncSession
async def create_event_callback(
self, request: CreateEventCallbackRequest
) -> EventCallback:
"""Create a new event callback."""
# Create EventCallback from request
event_callback = EventCallback(
conversation_id=request.conversation_id,
processor=request.processor,
event_kind=request.event_kind,
)
# Create stored version and add to db_session
stored_callback = StoredEventCallback(**event_callback.model_dump())
self.db_session.add(stored_callback)
await self.db_session.commit()
await self.db_session.refresh(stored_callback)
return EventCallback(**row2dict(stored_callback))
async def get_event_callback(self, id: UUID) -> EventCallback | None:
"""Get a single event callback, returning None if not found."""
stmt = select(StoredEventCallback).where(StoredEventCallback.id == id)
result = await self.db_session.execute(stmt)
stored_callback = result.scalar_one_or_none()
if stored_callback:
return EventCallback(**row2dict(stored_callback))
return None
async def delete_event_callback(self, id: UUID) -> bool:
"""Delete an event callback, returning True if deleted, False if not found."""
stmt = select(StoredEventCallback).where(StoredEventCallback.id == id)
result = await self.db_session.execute(stmt)
stored_callback = result.scalar_one_or_none()
if stored_callback is None:
return False
await self.db_session.delete(stored_callback)
await self.db_session.commit()
return True
async def search_event_callbacks(
self,
conversation_id__eq: UUID | None = None,
event_kind__eq: EventKind | None = None,
event_id__eq: UUID | None = None,
page_id: str | None = None,
limit: int = 100,
) -> EventCallbackPage:
"""Search for event callbacks, optionally filtered by parameters."""
# Build the query with filters
conditions = []
if conversation_id__eq is not None:
conditions.append(
StoredEventCallback.conversation_id == conversation_id__eq
)
if event_kind__eq is not None:
conditions.append(StoredEventCallback.event_kind == event_kind__eq)
# Note: event_id__eq is not stored in the event_callbacks table
# This parameter might be used for filtering results after retrieval
# or might be intended for a different use case
# Build the base query
stmt = select(StoredEventCallback)
if conditions:
stmt = stmt.where(and_(*conditions))
# Handle pagination
if page_id is not None:
# Parse page_id to get offset or cursor
try:
offset = int(page_id)
stmt = stmt.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Apply limit and get one extra to check if there are more results
stmt = stmt.limit(limit + 1).order_by(StoredEventCallback.created_at.desc())
result = await self.db_session.execute(stmt)
stored_callbacks = result.scalars().all()
# Check if there are more results
has_more = len(stored_callbacks) > limit
if has_more:
stored_callbacks = stored_callbacks[:limit]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
# Convert stored callbacks to domain models
callbacks = [EventCallback(**row2dict(cb)) for cb in stored_callbacks]
return EventCallbackPage(items=callbacks, next_page_id=next_page_id)
async def execute_callbacks(self, conversation_id: UUID, event: Event) -> None:
query = (
select(StoredEventCallback)
.where(
or_(
StoredEventCallback.event_kind == event.kind,
StoredEventCallback.event_kind.is_(None),
)
)
.where(
or_(
StoredEventCallback.conversation_id == conversation_id,
StoredEventCallback.conversation_id.is_(None),
)
)
)
result = await self.db_session.execute(query)
stored_callbacks = result.scalars().all()
if stored_callbacks:
callbacks = [EventCallback(**row2dict(cb)) for cb in stored_callbacks]
await asyncio.gather(
*[
self.execute_callback(conversation_id, callback, event)
for callback in callbacks
]
)
await self.db_session.commit()
async def execute_callback(
self, conversation_id: UUID, callback: EventCallback, event: Event
):
try:
result = await callback.processor(conversation_id, callback, event)
stored_result = StoredEventCallbackResult(**row2dict(result))
except Exception as exc:
_logger.exception(f'Exception in callback {callback.id}', stack_info=True)
stored_result = StoredEventCallbackResult(
status=EventCallbackResultStatus.ERROR,
event_callback_id=callback.id,
event_id=event.id,
conversation_id=conversation_id,
detail=str(exc),
)
self.db_session.add(stored_result)
async def __aexit__(self, exc_type, exc_value, traceback):
"""Stop using this event callback service."""
pass
class SQLEventCallbackServiceInjector(EventCallbackServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[EventCallbackService, None]:
from openhands.app_server.config import get_db_session
async with get_db_session(state) as db_session:
yield SQLEventCallbackService(db_session=db_session)