From ddf58da995d9d8567f57f78c7ce05c21451d13fd Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Thu, 6 Nov 2025 16:05:58 -0700 Subject: [PATCH] Fix V1 callbacks (#11654) Co-authored-by: openhands --- ...0_add_status_and_updated_at_to_callback.py | 71 +++++ .../app_conversation_models.py | 2 +- .../live_status_app_conversation_service.py | 30 +- .../app_lifespan/alembic/versions/002.py | 73 +++++ .../event_callback/event_callback_models.py | 12 +- .../event_callback/event_callback_service.py | 4 + .../set_title_callback_processor.py | 85 ++++++ .../sql_event_callback_service.py | 19 +- .../test_sql_event_callback_service.py | 281 ++++++++++++++++++ .../experiments/test_experiment_manager.py | 2 + 10 files changed, 574 insertions(+), 5 deletions(-) create mode 100644 enterprise/migrations/versions/080_add_status_and_updated_at_to_callback.py create mode 100644 openhands/app_server/app_lifespan/alembic/versions/002.py create mode 100644 openhands/app_server/event_callback/set_title_callback_processor.py diff --git a/enterprise/migrations/versions/080_add_status_and_updated_at_to_callback.py b/enterprise/migrations/versions/080_add_status_and_updated_at_to_callback.py new file mode 100644 index 0000000000..4b461b3098 --- /dev/null +++ b/enterprise/migrations/versions/080_add_status_and_updated_at_to_callback.py @@ -0,0 +1,71 @@ +"""add status and updated_at to callback + +Revision ID: 080 +Revises: 079 +Create Date: 2025-11-05 00:00:00.000000 + +""" + +from enum import Enum +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '080' +down_revision: Union[str, None] = '079' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +class EventCallbackStatus(Enum): + ACTIVE = 'ACTIVE' + DISABLED = 'DISABLED' + COMPLETED = 'COMPLETED' + ERROR = 'ERROR' + + +def upgrade() -> None: + """Upgrade schema.""" + status = sa.Enum(EventCallbackStatus, name='eventcallbackstatus') + status.create(op.get_bind(), checkfirst=True) + op.add_column( + 'event_callback', + sa.Column('status', status, nullable=False, server_default='ACTIVE'), + ) + op.add_column( + 'event_callback', + sa.Column( + 'updated_at', sa.DateTime, nullable=False, server_default=sa.func.now() + ), + ) + op.drop_index('ix_event_callback_result_event_id') + op.drop_column('event_callback_result', 'event_id') + op.add_column( + 'event_callback_result', sa.Column('event_id', sa.String, nullable=True) + ) + op.create_index( + op.f('ix_event_callback_result_event_id'), + 'event_callback_result', + ['event_id'], + unique=False, + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_column('event_callback', 'status') + op.drop_column('event_callback', 'updated_at') + op.drop_index('ix_event_callback_result_event_id') + op.drop_column('event_callback_result', 'event_id') + op.add_column( + 'event_callback_result', sa.Column('event_id', sa.UUID, nullable=True) + ) + op.create_index( + op.f('ix_event_callback_result_event_id'), + 'event_callback_result', + ['event_id'], + unique=False, + ) + op.execute('DROP TYPE eventcallbackstatus') diff --git a/openhands/app_server/app_conversation/app_conversation_models.py b/openhands/app_server/app_conversation/app_conversation_models.py index d4992c7058..1b2f201dcd 100644 --- a/openhands/app_server/app_conversation/app_conversation_models.py +++ b/openhands/app_server/app_conversation/app_conversation_models.py @@ -88,7 +88,7 @@ class AppConversationStartRequest(BaseModel): sandbox_id: str | None = Field(default=None) initial_message: SendMessageRequest | None = None - processors: list[EventCallbackProcessor] = Field(default_factory=list) + processors: list[EventCallbackProcessor] | None = Field(default=None) llm_model: str | None = None # Git parameters diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index bb5040e861..1b2763e279 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -42,7 +42,15 @@ from openhands.app_server.app_conversation.git_app_conversation_service import ( from openhands.app_server.app_conversation.sql_app_conversation_info_service import ( SQLAppConversationInfoService, ) +from openhands.app_server.config import get_event_callback_service from openhands.app_server.errors import SandboxError +from openhands.app_server.event_callback.event_callback_models import EventCallback +from openhands.app_server.event_callback.event_callback_service import ( + EventCallbackService, +) +from openhands.app_server.event_callback.set_title_callback_processor import ( + SetTitleCallbackProcessor, +) from openhands.app_server.sandbox.docker_sandbox_service import DockerSandboxService from openhands.app_server.sandbox.sandbox_models import ( AGENT_SERVER, @@ -75,6 +83,7 @@ class LiveStatusAppConversationService(GitAppConversationService): user_context: UserContext app_conversation_info_service: AppConversationInfoService app_conversation_start_task_service: AppConversationStartTaskService + event_callback_service: EventCallbackService sandbox_service: SandboxService sandbox_spec_service: SandboxSpecService jwt_service: JwtService @@ -221,7 +230,6 @@ class LiveStatusAppConversationService(GitAppConversationService): user_id = await self.user_context.get_user_id() app_conversation_info = AppConversationInfo( id=info.id, - # TODO: As of writing, StartConversationRequest from AgentServer does not have a title title=f'Conversation {info.id.hex}', sandbox_id=sandbox.id, created_by_user_id=user_id, @@ -237,6 +245,24 @@ class LiveStatusAppConversationService(GitAppConversationService): app_conversation_info ) + # Setup default processors + processors = request.processors + if processors is None: + processors = [SetTitleCallbackProcessor()] + + # Save processors + await asyncio.gather( + *[ + self.event_callback_service.save_event_callback( + EventCallback( + conversation_id=info.id, + processor=processor, + ) + ) + for processor in processors + ] + ) + # Update the start task task.status = AppConversationStartTaskStatus.READY task.app_conversation_id = info.id @@ -673,6 +699,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): get_app_conversation_start_task_service( state, request ) as app_conversation_start_task_service, + get_event_callback_service(state, request) as event_callback_service, get_jwt_service(state, request) as jwt_service, get_httpx_client(state, request) as httpx_client, ): @@ -696,6 +723,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): sandbox_spec_service=sandbox_spec_service, app_conversation_info_service=app_conversation_info_service, app_conversation_start_task_service=app_conversation_start_task_service, + event_callback_service=event_callback_service, jwt_service=jwt_service, sandbox_startup_timeout=self.sandbox_startup_timeout, sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency, diff --git a/openhands/app_server/app_lifespan/alembic/versions/002.py b/openhands/app_server/app_lifespan/alembic/versions/002.py new file mode 100644 index 0000000000..cb3ec72db6 --- /dev/null +++ b/openhands/app_server/app_lifespan/alembic/versions/002.py @@ -0,0 +1,73 @@ +"""Sync DB with Models + +Revision ID: 001 +Revises: +Create Date: 2025-10-05 11:28:41.772294 + +""" + +from enum import Enum +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '002' +down_revision: Union[str, None] = '001' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +class EventCallbackStatus(Enum): + ACTIVE = 'ACTIVE' + DISABLED = 'DISABLED' + COMPLETED = 'COMPLETED' + ERROR = 'ERROR' + + +def upgrade() -> None: + """Upgrade schema.""" + op.add_column( + 'event_callback', + sa.Column( + 'status', + sa.Enum(EventCallbackStatus), + nullable=False, + server_default='ACTIVE', + ), + ) + op.add_column( + 'event_callback', + sa.Column( + 'updated_at', sa.DateTime, nullable=False, server_default=sa.func.now() + ), + ) + op.drop_index('ix_event_callback_result_event_id') + op.drop_column('event_callback_result', 'event_id') + op.add_column( + 'event_callback_result', sa.Column('event_id', sa.String, nullable=True) + ) + op.create_index( + op.f('ix_event_callback_result_event_id'), + 'event_callback_result', + ['event_id'], + unique=False, + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_column('event_callback', 'status') + op.drop_column('event_callback', 'updated_at') + op.drop_index('ix_event_callback_result_event_id') + op.drop_column('event_callback_result', 'event_id') + op.add_column( + 'event_callback_result', sa.Column('event_id', sa.UUID, nullable=True) + ) + op.create_index( + op.f('ix_event_callback_result_event_id'), + 'event_callback_result', + ['event_id'], + unique=False, + ) diff --git a/openhands/app_server/event_callback/event_callback_models.py b/openhands/app_server/event_callback/event_callback_models.py index 4e39bc6a42..8b1abd6aa5 100644 --- a/openhands/app_server/event_callback/event_callback_models.py +++ b/openhands/app_server/event_callback/event_callback_models.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod from datetime import datetime +from enum import Enum from typing import TYPE_CHECKING, Literal from uuid import UUID, uuid4 @@ -28,6 +29,13 @@ else: EventKind = Literal[tuple(c.__name__ for c in get_known_concrete_subclasses(Event))] +class EventCallbackStatus(Enum): + ACTIVE = 'ACTIVE' + DISABLED = 'DISABLED' + COMPLETED = 'COMPLETED' + ERROR = 'ERROR' + + class EventCallbackProcessor(DiscriminatedUnionMixin, ABC): @abstractmethod async def __call__( @@ -35,7 +43,7 @@ class EventCallbackProcessor(DiscriminatedUnionMixin, ABC): conversation_id: UUID, callback: EventCallback, event: Event, - ) -> EventCallbackResult: + ) -> EventCallbackResult | None: """Process an event.""" @@ -75,7 +83,9 @@ class CreateEventCallbackRequest(OpenHandsModel): class EventCallback(CreateEventCallbackRequest): id: OpenHandsUUID = Field(default_factory=uuid4) + status: EventCallbackStatus = Field(default=EventCallbackStatus.ACTIVE) created_at: datetime = Field(default_factory=utc_now) + updated_at: datetime = Field(default_factory=utc_now) class EventCallbackPage(OpenHandsModel): diff --git a/openhands/app_server/event_callback/event_callback_service.py b/openhands/app_server/event_callback/event_callback_service.py index 825b43051a..2d27884bc9 100644 --- a/openhands/app_server/event_callback/event_callback_service.py +++ b/openhands/app_server/event_callback/event_callback_service.py @@ -53,6 +53,10 @@ class EventCallbackService(ABC): ) return results + @abstractmethod + async def save_event_callback(self, event_callback: EventCallback) -> EventCallback: + """Update the event callback given.""" + @abstractmethod async def execute_callbacks(self, conversation_id: UUID, event: Event) -> None: """Execute any applicable callbacks for the event and store the results.""" diff --git a/openhands/app_server/event_callback/set_title_callback_processor.py b/openhands/app_server/event_callback/set_title_callback_processor.py new file mode 100644 index 0000000000..92373dbff0 --- /dev/null +++ b/openhands/app_server/event_callback/set_title_callback_processor.py @@ -0,0 +1,85 @@ +import logging +from uuid import UUID + +from openhands.app_server.app_conversation.app_conversation_models import ( + AppConversationInfo, +) +from openhands.app_server.event_callback.event_callback_models import ( + EventCallback, + EventCallbackProcessor, + EventCallbackStatus, +) +from openhands.app_server.event_callback.event_callback_result_models import ( + EventCallbackResult, + EventCallbackResultStatus, +) +from openhands.app_server.services.injector import InjectorState +from openhands.app_server.user.specifiy_user_context import ADMIN, USER_CONTEXT_ATTR +from openhands.sdk import Event, MessageEvent + +_logger = logging.getLogger(__name__) + + +class SetTitleCallbackProcessor(EventCallbackProcessor): + """Callback processor which sets conversation titles.""" + + async def __call__( + self, + conversation_id: UUID, + callback: EventCallback, + event: Event, + ) -> EventCallbackResult | None: + if not isinstance(event, MessageEvent): + return None + from openhands.app_server.config import ( + get_app_conversation_info_service, + get_app_conversation_service, + get_event_callback_service, + get_httpx_client, + ) + + _logger.info(f'Callback {callback.id} Invoked for event {event}') + + state = InjectorState() + setattr(state, USER_CONTEXT_ATTR, ADMIN) + async with ( + get_event_callback_service(state) as event_callback_service, + get_app_conversation_service(state) as app_conversation_service, + get_app_conversation_info_service(state) as app_conversation_info_service, + get_httpx_client(state) as httpx_client, + ): + # Generate a title for the conversation + app_conversation = await app_conversation_service.get_app_conversation( + conversation_id + ) + assert app_conversation is not None + response = await httpx_client.post( + f'{app_conversation.conversation_url}/generate_title', + headers={ + 'X-Session-API-Key': app_conversation.session_api_key, + }, + content='{}', + ) + response.raise_for_status() + title = response.json()['title'] + + # Save the conversation info + info = AppConversationInfo( + **{ + name: getattr(app_conversation, name) + for name in AppConversationInfo.model_fields + } + ) + info.title = title + await app_conversation_info_service.save_app_conversation_info(info) + + # Disable callback - we have already set the status + callback.status = EventCallbackStatus.DISABLED + await event_callback_service.save_event_callback(callback) + + return EventCallbackResult( + status=EventCallbackResultStatus.SUCCESS, + event_callback_id=callback.id, + event_id=event.id, + conversation_id=conversation_id, + ) diff --git a/openhands/app_server/event_callback/sql_event_callback_service.py b/openhands/app_server/event_callback/sql_event_callback_service.py index 3309e7154d..37e5bce111 100644 --- a/openhands/app_server/event_callback/sql_event_callback_service.py +++ b/openhands/app_server/event_callback/sql_event_callback_service.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio import logging from dataclasses import dataclass +from datetime import datetime from typing import AsyncGenerator from uuid import UUID @@ -19,6 +20,7 @@ from openhands.app_server.event_callback.event_callback_models import ( EventCallback, EventCallbackPage, EventCallbackProcessor, + EventCallbackStatus, EventKind, ) from openhands.app_server.event_callback.event_callback_result_models import ( @@ -46,9 +48,13 @@ class StoredEventCallback(Base): # type: ignore __tablename__ = 'event_callback' id = Column(SQLUUID, primary_key=True) conversation_id = Column(SQLUUID, nullable=True) + status = Column( + Enum(EventCallbackStatus), nullable=False, default=EventCallbackStatus.ACTIVE + ) processor = Column(create_json_type_decorator(EventCallbackProcessor)) event_kind = Column(String, nullable=True) created_at = Column(UtcDateTime, server_default=func.now(), index=True) + updated_at = Column(UtcDateTime, server_default=func.now(), index=True) class StoredEventCallbackResult(Base): # type: ignore @@ -56,7 +62,7 @@ class StoredEventCallbackResult(Base): # type: ignore id = Column(SQLUUID, primary_key=True) status = Column(Enum(EventCallbackResultStatus), nullable=True) event_callback_id = Column(SQLUUID, index=True) - event_id = Column(SQLUUID, index=True) + event_id = Column(String, index=True) conversation_id = Column(SQLUUID, index=True) detail = Column(String, nullable=True) created_at = Column(UtcDateTime, server_default=func.now(), index=True) @@ -170,9 +176,16 @@ class SQLEventCallbackService(EventCallbackService): callbacks = [EventCallback(**row2dict(cb)) for cb in stored_callbacks] return EventCallbackPage(items=callbacks, next_page_id=next_page_id) + async def save_event_callback(self, event_callback: EventCallback) -> EventCallback: + event_callback.updated_at = datetime.now() + stored_callback = StoredEventCallback(**event_callback.model_dump()) + await self.db_session.merge(stored_callback) + return event_callback + async def execute_callbacks(self, conversation_id: UUID, event: Event) -> None: query = ( select(StoredEventCallback) + .where(StoredEventCallback.status == EventCallbackStatus.ACTIVE) .where( or_( StoredEventCallback.event_kind == event.kind, @@ -203,7 +216,9 @@ class SQLEventCallbackService(EventCallbackService): ): try: result = await callback.processor(conversation_id, callback, event) - stored_result = StoredEventCallbackResult(**row2dict(result)) + if result is None: + return + stored_result = StoredEventCallbackResult(**result.model_dump()) except Exception as exc: _logger.exception(f'Exception in callback {callback.id}', stack_info=True) stored_result = StoredEventCallbackResult( diff --git a/tests/unit/app_server/test_sql_event_callback_service.py b/tests/unit/app_server/test_sql_event_callback_service.py index b69d237f58..47a90c3175 100644 --- a/tests/unit/app_server/test_sql_event_callback_service.py +++ b/tests/unit/app_server/test_sql_event_callback_service.py @@ -372,3 +372,284 @@ class TestSQLEventCallbackService: assert len(result.items) == 2 assert result.items[0].id == callback2.id assert result.items[1].id == callback1.id + + async def test_save_event_callback_new( + self, + service: SQLEventCallbackService, + sample_callback: EventCallback, + ): + """Test saving a new event callback (insert scenario).""" + # Save the callback + original_updated_at = sample_callback.updated_at + saved_callback = await service.save_event_callback(sample_callback) + + # Verify the returned callback + assert saved_callback.id == sample_callback.id + assert saved_callback.conversation_id == sample_callback.conversation_id + assert saved_callback.processor == sample_callback.processor + assert saved_callback.event_kind == sample_callback.event_kind + assert saved_callback.status == sample_callback.status + + # Verify updated_at was changed (handle timezone differences) + # Convert both to UTC for comparison if needed + original_utc = ( + original_updated_at.replace(tzinfo=timezone.utc) + if original_updated_at.tzinfo is None + else original_updated_at + ) + saved_utc = ( + saved_callback.updated_at.replace(tzinfo=timezone.utc) + if saved_callback.updated_at.tzinfo is None + else saved_callback.updated_at + ) + assert saved_utc >= original_utc + + # Commit the transaction to persist changes + await service.db_session.commit() + + # Verify the callback can be retrieved + retrieved_callback = await service.get_event_callback(sample_callback.id) + assert retrieved_callback is not None + assert retrieved_callback.id == sample_callback.id + assert retrieved_callback.conversation_id == sample_callback.conversation_id + assert retrieved_callback.event_kind == sample_callback.event_kind + + async def test_save_event_callback_update_existing( + self, + service: SQLEventCallbackService, + sample_request: CreateEventCallbackRequest, + ): + """Test saving an existing event callback (update scenario).""" + # First create a callback through the service + created_callback = await service.create_event_callback(sample_request) + original_updated_at = created_callback.updated_at + + # Modify the callback + created_callback.event_kind = 'ObservationEvent' + from openhands.app_server.event_callback.event_callback_models import ( + EventCallbackStatus, + ) + + created_callback.status = EventCallbackStatus.DISABLED + + # Save the modified callback + saved_callback = await service.save_event_callback(created_callback) + + # Verify the returned callback has the modifications + assert saved_callback.id == created_callback.id + assert saved_callback.event_kind == 'ObservationEvent' + assert saved_callback.status == EventCallbackStatus.DISABLED + + # Verify updated_at was changed (handle timezone differences) + original_utc = ( + original_updated_at.replace(tzinfo=timezone.utc) + if original_updated_at.tzinfo is None + else original_updated_at + ) + saved_utc = ( + saved_callback.updated_at.replace(tzinfo=timezone.utc) + if saved_callback.updated_at.tzinfo is None + else saved_callback.updated_at + ) + assert saved_utc >= original_utc + + # Commit the transaction to persist changes + await service.db_session.commit() + + # Verify the changes were persisted + retrieved_callback = await service.get_event_callback(created_callback.id) + assert retrieved_callback is not None + assert retrieved_callback.event_kind == 'ObservationEvent' + assert retrieved_callback.status == EventCallbackStatus.DISABLED + + async def test_save_event_callback_timestamp_update( + self, + service: SQLEventCallbackService, + sample_callback: EventCallback, + ): + """Test that save_event_callback properly updates the timestamp.""" + # Record the original timestamp + original_updated_at = sample_callback.updated_at + + # Wait a small amount to ensure timestamp difference + import asyncio + + await asyncio.sleep(0.01) + + # Save the callback + saved_callback = await service.save_event_callback(sample_callback) + + # Verify updated_at was changed and is more recent (handle timezone differences) + original_utc = ( + original_updated_at.replace(tzinfo=timezone.utc) + if original_updated_at.tzinfo is None + else original_updated_at + ) + saved_utc = ( + saved_callback.updated_at.replace(tzinfo=timezone.utc) + if saved_callback.updated_at.tzinfo is None + else saved_callback.updated_at + ) + assert saved_utc >= original_utc + assert isinstance(saved_callback.updated_at, datetime) + + # Verify the timestamp is recent (within last minute) + now = datetime.now(timezone.utc) + time_diff = now - saved_utc + assert time_diff.total_seconds() < 60 + + async def test_save_event_callback_with_null_values( + self, + service: SQLEventCallbackService, + sample_processor: EventCallbackProcessor, + ): + """Test saving a callback with null conversation_id and event_kind.""" + # Create a callback with null values + callback = EventCallback( + conversation_id=None, + processor=sample_processor, + event_kind=None, + ) + + # Save the callback + saved_callback = await service.save_event_callback(callback) + + # Verify the callback was saved correctly + assert saved_callback.id == callback.id + assert saved_callback.conversation_id is None + assert saved_callback.event_kind is None + assert saved_callback.processor == sample_processor + + # Commit and verify persistence + await service.db_session.commit() + retrieved_callback = await service.get_event_callback(callback.id) + assert retrieved_callback is not None + assert retrieved_callback.conversation_id is None + assert retrieved_callback.event_kind is None + + async def test_save_event_callback_preserves_created_at( + self, + service: SQLEventCallbackService, + sample_request: CreateEventCallbackRequest, + ): + """Test that save_event_callback preserves the original created_at timestamp.""" + # Create a callback through the service + created_callback = await service.create_event_callback(sample_request) + original_created_at = created_callback.created_at + + # Wait a small amount to ensure timestamp difference + import asyncio + + await asyncio.sleep(0.01) + + # Save the callback again + saved_callback = await service.save_event_callback(created_callback) + + # Verify created_at was preserved but updated_at was changed + assert saved_callback.created_at == original_created_at + # Handle timezone differences for comparison + created_utc = ( + original_created_at.replace(tzinfo=timezone.utc) + if original_created_at.tzinfo is None + else original_created_at + ) + updated_utc = ( + saved_callback.updated_at.replace(tzinfo=timezone.utc) + if saved_callback.updated_at.tzinfo is None + else saved_callback.updated_at + ) + assert updated_utc >= created_utc + + async def test_save_event_callback_different_statuses( + self, + service: SQLEventCallbackService, + sample_processor: EventCallbackProcessor, + ): + """Test saving callbacks with different status values.""" + from openhands.app_server.event_callback.event_callback_models import ( + EventCallbackStatus, + ) + + # Test each status + statuses = [ + EventCallbackStatus.ACTIVE, + EventCallbackStatus.DISABLED, + EventCallbackStatus.COMPLETED, + EventCallbackStatus.ERROR, + ] + + for status in statuses: + callback = EventCallback( + conversation_id=uuid4(), + processor=sample_processor, + event_kind='ActionEvent', + status=status, + ) + + # Save the callback + saved_callback = await service.save_event_callback(callback) + + # Verify the status was preserved + assert saved_callback.status == status + + # Commit and verify persistence + await service.db_session.commit() + retrieved_callback = await service.get_event_callback(callback.id) + assert retrieved_callback is not None + assert retrieved_callback.status == status + + async def test_save_event_callback_returns_same_object( + self, + service: SQLEventCallbackService, + sample_callback: EventCallback, + ): + """Test that save_event_callback returns the same object instance.""" + # Save the callback + saved_callback = await service.save_event_callback(sample_callback) + + # Verify it's the same object (identity check) + assert saved_callback is sample_callback + + # But verify the updated_at was modified on the original object + assert sample_callback.updated_at == saved_callback.updated_at + + async def test_save_event_callback_multiple_saves( + self, + service: SQLEventCallbackService, + sample_callback: EventCallback, + ): + """Test saving the same callback multiple times.""" + # Save the callback multiple times + first_save = await service.save_event_callback(sample_callback) + first_updated_at = first_save.updated_at + + # Wait a small amount to ensure timestamp difference + import asyncio + + await asyncio.sleep(0.01) + + second_save = await service.save_event_callback(sample_callback) + second_updated_at = second_save.updated_at + + # Verify timestamps are different (handle timezone differences) + first_utc = ( + first_updated_at.replace(tzinfo=timezone.utc) + if first_updated_at.tzinfo is None + else first_updated_at + ) + second_utc = ( + second_updated_at.replace(tzinfo=timezone.utc) + if second_updated_at.tzinfo is None + else second_updated_at + ) + assert second_utc >= first_utc + + # Verify it's still the same callback + assert first_save.id == second_save.id + assert first_save is second_save # Same object instance + + # Commit and verify only one record exists + await service.db_session.commit() + retrieved_callback = await service.get_event_callback(sample_callback.id) + assert retrieved_callback is not None + assert retrieved_callback.id == sample_callback.id diff --git a/tests/unit/experiments/test_experiment_manager.py b/tests/unit/experiments/test_experiment_manager.py index 7a23cf9079..2103e11cb4 100644 --- a/tests/unit/experiments/test_experiment_manager.py +++ b/tests/unit/experiments/test_experiment_manager.py @@ -169,6 +169,7 @@ class TestExperimentManagerIntegration: # The service requires a lot of deps, but for this test we won't exercise them. app_conversation_info_service = Mock() app_conversation_start_task_service = Mock() + event_callback_service = Mock() sandbox_service = Mock() sandbox_spec_service = Mock() jwt_service = Mock() @@ -179,6 +180,7 @@ class TestExperimentManagerIntegration: user_context=user_context, app_conversation_info_service=app_conversation_info_service, app_conversation_start_task_service=app_conversation_start_task_service, + event_callback_service=event_callback_service, sandbox_service=sandbox_service, sandbox_spec_service=sandbox_spec_service, jwt_service=jwt_service,