Fix V1 callbacks (#11654)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell 2025-11-06 16:05:58 -07:00 committed by GitHub
parent b678d548c2
commit ddf58da995
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 574 additions and 5 deletions

View File

@ -0,0 +1,71 @@
"""add status and updated_at to callback
Revision ID: 080
Revises: 079
Create Date: 2025-11-05 00:00:00.000000
"""
from enum import Enum
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '080'
down_revision: Union[str, None] = '079'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
class EventCallbackStatus(Enum):
ACTIVE = 'ACTIVE'
DISABLED = 'DISABLED'
COMPLETED = 'COMPLETED'
ERROR = 'ERROR'
def upgrade() -> None:
"""Upgrade schema."""
status = sa.Enum(EventCallbackStatus, name='eventcallbackstatus')
status.create(op.get_bind(), checkfirst=True)
op.add_column(
'event_callback',
sa.Column('status', status, nullable=False, server_default='ACTIVE'),
)
op.add_column(
'event_callback',
sa.Column(
'updated_at', sa.DateTime, nullable=False, server_default=sa.func.now()
),
)
op.drop_index('ix_event_callback_result_event_id')
op.drop_column('event_callback_result', 'event_id')
op.add_column(
'event_callback_result', sa.Column('event_id', sa.String, nullable=True)
)
op.create_index(
op.f('ix_event_callback_result_event_id'),
'event_callback_result',
['event_id'],
unique=False,
)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_column('event_callback', 'status')
op.drop_column('event_callback', 'updated_at')
op.drop_index('ix_event_callback_result_event_id')
op.drop_column('event_callback_result', 'event_id')
op.add_column(
'event_callback_result', sa.Column('event_id', sa.UUID, nullable=True)
)
op.create_index(
op.f('ix_event_callback_result_event_id'),
'event_callback_result',
['event_id'],
unique=False,
)
op.execute('DROP TYPE eventcallbackstatus')

View File

@ -88,7 +88,7 @@ class AppConversationStartRequest(BaseModel):
sandbox_id: str | None = Field(default=None)
initial_message: SendMessageRequest | None = None
processors: list[EventCallbackProcessor] = Field(default_factory=list)
processors: list[EventCallbackProcessor] | None = Field(default=None)
llm_model: str | None = None
# Git parameters

View File

@ -42,7 +42,15 @@ from openhands.app_server.app_conversation.git_app_conversation_service import (
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
)
from openhands.app_server.config import get_event_callback_service
from openhands.app_server.errors import SandboxError
from openhands.app_server.event_callback.event_callback_models import EventCallback
from openhands.app_server.event_callback.event_callback_service import (
EventCallbackService,
)
from openhands.app_server.event_callback.set_title_callback_processor import (
SetTitleCallbackProcessor,
)
from openhands.app_server.sandbox.docker_sandbox_service import DockerSandboxService
from openhands.app_server.sandbox.sandbox_models import (
AGENT_SERVER,
@ -75,6 +83,7 @@ class LiveStatusAppConversationService(GitAppConversationService):
user_context: UserContext
app_conversation_info_service: AppConversationInfoService
app_conversation_start_task_service: AppConversationStartTaskService
event_callback_service: EventCallbackService
sandbox_service: SandboxService
sandbox_spec_service: SandboxSpecService
jwt_service: JwtService
@ -221,7 +230,6 @@ class LiveStatusAppConversationService(GitAppConversationService):
user_id = await self.user_context.get_user_id()
app_conversation_info = AppConversationInfo(
id=info.id,
# TODO: As of writing, StartConversationRequest from AgentServer does not have a title
title=f'Conversation {info.id.hex}',
sandbox_id=sandbox.id,
created_by_user_id=user_id,
@ -237,6 +245,24 @@ class LiveStatusAppConversationService(GitAppConversationService):
app_conversation_info
)
# Setup default processors
processors = request.processors
if processors is None:
processors = [SetTitleCallbackProcessor()]
# Save processors
await asyncio.gather(
*[
self.event_callback_service.save_event_callback(
EventCallback(
conversation_id=info.id,
processor=processor,
)
)
for processor in processors
]
)
# Update the start task
task.status = AppConversationStartTaskStatus.READY
task.app_conversation_id = info.id
@ -673,6 +699,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
get_app_conversation_start_task_service(
state, request
) as app_conversation_start_task_service,
get_event_callback_service(state, request) as event_callback_service,
get_jwt_service(state, request) as jwt_service,
get_httpx_client(state, request) as httpx_client,
):
@ -696,6 +723,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
sandbox_spec_service=sandbox_spec_service,
app_conversation_info_service=app_conversation_info_service,
app_conversation_start_task_service=app_conversation_start_task_service,
event_callback_service=event_callback_service,
jwt_service=jwt_service,
sandbox_startup_timeout=self.sandbox_startup_timeout,
sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency,

View File

@ -0,0 +1,73 @@
"""Sync DB with Models
Revision ID: 001
Revises:
Create Date: 2025-10-05 11:28:41.772294
"""
from enum import Enum
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '002'
down_revision: Union[str, None] = '001'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
class EventCallbackStatus(Enum):
ACTIVE = 'ACTIVE'
DISABLED = 'DISABLED'
COMPLETED = 'COMPLETED'
ERROR = 'ERROR'
def upgrade() -> None:
"""Upgrade schema."""
op.add_column(
'event_callback',
sa.Column(
'status',
sa.Enum(EventCallbackStatus),
nullable=False,
server_default='ACTIVE',
),
)
op.add_column(
'event_callback',
sa.Column(
'updated_at', sa.DateTime, nullable=False, server_default=sa.func.now()
),
)
op.drop_index('ix_event_callback_result_event_id')
op.drop_column('event_callback_result', 'event_id')
op.add_column(
'event_callback_result', sa.Column('event_id', sa.String, nullable=True)
)
op.create_index(
op.f('ix_event_callback_result_event_id'),
'event_callback_result',
['event_id'],
unique=False,
)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_column('event_callback', 'status')
op.drop_column('event_callback', 'updated_at')
op.drop_index('ix_event_callback_result_event_id')
op.drop_column('event_callback_result', 'event_id')
op.add_column(
'event_callback_result', sa.Column('event_id', sa.UUID, nullable=True)
)
op.create_index(
op.f('ix_event_callback_result_event_id'),
'event_callback_result',
['event_id'],
unique=False,
)

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, Literal
from uuid import UUID, uuid4
@ -28,6 +29,13 @@ else:
EventKind = Literal[tuple(c.__name__ for c in get_known_concrete_subclasses(Event))]
class EventCallbackStatus(Enum):
ACTIVE = 'ACTIVE'
DISABLED = 'DISABLED'
COMPLETED = 'COMPLETED'
ERROR = 'ERROR'
class EventCallbackProcessor(DiscriminatedUnionMixin, ABC):
@abstractmethod
async def __call__(
@ -35,7 +43,7 @@ class EventCallbackProcessor(DiscriminatedUnionMixin, ABC):
conversation_id: UUID,
callback: EventCallback,
event: Event,
) -> EventCallbackResult:
) -> EventCallbackResult | None:
"""Process an event."""
@ -75,7 +83,9 @@ class CreateEventCallbackRequest(OpenHandsModel):
class EventCallback(CreateEventCallbackRequest):
id: OpenHandsUUID = Field(default_factory=uuid4)
status: EventCallbackStatus = Field(default=EventCallbackStatus.ACTIVE)
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now)
class EventCallbackPage(OpenHandsModel):

