mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Fix V1 callbacks (#11654)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
b678d548c2
commit
ddf58da995
@ -0,0 +1,71 @@
|
||||
"""add status and updated_at to callback
|
||||
|
||||
Revision ID: 080
|
||||
Revises: 079
|
||||
Create Date: 2025-11-05 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '080'
|
||||
down_revision: Union[str, None] = '079'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
class EventCallbackStatus(Enum):
|
||||
ACTIVE = 'ACTIVE'
|
||||
DISABLED = 'DISABLED'
|
||||
COMPLETED = 'COMPLETED'
|
||||
ERROR = 'ERROR'
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
status = sa.Enum(EventCallbackStatus, name='eventcallbackstatus')
|
||||
status.create(op.get_bind(), checkfirst=True)
|
||||
op.add_column(
|
||||
'event_callback',
|
||||
sa.Column('status', status, nullable=False, server_default='ACTIVE'),
|
||||
)
|
||||
op.add_column(
|
||||
'event_callback',
|
||||
sa.Column(
|
||||
'updated_at', sa.DateTime, nullable=False, server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
op.drop_index('ix_event_callback_result_event_id')
|
||||
op.drop_column('event_callback_result', 'event_id')
|
||||
op.add_column(
|
||||
'event_callback_result', sa.Column('event_id', sa.String, nullable=True)
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_event_callback_result_event_id'),
|
||||
'event_callback_result',
|
||||
['event_id'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_column('event_callback', 'status')
|
||||
op.drop_column('event_callback', 'updated_at')
|
||||
op.drop_index('ix_event_callback_result_event_id')
|
||||
op.drop_column('event_callback_result', 'event_id')
|
||||
op.add_column(
|
||||
'event_callback_result', sa.Column('event_id', sa.UUID, nullable=True)
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_event_callback_result_event_id'),
|
||||
'event_callback_result',
|
||||
['event_id'],
|
||||
unique=False,
|
||||
)
|
||||
op.execute('DROP TYPE eventcallbackstatus')
|
||||
@ -88,7 +88,7 @@ class AppConversationStartRequest(BaseModel):
|
||||
|
||||
sandbox_id: str | None = Field(default=None)
|
||||
initial_message: SendMessageRequest | None = None
|
||||
processors: list[EventCallbackProcessor] = Field(default_factory=list)
|
||||
processors: list[EventCallbackProcessor] | None = Field(default=None)
|
||||
llm_model: str | None = None
|
||||
|
||||
# Git parameters
|
||||
|
||||
@ -42,7 +42,15 @@ from openhands.app_server.app_conversation.git_app_conversation_service import (
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.config import get_event_callback_service
|
||||
from openhands.app_server.errors import SandboxError
|
||||
from openhands.app_server.event_callback.event_callback_models import EventCallback
|
||||
from openhands.app_server.event_callback.event_callback_service import (
|
||||
EventCallbackService,
|
||||
)
|
||||
from openhands.app_server.event_callback.set_title_callback_processor import (
|
||||
SetTitleCallbackProcessor,
|
||||
)
|
||||
from openhands.app_server.sandbox.docker_sandbox_service import DockerSandboxService
|
||||
from openhands.app_server.sandbox.sandbox_models import (
|
||||
AGENT_SERVER,
|
||||
@ -75,6 +83,7 @@ class LiveStatusAppConversationService(GitAppConversationService):
|
||||
user_context: UserContext
|
||||
app_conversation_info_service: AppConversationInfoService
|
||||
app_conversation_start_task_service: AppConversationStartTaskService
|
||||
event_callback_service: EventCallbackService
|
||||
sandbox_service: SandboxService
|
||||
sandbox_spec_service: SandboxSpecService
|
||||
jwt_service: JwtService
|
||||
@ -221,7 +230,6 @@ class LiveStatusAppConversationService(GitAppConversationService):
|
||||
user_id = await self.user_context.get_user_id()
|
||||
app_conversation_info = AppConversationInfo(
|
||||
id=info.id,
|
||||
# TODO: As of writing, StartConversationRequest from AgentServer does not have a title
|
||||
title=f'Conversation {info.id.hex}',
|
||||
sandbox_id=sandbox.id,
|
||||
created_by_user_id=user_id,
|
||||
@ -237,6 +245,24 @@ class LiveStatusAppConversationService(GitAppConversationService):
|
||||
app_conversation_info
|
||||
)
|
||||
|
||||
# Setup default processors
|
||||
processors = request.processors
|
||||
if processors is None:
|
||||
processors = [SetTitleCallbackProcessor()]
|
||||
|
||||
# Save processors
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.event_callback_service.save_event_callback(
|
||||
EventCallback(
|
||||
conversation_id=info.id,
|
||||
processor=processor,
|
||||
)
|
||||
)
|
||||
for processor in processors
|
||||
]
|
||||
)
|
||||
|
||||
# Update the start task
|
||||
task.status = AppConversationStartTaskStatus.READY
|
||||
task.app_conversation_id = info.id
|
||||
@ -673,6 +699,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
||||
get_app_conversation_start_task_service(
|
||||
state, request
|
||||
) as app_conversation_start_task_service,
|
||||
get_event_callback_service(state, request) as event_callback_service,
|
||||
get_jwt_service(state, request) as jwt_service,
|
||||
get_httpx_client(state, request) as httpx_client,
|
||||
):
|
||||
@ -696,6 +723,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
||||
sandbox_spec_service=sandbox_spec_service,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
app_conversation_start_task_service=app_conversation_start_task_service,
|
||||
event_callback_service=event_callback_service,
|
||||
jwt_service=jwt_service,
|
||||
sandbox_startup_timeout=self.sandbox_startup_timeout,
|
||||
sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency,
|
||||
|
||||
73
openhands/app_server/app_lifespan/alembic/versions/002.py
Normal file
73
openhands/app_server/app_lifespan/alembic/versions/002.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""Sync DB with Models
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2025-10-05 11:28:41.772294
|
||||
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '002'
|
||||
down_revision: Union[str, None] = '001'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
class EventCallbackStatus(Enum):
|
||||
ACTIVE = 'ACTIVE'
|
||||
DISABLED = 'DISABLED'
|
||||
COMPLETED = 'COMPLETED'
|
||||
ERROR = 'ERROR'
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.add_column(
|
||||
'event_callback',
|
||||
sa.Column(
|
||||
'status',
|
||||
sa.Enum(EventCallbackStatus),
|
||||
nullable=False,
|
||||
server_default='ACTIVE',
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
'event_callback',
|
||||
sa.Column(
|
||||
'updated_at', sa.DateTime, nullable=False, server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
op.drop_index('ix_event_callback_result_event_id')
|
||||
op.drop_column('event_callback_result', 'event_id')
|
||||
op.add_column(
|
||||
'event_callback_result', sa.Column('event_id', sa.String, nullable=True)
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_event_callback_result_event_id'),
|
||||
'event_callback_result',
|
||||
['event_id'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_column('event_callback', 'status')
|
||||
op.drop_column('event_callback', 'updated_at')
|
||||
op.drop_index('ix_event_callback_result_event_id')
|
||||
op.drop_column('event_callback_result', 'event_id')
|
||||
op.add_column(
|
||||
'event_callback_result', sa.Column('event_id', sa.UUID, nullable=True)
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_event_callback_result_event_id'),
|
||||
'event_callback_result',
|
||||
['event_id'],
|
||||
unique=False,
|
||||
)
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@ -28,6 +29,13 @@ else:
|
||||
EventKind = Literal[tuple(c.__name__ for c in get_known_concrete_subclasses(Event))]
|
||||
|
||||
|
||||
class EventCallbackStatus(Enum):
|
||||
ACTIVE = 'ACTIVE'
|
||||
DISABLED = 'DISABLED'
|
||||
COMPLETED = 'COMPLETED'
|
||||
ERROR = 'ERROR'
|
||||
|
||||
|
||||
class EventCallbackProcessor(DiscriminatedUnionMixin, ABC):
|
||||
@abstractmethod
|
||||
async def __call__(
|
||||
@ -35,7 +43,7 @@ class EventCallbackProcessor(DiscriminatedUnionMixin, ABC):
|
||||
conversation_id: UUID,
|
||||
callback: EventCallback,
|
||||
event: Event,
|
||||
) -> EventCallbackResult:
|
||||
) -> EventCallbackResult | None:
|
||||
"""Process an event."""
|
||||
|
||||
|
||||
@ -75,7 +83,9 @@ class CreateEventCallbackRequest(OpenHandsModel):
|
||||
|
||||
class EventCallback(CreateEventCallbackRequest):
|
||||
id: OpenHandsUUID = Field(default_factory=uuid4)
|
||||
status: EventCallbackStatus = Field(default=EventCallbackStatus.ACTIVE)
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
updated_at: datetime = Field(default_factory=utc_now)
|
||||
|
||||
|
||||
class EventCallbackPage(OpenHandsModel):
|
||||
|
||||
@ -53,6 +53,10 @@ class EventCallbackService(ABC):
|
||||
)
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
async def save_event_callback(self, event_callback: EventCallback) -> EventCallback:
|
||||
"""Update the event callback given."""
|
||||
|
||||
@abstractmethod
|
||||
async def execute_callbacks(self, conversation_id: UUID, event: Event) -> None:
|
||||
"""Execute any applicable callbacks for the event and store the results."""
|
||||
|
||||
@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.event_callback.event_callback_models import (
|
||||
EventCallback,
|
||||
EventCallbackProcessor,
|
||||
EventCallbackStatus,
|
||||
)
|
||||
from openhands.app_server.event_callback.event_callback_result_models import (
|
||||
EventCallbackResult,
|
||||
EventCallbackResultStatus,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN, USER_CONTEXT_ATTR
|
||||
from openhands.sdk import Event, MessageEvent
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SetTitleCallbackProcessor(EventCallbackProcessor):
|
||||
"""Callback processor which sets conversation titles."""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
callback: EventCallback,
|
||||
event: Event,
|
||||
) -> EventCallbackResult | None:
|
||||
if not isinstance(event, MessageEvent):
|
||||
return None
|
||||
from openhands.app_server.config import (
|
||||
get_app_conversation_info_service,
|
||||
get_app_conversation_service,
|
||||
get_event_callback_service,
|
||||
get_httpx_client,
|
||||
)
|
||||
|
||||
_logger.info(f'Callback {callback.id} Invoked for event {event}')
|
||||
|
||||
state = InjectorState()
|
||||
setattr(state, USER_CONTEXT_ATTR, ADMIN)
|
||||
async with (
|
||||
get_event_callback_service(state) as event_callback_service,
|
||||
get_app_conversation_service(state) as app_conversation_service,
|
||||
get_app_conversation_info_service(state) as app_conversation_info_service,
|
||||
get_httpx_client(state) as httpx_client,
|
||||
):
|
||||
# Generate a title for the conversation
|
||||
app_conversation = await app_conversation_service.get_app_conversation(
|
||||
conversation_id
|
||||
)
|
||||
assert app_conversation is not None
|
||||
response = await httpx_client.post(
|
||||
f'{app_conversation.conversation_url}/generate_title',
|
||||
headers={
|
||||
'X-Session-API-Key': app_conversation.session_api_key,
|
||||
},
|
||||
content='{}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
title = response.json()['title']
|
||||
|
||||
# Save the conversation info
|
||||
info = AppConversationInfo(
|
||||
**{
|
||||
name: getattr(app_conversation, name)
|
||||
for name in AppConversationInfo.model_fields
|
||||
}
|
||||
)
|
||||
info.title = title
|
||||
await app_conversation_info_service.save_app_conversation_info(info)
|
||||
|
||||
# Disable callback - we have already set the status
|
||||
callback.status = EventCallbackStatus.DISABLED
|
||||
await event_callback_service.save_event_callback(callback)
|
||||
|
||||
return EventCallbackResult(
|
||||
status=EventCallbackResultStatus.SUCCESS,
|
||||
event_callback_id=callback.id,
|
||||
event_id=event.id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
@ -6,6 +6,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
@ -19,6 +20,7 @@ from openhands.app_server.event_callback.event_callback_models import (
|
||||
EventCallback,
|
||||
EventCallbackPage,
|
||||
EventCallbackProcessor,
|
||||
EventCallbackStatus,
|
||||
EventKind,
|
||||
)
|
||||
from openhands.app_server.event_callback.event_callback_result_models import (
|
||||
@ -46,9 +48,13 @@ class StoredEventCallback(Base): # type: ignore
|
||||
__tablename__ = 'event_callback'
|
||||
id = Column(SQLUUID, primary_key=True)
|
||||
conversation_id = Column(SQLUUID, nullable=True)
|
||||
status = Column(
|
||||
Enum(EventCallbackStatus), nullable=False, default=EventCallbackStatus.ACTIVE
|
||||
)
|
||||
processor = Column(create_json_type_decorator(EventCallbackProcessor))
|
||||
event_kind = Column(String, nullable=True)
|
||||
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
|
||||
updated_at = Column(UtcDateTime, server_default=func.now(), index=True)
|
||||
|
||||
|
||||
class StoredEventCallbackResult(Base): # type: ignore
|
||||
@ -56,7 +62,7 @@ class StoredEventCallbackResult(Base): # type: ignore
|
||||
id = Column(SQLUUID, primary_key=True)
|
||||
status = Column(Enum(EventCallbackResultStatus), nullable=True)
|
||||
event_callback_id = Column(SQLUUID, index=True)
|
||||
event_id = Column(SQLUUID, index=True)
|
||||
event_id = Column(String, index=True)
|
||||
conversation_id = Column(SQLUUID, index=True)
|
||||
detail = Column(String, nullable=True)
|
||||
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
|
||||
@ -170,9 +176,16 @@ class SQLEventCallbackService(EventCallbackService):
|
||||
callbacks = [EventCallback(**row2dict(cb)) for cb in stored_callbacks]
|
||||
return EventCallbackPage(items=callbacks, next_page_id=next_page_id)
|
||||
|
||||
async def save_event_callback(self, event_callback: EventCallback) -> EventCallback:
|
||||
event_callback.updated_at = datetime.now()
|
||||
stored_callback = StoredEventCallback(**event_callback.model_dump())
|
||||
await self.db_session.merge(stored_callback)
|
||||
return event_callback
|
||||
|
||||
async def execute_callbacks(self, conversation_id: UUID, event: Event) -> None:
|
||||
query = (
|
||||
select(StoredEventCallback)
|
||||
.where(StoredEventCallback.status == EventCallbackStatus.ACTIVE)
|
||||
.where(
|
||||
or_(
|
||||
StoredEventCallback.event_kind == event.kind,
|
||||
@ -203,7 +216,9 @@ class SQLEventCallbackService(EventCallbackService):
|
||||
):
|
||||
try:
|
||||
result = await callback.processor(conversation_id, callback, event)
|
||||
stored_result = StoredEventCallbackResult(**row2dict(result))
|
||||
if result is None:
|
||||
return
|
||||
stored_result = StoredEventCallbackResult(**result.model_dump())
|
||||
except Exception as exc:
|
||||
_logger.exception(f'Exception in callback {callback.id}', stack_info=True)
|
||||
stored_result = StoredEventCallbackResult(
|
||||
|
||||
@ -372,3 +372,284 @@ class TestSQLEventCallbackService:
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].id == callback2.id
|
||||
assert result.items[1].id == callback1.id
|
||||
|
||||
async def test_save_event_callback_new(
|
||||
self,
|
||||
service: SQLEventCallbackService,
|
||||
sample_callback: EventCallback,
|
||||
):
|
||||
"""Test saving a new event callback (insert scenario)."""
|
||||
# Save the callback
|
||||
original_updated_at = sample_callback.updated_at
|
||||
saved_callback = await service.save_event_callback(sample_callback)
|
||||
|
||||
# Verify the returned callback
|
||||
assert saved_callback.id == sample_callback.id
|
||||
assert saved_callback.conversation_id == sample_callback.conversation_id
|
||||
assert saved_callback.processor == sample_callback.processor
|
||||
assert saved_callback.event_kind == sample_callback.event_kind
|
||||
assert saved_callback.status == sample_callback.status
|
||||
|
||||
# Verify updated_at was changed (handle timezone differences)
|
||||
# Convert both to UTC for comparison if needed
|
||||
original_utc = (
|
||||
original_updated_at.replace(tzinfo=timezone.utc)
|
||||
if original_updated_at.tzinfo is None
|
||||
else original_updated_at
|
||||
)
|
||||
saved_utc = (
|
||||
saved_callback.updated_at.replace(tzinfo=timezone.utc)
|
||||
if saved_callback.updated_at.tzinfo is None
|
||||
else saved_callback.updated_at
|
||||
)
|
||||
assert saved_utc >= original_utc
|
||||
|
||||
# Commit the transaction to persist changes
|
||||
await service.db_session.commit()
|
||||
|
||||
# Verify the callback can be retrieved
|
||||
retrieved_callback = await service.get_event_callback(sample_callback.id)
|
||||
assert retrieved_callback is not None
|
||||
assert retrieved_callback.id == sample_callback.id
|
||||
assert retrieved_callback.conversation_id == sample_callback.conversation_id
|
||||
assert retrieved_callback.event_kind == sample_callback.event_kind
|
||||
|
||||
async def test_save_event_callback_update_existing(
|
||||
self,
|
||||
service: SQLEventCallbackService,
|
||||
sample_request: CreateEventCallbackRequest,
|
||||
):
|
||||
"""Test saving an existing event callback (update scenario)."""
|
||||
# First create a callback through the service
|
||||
created_callback = await service.create_event_callback(sample_request)
|
||||
original_updated_at = created_callback.updated_at
|
||||
|
||||
# Modify the callback
|
||||
created_callback.event_kind = 'ObservationEvent'
|
||||
from openhands.app_server.event_callback.event_callback_models import (
|
||||
EventCallbackStatus,
|
||||
)
|
||||
|
||||
created_callback.status = EventCallbackStatus.DISABLED
|
||||
|
||||
# Save the modified callback
|
||||
saved_callback = await service.save_event_callback(created_callback)
|
||||
|
||||
# Verify the returned callback has the modifications
|
||||
assert saved_callback.id == created_callback.id
|
||||
assert saved_callback.event_kind == 'ObservationEvent'
|
||||
assert saved_callback.status == EventCallbackStatus.DISABLED
|
||||
|
||||
# Verify updated_at was changed (handle timezone differences)
|
||||
original_utc = (
|
||||
original_updated_at.replace(tzinfo=timezone.utc)
|
||||
if original_updated_at.tzinfo is None
|
||||
else original_updated_at
|
||||
)
|
||||
saved_utc = (
|
||||
saved_callback.updated_at.replace(tzinfo=timezone.utc)
|
||||
if saved_callback.updated_at.tzinfo is None
|
||||
else saved_callback.updated_at
|
||||
)
|
||||
assert saved_utc >= original_utc
|
||||
|
||||
# Commit the transaction to persist changes
|
||||
await service.db_session.commit()
|
||||
|
||||
# Verify the changes were persisted
|
||||
retrieved_callback = await service.get_event_callback(created_callback.id)
|
||||
assert retrieved_callback is not None
|
||||
assert retrieved_callback.event_kind == 'ObservationEvent'
|
||||
assert retrieved_callback.status == EventCallbackStatus.DISABLED
|
||||
|
||||
async def test_save_event_callback_timestamp_update(
|
||||
self,
|
||||
service: SQLEventCallbackService,
|
||||
sample_callback: EventCallback,
|
||||
):
|
||||
"""Test that save_event_callback properly updates the timestamp."""
|
||||
# Record the original timestamp
|
||||
original_updated_at = sample_callback.updated_at
|
||||
|
||||
# Wait a small amount to ensure timestamp difference
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Save the callback
|
||||
saved_callback = await service.save_event_callback(sample_callback)
|
||||
|
||||
# Verify updated_at was changed and is more recent (handle timezone differences)
|
||||
original_utc = (
|
||||
original_updated_at.replace(tzinfo=timezone.utc)
|
||||
if original_updated_at.tzinfo is None
|
||||
else original_updated_at
|
||||
)
|
||||
saved_utc = (
|
||||
saved_callback.updated_at.replace(tzinfo=timezone.utc)
|
||||
if saved_callback.updated_at.tzinfo is None
|
||||
else saved_callback.updated_at
|
||||
)
|
||||
assert saved_utc >= original_utc
|
||||
assert isinstance(saved_callback.updated_at, datetime)
|
||||
|
||||
# Verify the timestamp is recent (within last minute)
|
||||
now = datetime.now(timezone.utc)
|
||||
time_diff = now - saved_utc
|
||||
assert time_diff.total_seconds() < 60
|
||||
|
||||
async def test_save_event_callback_with_null_values(
|
||||
self,
|
||||
service: SQLEventCallbackService,
|
||||
sample_processor: EventCallbackProcessor,
|
||||
):
|
||||
"""Test saving a callback with null conversation_id and event_kind."""
|
||||
# Create a callback with null values
|
||||
callback = EventCallback(
|
||||
conversation_id=None,
|
||||
processor=sample_processor,
|
||||
event_kind=None,
|
||||
)
|
||||
|
||||
# Save the callback
|
||||
saved_callback = await service.save_event_callback(callback)
|
||||
|
||||
# Verify the callback was saved correctly
|
||||
assert saved_callback.id == callback.id
|
||||
assert saved_callback.conversation_id is None
|
||||
assert saved_callback.event_kind is None
|
||||
assert saved_callback.processor == sample_processor
|
||||
|
||||
# Commit and verify persistence
|
||||
await service.db_session.commit()
|
||||
retrieved_callback = await service.get_event_callback(callback.id)
|
||||
assert retrieved_callback is not None
|
||||
assert retrieved_callback.conversation_id is None
|
||||
assert retrieved_callback.event_kind is None
|
||||
|
||||
async def test_save_event_callback_preserves_created_at(
|
||||
self,
|
||||
service: SQLEventCallbackService,
|
||||
sample_request: CreateEventCallbackRequest,
|
||||
):
|
||||
"""Test that save_event_callback preserves the original created_at timestamp."""
|
||||
# Create a callback through the service
|
||||
created_callback = await service.create_event_callback(sample_request)
|
||||
original_created_at = created_callback.created_at
|
||||
|
||||
# Wait a small amount to ensure timestamp difference
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Save the callback again
|
||||
saved_callback = await service.save_event_callback(created_callback)
|
||||
|
||||
# Verify created_at was preserved but updated_at was changed
|
||||
assert saved_callback.created_at == original_created_at
|
||||
# Handle timezone differences for comparison
|
||||
created_utc = (
|
||||
original_created_at.replace(tzinfo=timezone.utc)
|
||||
if original_created_at.tzinfo is None
|
||||
else original_created_at
|
||||
)
|
||||
updated_utc = (
|
||||
saved_callback.updated_at.replace(tzinfo=timezone.utc)
|
||||
if saved_callback.updated_at.tzinfo is None
|
||||
else saved_callback.updated_at
|
||||
)
|
||||
assert updated_utc >= created_utc
|
||||
|
||||
async def test_save_event_callback_different_statuses(
|
||||
self,
|
||||
service: SQLEventCallbackService,
|
||||
sample_processor: EventCallbackProcessor,
|
||||
):
|
||||
"""Test saving callbacks with different status values."""
|
||||
from openhands.app_server.event_callback.event_callback_models import (
|
||||
EventCallbackStatus,
|
||||
)
|
||||
|
||||
# Test each status
|
||||
statuses = [
|
||||
EventCallbackStatus.ACTIVE,
|
||||
EventCallbackStatus.DISABLED,
|
||||
EventCallbackStatus.COMPLETED,
|
||||
EventCallbackStatus.ERROR,
|
||||
]
|
||||
|
||||
for status in statuses:
|
||||
callback = EventCallback(
|
||||
conversation_id=uuid4(),
|
||||
processor=sample_processor,
|
||||
event_kind='ActionEvent',
|
||||
status=status,
|
||||
)
|
||||
|
||||
# Save the callback
|
||||
saved_callback = await service.save_event_callback(callback)
|
||||
|
||||
# Verify the status was preserved
|
||||
assert saved_callback.status == status
|
||||
|
||||
# Commit and verify persistence
|
||||
await service.db_session.commit()
|
||||
retrieved_callback = await service.get_event_callback(callback.id)
|
||||
assert retrieved_callback is not None
|
||||
assert retrieved_callback.status == status
|
||||
|
||||
async def test_save_event_callback_returns_same_object(
|
||||
self,
|
||||
service: SQLEventCallbackService,
|
||||
sample_callback: EventCallback,
|
||||
):
|
||||
"""Test that save_event_callback returns the same object instance."""
|
||||
# Save the callback
|
||||
saved_callback = await service.save_event_callback(sample_callback)
|
||||
|
||||
# Verify it's the same object (identity check)
|
||||
assert saved_callback is sample_callback
|
||||
|
||||
# But verify the updated_at was modified on the original object
|
||||
assert sample_callback.updated_at == saved_callback.updated_at
|
||||
|
||||
async def test_save_event_callback_multiple_saves(
|
||||
self,
|
||||
service: SQLEventCallbackService,
|
||||
sample_callback: EventCallback,
|
||||
):
|
||||
"""Test saving the same callback multiple times."""
|
||||
# Save the callback multiple times
|
||||
first_save = await service.save_event_callback(sample_callback)
|
||||
first_updated_at = first_save.updated_at
|
||||
|
||||
# Wait a small amount to ensure timestamp difference
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
second_save = await service.save_event_callback(sample_callback)
|
||||
second_updated_at = second_save.updated_at
|
||||
|
||||
# Verify timestamps are different (handle timezone differences)
|
||||
first_utc = (
|
||||
first_updated_at.replace(tzinfo=timezone.utc)
|
||||
if first_updated_at.tzinfo is None
|
||||
else first_updated_at
|
||||
)
|
||||
second_utc = (
|
||||
second_updated_at.replace(tzinfo=timezone.utc)
|
||||
if second_updated_at.tzinfo is None
|
||||
else second_updated_at
|
||||
)
|
||||
assert second_utc >= first_utc
|
||||
|
||||
# Verify it's still the same callback
|
||||
assert first_save.id == second_save.id
|
||||
assert first_save is second_save # Same object instance
|
||||
|
||||
# Commit and verify only one record exists
|
||||
await service.db_session.commit()
|
||||
retrieved_callback = await service.get_event_callback(sample_callback.id)
|
||||
assert retrieved_callback is not None
|
||||
assert retrieved_callback.id == sample_callback.id
|
||||
|
||||
@ -169,6 +169,7 @@ class TestExperimentManagerIntegration:
|
||||
# The service requires a lot of deps, but for this test we won't exercise them.
|
||||
app_conversation_info_service = Mock()
|
||||
app_conversation_start_task_service = Mock()
|
||||
event_callback_service = Mock()
|
||||
sandbox_service = Mock()
|
||||
sandbox_spec_service = Mock()
|
||||
jwt_service = Mock()
|
||||
@ -179,6 +180,7 @@ class TestExperimentManagerIntegration:
|
||||
user_context=user_context,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
app_conversation_start_task_service=app_conversation_start_task_service,
|
||||
event_callback_service=event_callback_service,
|
||||
sandbox_service=sandbox_service,
|
||||
sandbox_spec_service=sandbox_spec_service,
|
||||
jwt_service=jwt_service,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user