View File

@ -53,6 +53,10 @@ class EventCallbackService(ABC):
)
return results
@abstractmethod
async def save_event_callback(self, event_callback: EventCallback) -> EventCallback:
"""Update the event callback given."""
@abstractmethod
async def execute_callbacks(self, conversation_id: UUID, event: Event) -> None:
"""Execute any applicable callbacks for the event and store the results."""

View File

@ -0,0 +1,85 @@
import logging
from uuid import UUID
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.event_callback.event_callback_models import (
EventCallback,
EventCallbackProcessor,
EventCallbackStatus,
)
from openhands.app_server.event_callback.event_callback_result_models import (
EventCallbackResult,
EventCallbackResultStatus,
)
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import ADMIN, USER_CONTEXT_ATTR
from openhands.sdk import Event, MessageEvent
_logger = logging.getLogger(__name__)
class SetTitleCallbackProcessor(EventCallbackProcessor):
"""Callback processor which sets conversation titles."""
async def __call__(
self,
conversation_id: UUID,
callback: EventCallback,
event: Event,
) -> EventCallbackResult | None:
if not isinstance(event, MessageEvent):
return None
from openhands.app_server.config import (
get_app_conversation_info_service,
get_app_conversation_service,
get_event_callback_service,
get_httpx_client,
)
_logger.info(f'Callback {callback.id} Invoked for event {event}')
state = InjectorState()
setattr(state, USER_CONTEXT_ATTR, ADMIN)
async with (
get_event_callback_service(state) as event_callback_service,
get_app_conversation_service(state) as app_conversation_service,
get_app_conversation_info_service(state) as app_conversation_info_service,
get_httpx_client(state) as httpx_client,
):
# Generate a title for the conversation
app_conversation = await app_conversation_service.get_app_conversation(
conversation_id
)
assert app_conversation is not None
response = await httpx_client.post(
f'{app_conversation.conversation_url}/generate_title',
headers={
'X-Session-API-Key': app_conversation.session_api_key,
},
content='{}',
)
response.raise_for_status()
title = response.json()['title']
# Save the conversation info
info = AppConversationInfo(
**{
name: getattr(app_conversation, name)
for name in AppConversationInfo.model_fields
}
)
info.title = title
await app_conversation_info_service.save_app_conversation_info(info)
# Disable callback - we have already set the status
callback.status = EventCallbackStatus.DISABLED
await event_callback_service.save_event_callback(callback)
return EventCallbackResult(
status=EventCallbackResultStatus.SUCCESS,
event_callback_id=callback.id,
event_id=event.id,
conversation_id=conversation_id,
)

View File

@ -6,6 +6,7 @@ from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import AsyncGenerator
from uuid import UUID
@ -19,6 +20,7 @@ from openhands.app_server.event_callback.event_callback_models import (
EventCallback,
EventCallbackPage,
EventCallbackProcessor,
EventCallbackStatus,
EventKind,
)
from openhands.app_server.event_callback.event_callback_result_models import (
@ -46,9 +48,13 @@ class StoredEventCallback(Base): # type: ignore
__tablename__ = 'event_callback'
id = Column(SQLUUID, primary_key=True)
conversation_id = Column(SQLUUID, nullable=True)
status = Column(
Enum(EventCallbackStatus), nullable=False, default=EventCallbackStatus.ACTIVE
)
processor = Column(create_json_type_decorator(EventCallbackProcessor))
event_kind = Column(String, nullable=True)
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
updated_at = Column(UtcDateTime, server_default=func.now(), index=True)
class StoredEventCallbackResult(Base): # type: ignore
@ -56,7 +62,7 @@ class StoredEventCallbackResult(Base): # type: ignore
id = Column(SQLUUID, primary_key=True)
status = Column(Enum(EventCallbackResultStatus), nullable=True)
event_callback_id = Column(SQLUUID, index=True)
event_id = Column(SQLUUID, index=True)
event_id = Column(String, index=True)
conversation_id = Column(SQLUUID, index=True)
detail = Column(String, nullable=True)
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
@ -170,9 +176,16 @@ class SQLEventCallbackService(EventCallbackService):
callbacks = [EventCallback(**row2dict(cb)) for cb in stored_callbacks]
return EventCallbackPage(items=callbacks, next_page_id=next_page_id)
async def save_event_callback(self, event_callback: EventCallback) -> EventCallback:
event_callback.updated_at = datetime.now()
stored_callback = StoredEventCallback(**event_callback.model_dump())
await self.db_session.merge(stored_callback)
return event_callback
async def execute_callbacks(self, conversation_id: UUID, event: Event) -> None:
query = (
select(StoredEventCallback)
.where(StoredEventCallback.status == EventCallbackStatus.ACTIVE)
.where(
or_(
StoredEventCallback.event_kind == event.kind,
@ -203,7 +216,9 @@ class SQLEventCallbackService(EventCallbackService):
):
try:
result = await callback.processor(conversation_id, callback, event)
stored_result = StoredEventCallbackResult(**row2dict(result))
if result is None:
return
stored_result = StoredEventCallbackResult(**result.model_dump())
except Exception as exc:
_logger.exception(f'Exception in callback {callback.id}', stack_info=True)
stored_result = StoredEventCallbackResult(

View File

@ -372,3 +372,284 @@ class TestSQLEventCallbackService:
assert len(result.items) == 2
assert result.items[0].id == callback2.id
assert result.items[1].id == callback1.id
async def test_save_event_callback_new(
self,
service: SQLEventCallbackService,
sample_callback: EventCallback,
):
"""Test saving a new event callback (insert scenario)."""
# Save the callback
original_updated_at = sample_callback.updated_at
saved_callback = await service.save_event_callback(sample_callback)
# Verify the returned callback
assert saved_callback.id == sample_callback.id
assert saved_callback.conversation_id == sample_callback.conversation_id
assert saved_callback.processor == sample_callback.processor
assert saved_callback.event_kind == sample_callback.event_kind
assert saved_callback.status == sample_callback.status
# Verify updated_at was changed (handle timezone differences)
# Convert both to UTC for comparison if needed
original_utc = (
original_updated_at.replace(tzinfo=timezone.utc)
if original_updated_at.tzinfo is None
else original_updated_at
)
saved_utc = (
saved_callback.updated_at.replace(tzinfo=timezone.utc)
if saved_callback.updated_at.tzinfo is None
else saved_callback.updated_at
)
assert saved_utc >= original_utc
# Commit the transaction to persist changes
await service.db_session.commit()
# Verify the callback can be retrieved
retrieved_callback = await service.get_event_callback(sample_callback.id)
assert retrieved_callback is not None
assert retrieved_callback.id == sample_callback.id
assert retrieved_callback.conversation_id == sample_callback.conversation_id
assert retrieved_callback.event_kind == sample_callback.event_kind
async def test_save_event_callback_update_existing(
self,
service: SQLEventCallbackService,
sample_request: CreateEventCallbackRequest,
):
"""Test saving an existing event callback (update scenario)."""
# First create a callback through the service
created_callback = await service.create_event_callback(sample_request)
original_updated_at = created_callback.updated_at
# Modify the callback
created_callback.event_kind = 'ObservationEvent'
from openhands.app_server.event_callback.event_callback_models import (
EventCallbackStatus,
)
created_callback.status = EventCallbackStatus.DISABLED
# Save the modified callback
saved_callback = await service.save_event_callback(created_callback)
# Verify the returned callback has the modifications
assert saved_callback.id == created_callback.id
assert saved_callback.event_kind == 'ObservationEvent'
assert saved_callback.status == EventCallbackStatus.DISABLED
# Verify updated_at was changed (handle timezone differences)
original_utc = (
original_updated_at.replace(tzinfo=timezone.utc)
if original_updated_at.tzinfo is None
else original_updated_at
)
saved_utc = (
saved_callback.updated_at.replace(tzinfo=timezone.utc)
if saved_callback.updated_at.tzinfo is None
else saved_callback.updated_at
)
assert saved_utc >= original_utc
# Commit the transaction to persist changes
await service.db_session.commit()
# Verify the changes were persisted
retrieved_callback = await service.get_event_callback(created_callback.id)
assert retrieved_callback is not None
assert retrieved_callback.event_kind == 'ObservationEvent'
assert retrieved_callback.status == EventCallbackStatus.DISABLED
async def test_save_event_callback_timestamp_update(
self,
service: SQLEventCallbackService,
sample_callback: EventCallback,
):
"""Test that save_event_callback properly updates the timestamp."""
# Record the original timestamp
original_updated_at = sample_callback.updated_at
# Wait a small amount to ensure timestamp difference
import asyncio
await asyncio.sleep(0.01)
# Save the callback
saved_callback = await service.save_event_callback(sample_callback)
# Verify updated_at was changed and is more recent (handle timezone differences)
original_utc = (
original_updated_at.replace(tzinfo=timezone.utc)
if original_updated_at.tzinfo is None
else original_updated_at
)
saved_utc = (
saved_callback.updated_at.replace(tzinfo=timezone.utc)
if saved_callback.updated_at.tzinfo is None
else saved_callback.updated_at
)
assert saved_utc >= original_utc
assert isinstance(saved_callback.updated_at, datetime)
# Verify the timestamp is recent (within last minute)
now = datetime.now(timezone.utc)
time_diff = now - saved_utc
assert time_diff.total_seconds() < 60
async def test_save_event_callback_with_null_values(
self,
service: SQLEventCallbackService,
sample_processor: EventCallbackProcessor,
):
"""Test saving a callback with null conversation_id and event_kind."""
# Create a callback with null values
callback = EventCallback(
conversation_id=None,
processor=sample_processor,
event_kind=None,
)
# Save the callback
saved_callback = await service.save_event_callback(callback)
# Verify the callback was saved correctly
assert saved_callback.id == callback.id
assert saved_callback.conversation_id is None
assert saved_callback.event_kind is None
assert saved_callback.processor == sample_processor
# Commit and verify persistence
await service.db_session.commit()
retrieved_callback = await service.get_event_callback(callback.id)
assert retrieved_callback is not None
assert retrieved_callback.conversation_id is None
assert retrieved_callback.event_kind is None
async def test_save_event_callback_preserves_created_at(
self,
service: SQLEventCallbackService,
sample_request: CreateEventCallbackRequest,
):
"""Test that save_event_callback preserves the original created_at timestamp."""
# Create a callback through the service
created_callback = await service.create_event_callback(sample_request)
original_created_at = created_callback.created_at
# Wait a small amount to ensure timestamp difference
import asyncio
await asyncio.sleep(0.01)
# Save the callback again
saved_callback = await service.save_event_callback(created_callback)
# Verify created_at was preserved but updated_at was changed
assert saved_callback.created_at == original_created_at
# Handle timezone differences for comparison
created_utc = (
original_created_at.replace(tzinfo=timezone.utc)
if original_created_at.tzinfo is None
else original_created_at
)
updated_utc = (
saved_callback.updated_at.replace(tzinfo=timezone.utc)
if saved_callback.updated_at.tzinfo is None
else saved_callback.updated_at
)
assert updated_utc >= created_utc
async def test_save_event_callback_different_statuses(
self,
service: SQLEventCallbackService,
sample_processor: EventCallbackProcessor,
):
"""Test saving callbacks with different status values."""
from openhands.app_server.event_callback.event_callback_models import (
EventCallbackStatus,
)
# Test each status
statuses = [
EventCallbackStatus.ACTIVE,
EventCallbackStatus.DISABLED,
EventCallbackStatus.COMPLETED,
EventCallbackStatus.ERROR,
]
for status in statuses:
callback = EventCallback(
conversation_id=uuid4(),
processor=sample_processor,
event_kind='ActionEvent',
status=status,
)
# Save the callback
saved_callback = await service.save_event_callback(callback)
# Verify the status was preserved
assert saved_callback.status == status
# Commit and verify persistence
await service.db_session.commit()
retrieved_callback = await service.get_event_callback(callback.id)
assert retrieved_callback is not None
assert retrieved_callback.status == status
async def test_save_event_callback_returns_same_object(
self,
service: SQLEventCallbackService,
sample_callback: EventCallback,
):
"""Test that save_event_callback returns the same object instance."""
# Save the callback
saved_callback = await service.save_event_callback(sample_callback)
# Verify it's the same object (identity check)
assert saved_callback is sample_callback
# But verify the updated_at was modified on the original object
assert sample_callback.updated_at == saved_callback.updated_at
async def test_save_event_callback_multiple_saves(
self,
service: SQLEventCallbackService,
sample_callback: EventCallback,
):
"""Test saving the same callback multiple times."""
# Save the callback multiple times
first_save = await service.save_event_callback(sample_callback)
first_updated_at = first_save.updated_at
# Wait a small amount to ensure timestamp difference
import asyncio
await asyncio.sleep(0.01)
second_save = await service.save_event_callback(sample_callback)
second_updated_at = second_save.updated_at
# Verify timestamps are different (handle timezone differences)
first_utc = (
first_updated_at.replace(tzinfo=timezone.utc)
if first_updated_at.tzinfo is None
else first_updated_at
)
second_utc = (
second_updated_at.replace(tzinfo=timezone.utc)
if second_updated_at.tzinfo is None
else second_updated_at
)
assert second_utc >= first_utc
# Verify it's still the same callback
assert first_save.id == second_save.id
assert first_save is second_save # Same object instance
# Commit and verify only one record exists
await service.db_session.commit()
retrieved_callback = await service.get_event_callback(sample_callback.id)
assert retrieved_callback is not None
assert retrieved_callback.id == sample_callback.id

View File

@ -169,6 +169,7 @@ class TestExperimentManagerIntegration:
# The service requires a lot of deps, but for this test we won't exercise them.
app_conversation_info_service = Mock()
app_conversation_start_task_service = Mock()
event_callback_service = Mock()
sandbox_service = Mock()
sandbox_spec_service = Mock()
jwt_service = Mock()
@ -179,6 +180,7 @@ class TestExperimentManagerIntegration:
user_context=user_context,
app_conversation_info_service=app_conversation_info_service,
app_conversation_start_task_service=app_conversation_start_task_service,
event_callback_service=event_callback_service,
sandbox_service=sandbox_service,
sandbox_spec_service=sandbox_spec_service,
jwt_service=jwt_service,