From 238cab4d08ebb176bf972e37ef55fa40615dbf82 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Mon, 16 Mar 2026 22:25:44 +0700 Subject: [PATCH] fix(frontend): prevent chat message loss during websocket disconnections or page refresh (#13380) --- .../101_add_pending_messages_table.py | 39 ++ .../utils/saas_pending_message_injector.py | 172 +++++ .../components/interactive-chat-box.test.tsx | 4 +- .../conversation-local-storage.test.ts | 227 +++++++ .../conversation-websocket-handler.test.tsx | 241 ++++++- .../hooks/use-draft-persistence.test.tsx | 594 ++++++++++++++++++ .../hooks/use-handle-plan-click.test.tsx | 3 + .../pending-message-service.api.ts | 40 ++ .../pending-message-service.types.ts | 22 + .../features/chat/chat-interface.tsx | 10 +- .../features/chat/custom-chat-input.tsx | 2 + .../features/chat/interactive-chat-box.tsx | 3 +- .../conversation-websocket-context.tsx | 50 +- .../src/hooks/chat/use-chat-input-logic.ts | 11 + .../src/hooks/chat/use-draft-persistence.ts | 179 ++++++ frontend/src/hooks/use-send-message.ts | 21 +- .../src/utils/conversation-local-storage.ts | 4 + .../live_status_app_conversation_service.py | 99 +++ .../app_lifespan/alembic/versions/007.py | 39 ++ openhands/app_server/config.py | 26 + .../app_server/pending_messages/__init__.py | 21 + .../pending_message_models.py | 32 + .../pending_message_router.py | 104 +++ .../pending_message_service.py | 200 ++++++ openhands/app_server/v1_router.py | 4 + ...st_live_status_app_conversation_service.py | 4 + .../app_server/test_pending_message_router.py | 227 +++++++ .../test_pending_message_service.py | 309 +++++++++ .../server/data_models/test_conversation.py | 3 + 29 files changed, 2668 insertions(+), 22 deletions(-) create mode 100644 enterprise/migrations/versions/101_add_pending_messages_table.py create mode 100644 enterprise/server/utils/saas_pending_message_injector.py create mode 100644 frontend/__tests__/hooks/use-draft-persistence.test.tsx create mode 100644 frontend/src/api/pending-message-service/pending-message-service.api.ts create mode 100644 frontend/src/api/pending-message-service/pending-message-service.types.ts create mode 100644 frontend/src/hooks/chat/use-draft-persistence.ts create mode 100644 openhands/app_server/app_lifespan/alembic/versions/007.py create mode 100644 openhands/app_server/pending_messages/__init__.py create mode 100644 openhands/app_server/pending_messages/pending_message_models.py create mode 100644 openhands/app_server/pending_messages/pending_message_router.py create mode 100644 openhands/app_server/pending_messages/pending_message_service.py create mode 100644 tests/unit/app_server/test_pending_message_router.py create mode 100644 tests/unit/app_server/test_pending_message_service.py diff --git a/enterprise/migrations/versions/101_add_pending_messages_table.py b/enterprise/migrations/versions/101_add_pending_messages_table.py new file mode 100644 index 0000000000..cbe97a955b --- /dev/null +++ b/enterprise/migrations/versions/101_add_pending_messages_table.py @@ -0,0 +1,39 @@ +"""Add pending_messages table for server-side message queuing + +Revision ID: 101 +Revises: 100 +Create Date: 2025-03-15 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '101' +down_revision: Union[str, None] = '100' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create pending_messages table for storing messages before conversation is ready. + + Messages are stored temporarily until the conversation becomes ready, then + delivered and deleted regardless of success or failure. + """ + op.create_table( + 'pending_messages', + sa.Column('id', sa.String(), primary_key=True), + sa.Column('conversation_id', sa.String(), nullable=False, index=True), + sa.Column('role', sa.String(20), nullable=False, server_default='user'), + sa.Column('content', sa.JSON, nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + ) + + +def downgrade() -> None: + """Remove pending_messages table.""" + op.drop_table('pending_messages') diff --git a/enterprise/server/utils/saas_pending_message_injector.py b/enterprise/server/utils/saas_pending_message_injector.py new file mode 100644 index 0000000000..fa47152801 --- /dev/null +++ b/enterprise/server/utils/saas_pending_message_injector.py @@ -0,0 +1,172 @@ +"""Enterprise injector for PendingMessageService with SAAS filtering.""" + +from typing import AsyncGenerator +from uuid import UUID + +from fastapi import Request +from sqlalchemy import select +from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas +from storage.user import User + +from openhands.agent_server.models import ImageContent, TextContent +from openhands.app_server.errors import AuthError +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessageResponse, +) +from openhands.app_server.pending_messages.pending_message_service import ( + PendingMessageService, + PendingMessageServiceInjector, + SQLPendingMessageService, +) +from openhands.app_server.services.injector import InjectorState +from openhands.app_server.user.specifiy_user_context import ADMIN +from openhands.app_server.user.user_context import UserContext + + +class SaasSQLPendingMessageService(SQLPendingMessageService): + """Extended SQLPendingMessageService with user and organization-based filtering. + + This enterprise version ensures that: + - Users can only queue messages for conversations they own + - Organization isolation is enforced for multi-tenant deployments + """ + + def __init__(self, db_session, user_context: UserContext): + super().__init__(db_session=db_session) + self.user_context = user_context + + async def _get_current_user(self) -> User | None: + """Get the current user using the existing db_session. + + Returns: + User object or None if no user_id is available + """ + user_id_str = await self.user_context.get_user_id() + if not user_id_str: + return None + + user_id_uuid = UUID(user_id_str) + result = await self.db_session.execute( + select(User).where(User.id == user_id_uuid) + ) + return result.scalars().first() + + async def _validate_conversation_ownership(self, conversation_id: str) -> None: + """Validate that the current user owns the conversation. + + This ensures multi-tenant isolation by checking: + - The conversation belongs to the current user + - The conversation belongs to the user's current organization + + Args: + conversation_id: The conversation ID to validate (can be task-id or UUID) + + Raises: + AuthError: If user doesn't own the conversation or authentication fails + """ + # For internal operations (e.g., processing pending messages during startup) + # we need a mode that bypasses filtering. The ADMIN context enables this. + if self.user_context == ADMIN: + return + + user_id_str = await self.user_context.get_user_id() + if not user_id_str: + raise AuthError('User authentication required') + + user_id_uuid = UUID(user_id_str) + + # Check conversation ownership via SAAS metadata + query = select(StoredConversationMetadataSaas).where( + StoredConversationMetadataSaas.conversation_id == conversation_id + ) + result = await self.db_session.execute(query) + saas_metadata = result.scalar_one_or_none() + + # If no SAAS metadata exists, the conversation might be a new task-id + # that hasn't been linked to a conversation yet. Allow access in this case + # as the message will be validated when the conversation is created. + if saas_metadata is None: + return + + # Verify user ownership + if saas_metadata.user_id != user_id_uuid: + raise AuthError('You do not have access to this conversation') + + # Verify organization ownership if applicable + user = await self._get_current_user() + if user and user.current_org_id is not None: + if saas_metadata.org_id != user.current_org_id: + raise AuthError('Conversation belongs to a different organization') + + async def add_message( + self, + conversation_id: str, + content: list[TextContent | ImageContent], + role: str = 'user', + ) -> PendingMessageResponse: + """Queue a message with ownership validation. + + Args: + conversation_id: The conversation ID to queue the message for + content: Message content + role: Message role (default: 'user') + + Returns: + PendingMessageResponse with the queued message info + + Raises: + AuthError: If user doesn't own the conversation + """ + await self._validate_conversation_ownership(conversation_id) + return await super().add_message(conversation_id, content, role) + + async def get_pending_messages(self, conversation_id: str): + """Get pending messages with ownership validation. + + Args: + conversation_id: The conversation ID to get messages for + + Returns: + List of pending messages + + Raises: + AuthError: If user doesn't own the conversation + """ + await self._validate_conversation_ownership(conversation_id) + return await super().get_pending_messages(conversation_id) + + async def count_pending_messages(self, conversation_id: str) -> int: + """Count pending messages with ownership validation. + + Args: + conversation_id: The conversation ID to count messages for + + Returns: + Number of pending messages + + Raises: + AuthError: If user doesn't own the conversation + """ + await self._validate_conversation_ownership(conversation_id) + return await super().count_pending_messages(conversation_id) + + +class SaasPendingMessageServiceInjector(PendingMessageServiceInjector): + """Enterprise injector for PendingMessageService with SAAS filtering.""" + + async def inject( + self, state: InjectorState, request: Request | None = None + ) -> AsyncGenerator[PendingMessageService, None]: + from openhands.app_server.config import ( + get_db_session, + get_user_context, + ) + + async with ( + get_user_context(state, request) as user_context, + get_db_session(state, request) as db_session, + ): + service = SaasSQLPendingMessageService( + db_session=db_session, user_context=user_context + ) + yield service diff --git a/frontend/__tests__/components/interactive-chat-box.test.tsx b/frontend/__tests__/components/interactive-chat-box.test.tsx index cb164123c1..bafa673731 100644 --- a/frontend/__tests__/components/interactive-chat-box.test.tsx +++ b/frontend/__tests__/components/interactive-chat-box.test.tsx @@ -198,9 +198,9 @@ describe("InteractiveChatBox", () => { expect(onSubmitMock).toHaveBeenCalledWith("Hello, world!", [], []); }); - it("should disable the submit button when agent is loading", async () => { + it("should disable the submit button when awaiting user confirmation", async () => { const user = userEvent.setup(); - mockStores(AgentState.LOADING); + mockStores(AgentState.AWAITING_USER_CONFIRMATION); renderInteractiveChatBox({ onSubmit: onSubmitMock, diff --git a/frontend/__tests__/conversation-local-storage.test.ts b/frontend/__tests__/conversation-local-storage.test.ts index a99e5fc005..33e9e12a7e 100644 --- a/frontend/__tests__/conversation-local-storage.test.ts +++ b/frontend/__tests__/conversation-local-storage.test.ts @@ -229,4 +229,231 @@ describe("conversation localStorage utilities", () => { expect(parsed.subConversationTaskId).toBeNull(); }); }); + + describe("draftMessage persistence", () => { + describe("getConversationState", () => { + it("returns default draftMessage as null when no state exists", () => { + // Arrange + const conversationId = "conv-draft-1"; + + // Act + const state = getConversationState(conversationId); + + // Assert + expect(state.draftMessage).toBeNull(); + }); + + it("retrieves draftMessage from localStorage when it exists", () => { + // Arrange + const conversationId = "conv-draft-2"; + const draftText = "This is my saved draft message"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`; + + localStorage.setItem( + consolidatedKey, + JSON.stringify({ + draftMessage: draftText, + }), + ); + + // Act + const state = getConversationState(conversationId); + + // Assert + expect(state.draftMessage).toBe(draftText); + }); + + it("returns null draftMessage for task conversation IDs (not persisted)", () => { + // Arrange + const taskId = "task-uuid-123"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${taskId}`; + + // Even if somehow there's data in localStorage for a task ID + localStorage.setItem( + consolidatedKey, + JSON.stringify({ + draftMessage: "Should not be returned", + }), + ); + + // Act + const state = getConversationState(taskId); + + // Assert - should return default state, not the stored value + expect(state.draftMessage).toBeNull(); + }); + }); + + describe("setConversationState", () => { + it("persists draftMessage to localStorage", () => { + // Arrange + const conversationId = "conv-draft-3"; + const draftText = "New draft message to save"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`; + + // Act + setConversationState(conversationId, { + draftMessage: draftText, + }); + + // Assert + const stored = localStorage.getItem(consolidatedKey); + expect(stored).not.toBeNull(); + const parsed = JSON.parse(stored!); + expect(parsed.draftMessage).toBe(draftText); + }); + + it("does not persist draftMessage for task conversation IDs", () => { + // Arrange + const taskId = "task-draft-xyz"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${taskId}`; + + // Act + setConversationState(taskId, { + draftMessage: "Draft for task ID", + }); + + // Assert - nothing should be stored + expect(localStorage.getItem(consolidatedKey)).toBeNull(); + }); + + it("merges draftMessage with existing state without overwriting other fields", () => { + // Arrange + const conversationId = "conv-draft-4"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`; + + localStorage.setItem( + consolidatedKey, + JSON.stringify({ + selectedTab: "terminal", + rightPanelShown: false, + unpinnedTabs: ["tab-1", "tab-2"], + conversationMode: "plan", + subConversationTaskId: "task-123", + }), + ); + + // Act + setConversationState(conversationId, { + draftMessage: "Updated draft", + }); + + // Assert + const stored = localStorage.getItem(consolidatedKey); + const parsed = JSON.parse(stored!); + + expect(parsed.draftMessage).toBe("Updated draft"); + expect(parsed.selectedTab).toBe("terminal"); + expect(parsed.rightPanelShown).toBe(false); + expect(parsed.unpinnedTabs).toEqual(["tab-1", "tab-2"]); + expect(parsed.conversationMode).toBe("plan"); + expect(parsed.subConversationTaskId).toBe("task-123"); + }); + + it("clears draftMessage when set to null", () => { + // Arrange + const conversationId = "conv-draft-5"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`; + + localStorage.setItem( + consolidatedKey, + JSON.stringify({ + draftMessage: "Existing draft", + }), + ); + + // Act + setConversationState(conversationId, { + draftMessage: null, + }); + + // Assert + const stored = localStorage.getItem(consolidatedKey); + const parsed = JSON.parse(stored!); + expect(parsed.draftMessage).toBeNull(); + }); + + it("clears draftMessage when set to empty string (stored as empty string)", () => { + // Arrange + const conversationId = "conv-draft-6"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`; + + localStorage.setItem( + consolidatedKey, + JSON.stringify({ + draftMessage: "Existing draft", + }), + ); + + // Act + setConversationState(conversationId, { + draftMessage: "", + }); + + // Assert + const stored = localStorage.getItem(consolidatedKey); + const parsed = JSON.parse(stored!); + expect(parsed.draftMessage).toBe(""); + }); + }); + + describe("conversation-specific draft isolation", () => { + it("stores drafts separately for different conversations", () => { + // Arrange + const convA = "conv-A"; + const convB = "conv-B"; + const draftA = "Draft for conversation A"; + const draftB = "Draft for conversation B"; + + // Act + setConversationState(convA, { draftMessage: draftA }); + setConversationState(convB, { draftMessage: draftB }); + + // Assert + const stateA = getConversationState(convA); + const stateB = getConversationState(convB); + + expect(stateA.draftMessage).toBe(draftA); + expect(stateB.draftMessage).toBe(draftB); + }); + + it("updating one conversation draft does not affect another", () => { + // Arrange + const convA = "conv-isolated-A"; + const convB = "conv-isolated-B"; + + setConversationState(convA, { draftMessage: "Original draft A" }); + setConversationState(convB, { draftMessage: "Original draft B" }); + + // Act - update only conversation A + setConversationState(convA, { draftMessage: "Updated draft A" }); + + // Assert - conversation B should be unchanged + const stateA = getConversationState(convA); + const stateB = getConversationState(convB); + + expect(stateA.draftMessage).toBe("Updated draft A"); + expect(stateB.draftMessage).toBe("Original draft B"); + }); + + it("clearing one conversation draft does not affect another", () => { + // Arrange + const convA = "conv-clear-A"; + const convB = "conv-clear-B"; + + setConversationState(convA, { draftMessage: "Draft A" }); + setConversationState(convB, { draftMessage: "Draft B" }); + + // Act - clear draft for conversation A + setConversationState(convA, { draftMessage: null }); + + // Assert + const stateA = getConversationState(convA); + const stateB = getConversationState(convB); + + expect(stateA.draftMessage).toBeNull(); + expect(stateB.draftMessage).toBe("Draft B"); + }); + }); + }); }); diff --git a/frontend/__tests__/conversation-websocket-handler.test.tsx b/frontend/__tests__/conversation-websocket-handler.test.tsx index 284aaee287..393d6f68f0 100644 --- a/frontend/__tests__/conversation-websocket-handler.test.tsx +++ b/frontend/__tests__/conversation-websocket-handler.test.tsx @@ -1,3 +1,4 @@ +import React from "react"; import { describe, it, @@ -8,7 +9,7 @@ import { afterEach, vi, } from "vitest"; -import { screen, waitFor, render, cleanup } from "@testing-library/react"; +import { screen, waitFor, render, cleanup, act } from "@testing-library/react"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { http, HttpResponse } from "msw"; import { MemoryRouter, Route, Routes } from "react-router"; @@ -682,8 +683,242 @@ describe("Conversation WebSocket Handler", () => { // 7. Message Sending Tests describe("Message Sending", () => { - it.todo("should send user actions through WebSocket when connected"); - it.todo("should handle send attempts when disconnected"); + it("should send user actions through WebSocket when connected", async () => { + // Arrange + const conversationId = "test-conversation-send"; + let receivedMessage: unknown = null; + + // Set up MSW to capture sent messages + mswServer.use( + wsLink.addEventListener("connection", ({ client, server }) => { + server.connect(); + + // Capture messages sent from client + client.addEventListener("message", (event) => { + receivedMessage = JSON.parse(event.data as string); + }); + }), + ); + + // Create ref to store sendMessage function + let sendMessageFn: typeof useConversationWebSocket extends () => infer R + ? R extends { sendMessage: infer S } + ? S + : null + : null = null; + + function TestComponent() { + const context = useConversationWebSocket(); + + React.useEffect(() => { + if (context?.sendMessage) { + sendMessageFn = context.sendMessage; + } + }, [context?.sendMessage]); + + return ( +
+
+ {context?.connectionState || "NOT_AVAILABLE"} +
+
+ ); + } + + // Act + renderWithWebSocketContext( + , + conversationId, + `http://localhost:3000/api/conversations/${conversationId}`, + ); + + // Wait for connection + await waitFor(() => { + expect(screen.getByTestId("connection-state")).toHaveTextContent( + "OPEN", + ); + }); + + // Send a message + await waitFor(() => { + expect(sendMessageFn).not.toBeNull(); + }); + + await act(async () => { + await sendMessageFn!({ + role: "user", + content: [{ type: "text", text: "Hello from test" }], + }); + }); + + // Assert - message should have been received by mock server + await waitFor(() => { + expect(receivedMessage).toEqual({ + role: "user", + content: [{ type: "text", text: "Hello from test" }], + }); + }); + }); + + it("should not throw error when sendMessage is called with WebSocket connected", async () => { + // This test verifies that sendMessage doesn't throw an error + // when the WebSocket is connected. + const conversationId = "test-conversation-no-throw"; + let sendError: Error | null = null; + + // Set up MSW to connect and receive messages + mswServer.use( + wsLink.addEventListener("connection", ({ server }) => { + server.connect(); + }), + ); + + // Create ref to store sendMessage function + let sendMessageFn: typeof useConversationWebSocket extends () => infer R + ? R extends { sendMessage: infer S } + ? S + : null + : null = null; + + function TestComponent() { + const context = useConversationWebSocket(); + + React.useEffect(() => { + if (context?.sendMessage) { + sendMessageFn = context.sendMessage; + } + }, [context?.sendMessage]); + + return ( +
+
+ {context?.connectionState || "NOT_AVAILABLE"} +
+
+ ); + } + + // Act + renderWithWebSocketContext( + , + conversationId, + `http://localhost:3000/api/conversations/${conversationId}`, + ); + + // Wait for connection + await waitFor(() => { + expect(screen.getByTestId("connection-state")).toHaveTextContent( + "OPEN", + ); + }); + + // Wait for the context to be available + await waitFor(() => { + expect(sendMessageFn).not.toBeNull(); + }); + + // Try to send a message + await act(async () => { + try { + await sendMessageFn!({ + role: "user", + content: [{ type: "text", text: "Test message" }], + }); + } catch (error) { + sendError = error as Error; + } + }); + + // Assert - should NOT throw an error + expect(sendError).toBeNull(); + }); + + it("should send multiple messages through WebSocket in order", async () => { + // Arrange + const conversationId = "test-conversation-multi"; + const receivedMessages: unknown[] = []; + + // Set up MSW to capture sent messages + mswServer.use( + wsLink.addEventListener("connection", ({ client, server }) => { + server.connect(); + + // Capture messages sent from client + client.addEventListener("message", (event) => { + receivedMessages.push(JSON.parse(event.data as string)); + }); + }), + ); + + // Create ref to store sendMessage function + let sendMessageFn: typeof useConversationWebSocket extends () => infer R + ? R extends { sendMessage: infer S } + ? S + : null + : null = null; + + function TestComponent() { + const context = useConversationWebSocket(); + + React.useEffect(() => { + if (context?.sendMessage) { + sendMessageFn = context.sendMessage; + } + }, [context?.sendMessage]); + + return ( +
+
+ {context?.connectionState || "NOT_AVAILABLE"} +
+
+ ); + } + + // Act + renderWithWebSocketContext( + , + conversationId, + `http://localhost:3000/api/conversations/${conversationId}`, + ); + + // Wait for connection + await waitFor(() => { + expect(screen.getByTestId("connection-state")).toHaveTextContent( + "OPEN", + ); + }); + + await waitFor(() => { + expect(sendMessageFn).not.toBeNull(); + }); + + // Send multiple messages + await act(async () => { + await sendMessageFn!({ + role: "user", + content: [{ type: "text", text: "Message 1" }], + }); + await sendMessageFn!({ + role: "user", + content: [{ type: "text", text: "Message 2" }], + }); + }); + + // Assert - both messages should have been received in order + await waitFor(() => { + expect(receivedMessages.length).toBe(2); + }); + + expect(receivedMessages[0]).toEqual({ + role: "user", + content: [{ type: "text", text: "Message 1" }], + }); + expect(receivedMessages[1]).toEqual({ + role: "user", + content: [{ type: "text", text: "Message 2" }], + }); + }); }); // 8. History Loading State Tests diff --git a/frontend/__tests__/hooks/use-draft-persistence.test.tsx b/frontend/__tests__/hooks/use-draft-persistence.test.tsx new file mode 100644 index 0000000000..0734470324 --- /dev/null +++ b/frontend/__tests__/hooks/use-draft-persistence.test.tsx @@ -0,0 +1,594 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { renderHook, act } from "@testing-library/react"; +import { useDraftPersistence } from "#/hooks/chat/use-draft-persistence"; +import * as conversationLocalStorage from "#/utils/conversation-local-storage"; + +// Mock the entire module +vi.mock("#/utils/conversation-local-storage", () => ({ + useConversationLocalStorageState: vi.fn(), + getConversationState: vi.fn(), + setConversationState: vi.fn(), +})); + +// Mock the getTextContent utility +vi.mock("#/components/features/chat/utils/chat-input.utils", () => ({ + getTextContent: vi.fn((el: HTMLDivElement | null) => el?.textContent || ""), +})); + +describe("useDraftPersistence", () => { + let mockSetDraftMessage: (message: string | null) => void; + + // Create a mock ref to contentEditable div + const createMockChatInputRef = (initialContent = "") => { + const div = document.createElement("div"); + div.setAttribute("contenteditable", "true"); + div.textContent = initialContent; + return { current: div }; + }; + + beforeEach(() => { + vi.clearAllMocks(); + vi.useFakeTimers(); + localStorage.clear(); + + mockSetDraftMessage = vi.fn<(message: string | null) => void>(); + + // Default mock for useConversationLocalStorageState + vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({ + state: { + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }, + setSelectedTab: vi.fn(), + setRightPanelShown: vi.fn(), + setUnpinnedTabs: vi.fn(), + setConversationMode: vi.fn(), + setDraftMessage: mockSetDraftMessage, + }); + + // Default mock for getConversationState + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + }); + + afterEach(() => { + vi.useRealTimers(); + vi.clearAllMocks(); + }); + + describe("draft restoration on mount", () => { + it("restores draft from localStorage when mounting with existing draft", () => { + // Arrange + const conversationId = "conv-restore-1"; + const savedDraft = "Previously saved draft message"; + const chatInputRef = createMockChatInputRef(); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: savedDraft, + }); + + // Act + renderHook(() => useDraftPersistence(conversationId, chatInputRef)); + + // Assert - draft should be restored to the DOM element + expect(chatInputRef.current?.textContent).toBe(savedDraft); + }); + + it("clears input on mount then restores draft if exists", () => { + // Arrange + const conversationId = "conv-restore-2"; + const existingContent = "Stale content from previous conversation"; + const savedDraft = "Saved draft"; + const chatInputRef = createMockChatInputRef(existingContent); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: savedDraft, + }); + + // Act + renderHook(() => useDraftPersistence(conversationId, chatInputRef)); + + // Assert - input cleared then draft restored + expect(chatInputRef.current?.textContent).toBe(savedDraft); + }); + + it("clears input when no draft exists for conversation", () => { + // Arrange + const conversationId = "conv-no-draft"; + const chatInputRef = createMockChatInputRef("Some stale content"); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + + // Act + renderHook(() => useDraftPersistence(conversationId, chatInputRef)); + + // Assert - content should be cleared since there's no draft + expect(chatInputRef.current?.textContent).toBe(""); + }); + }); + + describe("debounced saving", () => { + it("saves draft after debounce period", () => { + // Arrange + const conversationId = "conv-debounce-1"; + const chatInputRef = createMockChatInputRef(); + + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Act - simulate user typing + chatInputRef.current!.textContent = "New draft content"; + act(() => { + result.current.saveDraft(); + }); + + // Assert - should not save immediately + expect(mockSetDraftMessage).not.toHaveBeenCalled(); + + // Fast forward past debounce period (500ms) + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - should save after debounce + expect(mockSetDraftMessage).toHaveBeenCalledWith("New draft content"); + }); + + it("cancels pending save when new input arrives before debounce", () => { + // Arrange + const conversationId = "conv-debounce-2"; + const chatInputRef = createMockChatInputRef(); + + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Act - first input + chatInputRef.current!.textContent = "First"; + act(() => { + result.current.saveDraft(); + }); + + // Wait 200ms (less than debounce) + act(() => { + vi.advanceTimersByTime(200); + }); + + // Second input before debounce completes + chatInputRef.current!.textContent = "First Second"; + act(() => { + result.current.saveDraft(); + }); + + // Complete the second debounce + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - should only save the final value once + expect(mockSetDraftMessage).toHaveBeenCalledTimes(1); + expect(mockSetDraftMessage).toHaveBeenCalledWith("First Second"); + }); + + it("does not save if content matches existing draft", () => { + // Arrange + const conversationId = "conv-no-change"; + const existingDraft = "Existing draft"; + const chatInputRef = createMockChatInputRef(existingDraft); + + vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({ + state: { + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: existingDraft, + }, + setSelectedTab: vi.fn(), + setRightPanelShown: vi.fn(), + setUnpinnedTabs: vi.fn(), + setConversationMode: vi.fn(), + setDraftMessage: mockSetDraftMessage, + }); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: existingDraft, + }); + + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Act - try to save same content + act(() => { + result.current.saveDraft(); + }); + + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - should not save since content is the same + expect(mockSetDraftMessage).not.toHaveBeenCalled(); + }); + }); + + describe("clearDraft", () => { + it("clears the draft from localStorage", () => { + // Arrange + const conversationId = "conv-clear-1"; + const chatInputRef = createMockChatInputRef("Some content"); + + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Act + act(() => { + result.current.clearDraft(); + }); + + // Assert + expect(mockSetDraftMessage).toHaveBeenCalledWith(null); + }); + + it("cancels any pending debounced save when clearing", () => { + // Arrange + const conversationId = "conv-clear-2"; + const chatInputRef = createMockChatInputRef(); + + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Start a save + chatInputRef.current!.textContent = "Pending draft"; + act(() => { + result.current.saveDraft(); + }); + + // Clear before debounce completes + act(() => { + vi.advanceTimersByTime(200); + result.current.clearDraft(); + }); + + // Complete the original debounce period + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - only the clear should have been called (the pending save should be cancelled) + expect(mockSetDraftMessage).toHaveBeenCalledTimes(1); + expect(mockSetDraftMessage).toHaveBeenCalledWith(null); + }); + }); + + describe("conversation switching", () => { + it("clears input when switching to a new conversation without a draft", () => { + // Arrange + const chatInputRef = createMockChatInputRef("Draft from conv A"); + + // First conversation has a draft + vi.mocked(conversationLocalStorage.getConversationState) + .mockReturnValueOnce({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: "Draft from conv A", + }) + .mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + + const { rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "conv-A" } }, + ); + + // Act - switch to conversation B + rerender({ conversationId: "conv-B" }); + + // Assert - input should be cleared (no draft for conv-B) + expect(chatInputRef.current?.textContent).toBe(""); + }); + + it("restores draft when switching to a conversation with an existing draft", () => { + // Arrange + const chatInputRef = createMockChatInputRef(); + const draftForConvB = "Saved draft for conversation B"; + + vi.mocked(conversationLocalStorage.getConversationState) + .mockReturnValueOnce({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }) + .mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: draftForConvB, + }); + + const { rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "conv-A" } }, + ); + + // Act - switch to conversation B + rerender({ conversationId: "conv-B" }); + + // Assert - draft for conv-B should be restored + expect(chatInputRef.current?.textContent).toBe(draftForConvB); + }); + + it("cancels pending save when switching conversations", () => { + // Arrange + const chatInputRef = createMockChatInputRef(); + + const { result, rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "conv-A" } }, + ); + + // Start typing in conv-A + chatInputRef.current!.textContent = "Draft for conv-A"; + act(() => { + result.current.saveDraft(); + }); + + // Switch conversation before debounce completes + act(() => { + vi.advanceTimersByTime(200); + }); + rerender({ conversationId: "conv-B" }); + + // Complete the debounce period + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - the save should NOT have happened because conversation changed + expect(mockSetDraftMessage).not.toHaveBeenCalled(); + }); + }); + + describe("task ID to real conversation ID transition", () => { + it("transfers draft from task ID to real conversation ID during transition", () => { + // Arrange + const chatInputRef = createMockChatInputRef("Draft typed during init"); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + + const { rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "task-abc-123" } }, + ); + + // Simulate user typing during task initialization + chatInputRef.current!.textContent = "Draft typed during init"; + + // Act - transition to real conversation ID + rerender({ conversationId: "conv-real-123" }); + + // Assert - draft should be saved to the new real conversation ID + expect(conversationLocalStorage.setConversationState).toHaveBeenCalledWith( + "conv-real-123", + { draftMessage: "Draft typed during init" }, + ); + + // And the draft should remain visible in the input + expect(chatInputRef.current?.textContent).toBe("Draft typed during init"); + }); + + it("does not transfer empty draft during task-to-real transition", () => { + // Arrange + const chatInputRef = createMockChatInputRef(""); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + + const { rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "task-abc-123" } }, + ); + + // Act - transition to real conversation ID with empty input + rerender({ conversationId: "conv-real-123" }); + + // Assert - no draft should be saved (input is cleared, checked by hook) + // The setConversationState should not be called with draftMessage + expect(conversationLocalStorage.setConversationState).not.toHaveBeenCalled(); + }); + + it("does not transfer draft for non-task ID transitions", () => { + // Arrange + const chatInputRef = createMockChatInputRef("Some draft"); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + + const { rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "conv-A" } }, + ); + + // Act - normal conversation switch (not task-to-real) + rerender({ conversationId: "conv-B" }); + + // Assert - should not use setConversationState directly + // (the normal path uses setDraftMessage from the hook) + expect(conversationLocalStorage.setConversationState).not.toHaveBeenCalled(); + }); + }); + + describe("hasDraft and isRestored state", () => { + it("returns hasDraft true when draft exists in hook state", () => { + // Arrange + const conversationId = "conv-has-draft"; + const chatInputRef = createMockChatInputRef(); + + vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({ + state: { + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: "Existing draft", + }, + setSelectedTab: vi.fn(), + setRightPanelShown: vi.fn(), + setUnpinnedTabs: vi.fn(), + setConversationMode: vi.fn(), + setDraftMessage: mockSetDraftMessage, + }); + + // Act + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Assert + expect(result.current.hasDraft).toBe(true); + }); + + it("returns hasDraft false when no draft exists", () => { + // Arrange + const conversationId = "conv-no-draft"; + const chatInputRef = createMockChatInputRef(); + + // Act + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Assert + expect(result.current.hasDraft).toBe(false); + }); + + it("sets isRestored to true after restoration completes", () => { + // Arrange + const conversationId = "conv-restored"; + const chatInputRef = createMockChatInputRef(); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: "Draft to restore", + }); + + // Act + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Assert + expect(result.current.isRestored).toBe(true); + }); + }); + + describe("cleanup on unmount", () => { + it("clears pending timeout on unmount", () => { + // Arrange + const conversationId = "conv-unmount"; + const chatInputRef = createMockChatInputRef(); + + const { result, unmount } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Start a save + chatInputRef.current!.textContent = "Draft"; + act(() => { + result.current.saveDraft(); + }); + + // Unmount before debounce completes + unmount(); + + // Complete the debounce period + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - save should not have been called after unmount + expect(mockSetDraftMessage).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/frontend/__tests__/hooks/use-handle-plan-click.test.tsx b/frontend/__tests__/hooks/use-handle-plan-click.test.tsx index 067a208c81..fdaa4c06aa 100644 --- a/frontend/__tests__/hooks/use-handle-plan-click.test.tsx +++ b/frontend/__tests__/hooks/use-handle-plan-click.test.tsx @@ -88,6 +88,7 @@ describe("useHandlePlanClick", () => { unpinnedTabs: [], subConversationTaskId: null, conversationMode: "code", + draftMessage: null, }); }); @@ -117,6 +118,7 @@ describe("useHandlePlanClick", () => { unpinnedTabs: [], subConversationTaskId: storedTaskId, conversationMode: "code", + draftMessage: null, }); renderHook(() => useHandlePlanClick()); @@ -155,6 +157,7 @@ describe("useHandlePlanClick", () => { unpinnedTabs: [], subConversationTaskId: storedTaskId, conversationMode: "code", + draftMessage: null, }); renderHook(() => useHandlePlanClick()); diff --git a/frontend/src/api/pending-message-service/pending-message-service.api.ts b/frontend/src/api/pending-message-service/pending-message-service.api.ts new file mode 100644 index 0000000000..8c7ef73a8a --- /dev/null +++ b/frontend/src/api/pending-message-service/pending-message-service.api.ts @@ -0,0 +1,40 @@ +/** + * Pending Message Service + * + * This service handles server-side message queuing for V1 conversations. + * Messages can be queued when the WebSocket is not connected and will be + * delivered automatically when the conversation becomes ready. + */ + +import { openHands } from "../open-hands-axios"; +import type { + PendingMessageResponse, + QueuePendingMessageRequest, +} from "./pending-message-service.types"; + +class PendingMessageService { + /** + * Queue a message for delivery when conversation becomes ready. + * + * This endpoint allows users to submit messages even when the conversation's + * WebSocket connection is not yet established. Messages are stored server-side + * and delivered automatically when the conversation transitions to READY status. + * + * @param conversationId The conversation ID (can be task ID before conversation is ready) + * @param message The message to queue + * @returns PendingMessageResponse with the message ID and queue position + * @throws Error if too many pending messages (limit: 10 per conversation) + */ + static async queueMessage( + conversationId: string, + message: QueuePendingMessageRequest, + ): Promise { + const { data } = await openHands.post( + `/api/v1/conversations/${conversationId}/pending-messages`, + message, + ); + return data; + } +} + +export default PendingMessageService; diff --git a/frontend/src/api/pending-message-service/pending-message-service.types.ts b/frontend/src/api/pending-message-service/pending-message-service.types.ts new file mode 100644 index 0000000000..cf7b8dbf0e --- /dev/null +++ b/frontend/src/api/pending-message-service/pending-message-service.types.ts @@ -0,0 +1,22 @@ +/** + * Types for the pending message service + */ + +import type { V1MessageContent } from "../conversation-service/v1-conversation-service.types"; + +/** + * Response when queueing a pending message + */ +export interface PendingMessageResponse { + id: string; + queued: boolean; + position: number; +} + +/** + * Request to queue a pending message + */ +export interface QueuePendingMessageRequest { + role?: "user"; + content: V1MessageContent[]; +} diff --git a/frontend/src/components/features/chat/chat-interface.tsx b/frontend/src/components/features/chat/chat-interface.tsx index 85a9435678..43218149ae 100644 --- a/frontend/src/components/features/chat/chat-interface.tsx +++ b/frontend/src/components/features/chat/chat-interface.tsx @@ -190,8 +190,14 @@ export function ChatInterface() { const prompt = uploadedFiles.length > 0 ? `${content}\n\n${filePrompt}` : content; - send(createChatMessage(prompt, imageUrls, uploadedFiles, timestamp)); - setOptimisticUserMessage(content); + const result = await send( + createChatMessage(prompt, imageUrls, uploadedFiles, timestamp), + ); + // Only show optimistic UI if message was sent immediately via WebSocket + // If queued for later delivery, the message will appear when actually delivered + if (!result.queued) { + setOptimisticUserMessage(content); + } setMessageToSend(""); }; diff --git a/frontend/src/components/features/chat/custom-chat-input.tsx b/frontend/src/components/features/chat/custom-chat-input.tsx index 5fd92fdcd6..26a0f74ca9 100644 --- a/frontend/src/components/features/chat/custom-chat-input.tsx +++ b/frontend/src/components/features/chat/custom-chat-input.tsx @@ -60,6 +60,7 @@ export function CustomChatInput({ messageToSend, checkIsContentEmpty, clearEmptyContentHandler, + saveDraft, } = useChatInputLogic(); const { @@ -158,6 +159,7 @@ export function CustomChatInput({ onInput={() => { handleInput(); updateSlashMenu(); + saveDraft(); }} onPaste={handlePaste} onKeyDown={(e) => { diff --git a/frontend/src/components/features/chat/interactive-chat-box.tsx b/frontend/src/components/features/chat/interactive-chat-box.tsx index a2f1df8348..74818d1d6c 100644 --- a/frontend/src/components/features/chat/interactive-chat-box.tsx +++ b/frontend/src/components/features/chat/interactive-chat-box.tsx @@ -142,8 +142,9 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) { handleSubmit(suggestion); }; + // Allow users to submit messages during LOADING state - they will be + // queued server-side and delivered when the conversation becomes ready const isDisabled = - curAgentState === AgentState.LOADING || curAgentState === AgentState.AWAITING_USER_CONFIRMATION || isTaskPolling(subConversationTaskStatus); diff --git a/frontend/src/contexts/conversation-websocket-context.tsx b/frontend/src/contexts/conversation-websocket-context.tsx index 572ab4fd75..86863734b9 100644 --- a/frontend/src/contexts/conversation-websocket-context.tsx +++ b/frontend/src/contexts/conversation-websocket-context.tsx @@ -40,6 +40,7 @@ import type { V1SendMessageRequest, } from "#/api/conversation-service/v1-conversation-service.types"; import EventService from "#/api/event-service/event-service.api"; +import PendingMessageService from "#/api/pending-message-service/pending-message-service.api"; import { useConversationStore } from "#/stores/conversation-store"; import { isBudgetOrCreditError, trackError } from "#/utils/error-handler"; import { useTracking } from "#/hooks/use-tracking"; @@ -47,6 +48,7 @@ import { useReadConversationFile } from "#/hooks/mutation/use-read-conversation- import useMetricsStore from "#/stores/metrics-store"; import { I18nKey } from "#/i18n/declaration"; import { useConversationHistory } from "#/hooks/query/use-conversation-history"; +import { setConversationState } from "#/utils/conversation-local-storage"; // eslint-disable-next-line @typescript-eslint/naming-convention export type V1_WebSocketConnectionState = @@ -55,9 +57,13 @@ export type V1_WebSocketConnectionState = | "CLOSED" | "CLOSING"; +interface SendMessageResult { + queued: boolean; // true if message was queued for later delivery, false if sent immediately +} + interface ConversationWebSocketContextType { connectionState: V1_WebSocketConnectionState; - sendMessage: (message: V1SendMessageRequest) => Promise; + sendMessage: (message: V1SendMessageRequest) => Promise; isLoadingHistory: boolean; } @@ -397,6 +403,10 @@ export function ConversationWebSocketProvider({ // Clear optimistic user message when a user message is confirmed if (isUserMessageEvent(event)) { removeOptimisticUserMessage(); + // Clear draft from localStorage - message was successfully delivered + if (conversationId) { + setConversationState(conversationId, { draftMessage: null }); + } } // Handle cache invalidation for ActionEvent @@ -556,6 +566,11 @@ export function ConversationWebSocketProvider({ // Clear optimistic user message when a user message is confirmed if (isUserMessageEvent(event)) { removeOptimisticUserMessage(); + // Clear draft from localStorage - message was successfully delivered + // Use main conversationId since user types in main conversation input + if (conversationId) { + setConversationState(conversationId, { draftMessage: null }); + } } // Handle cache invalidation for ActionEvent @@ -810,21 +825,44 @@ export function ConversationWebSocketProvider({ ); // V1 send message function via WebSocket + // Falls back to REST API queue when WebSocket is not connected const sendMessage = useCallback( - async (message: V1SendMessageRequest) => { + async (message: V1SendMessageRequest): Promise => { const currentMode = useConversationStore.getState().conversationMode; const currentSocket = currentMode === "plan" ? planningAgentSocket : mainSocket; if (!currentSocket || currentSocket.readyState !== WebSocket.OPEN) { - const error = "WebSocket is not connected"; - setErrorMessage(error); - throw new Error(error); + // WebSocket not connected - queue message via REST API + // Message will be delivered automatically when conversation becomes ready + if (!conversationId) { + const error = new Error("No conversation ID available"); + setErrorMessage(error.message); + throw error; + } + + try { + await PendingMessageService.queueMessage(conversationId, { + role: "user", + content: message.content, + }); + // Message queued successfully - it will be delivered when ready + // Return queued: true so caller knows not to show optimistic UI + return { queued: true }; + } catch (error) { + const errorMessage = + error instanceof Error + ? error.message + : "Failed to queue message for delivery"; + setErrorMessage(errorMessage); + throw error; + } } try { // Send message through WebSocket as JSON currentSocket.send(JSON.stringify(message)); + return { queued: false }; } catch (error) { const errorMessage = error instanceof Error ? error.message : "Failed to send message"; @@ -832,7 +870,7 @@ export function ConversationWebSocketProvider({ throw error; } }, - [mainSocket, planningAgentSocket, setErrorMessage], + [mainSocket, planningAgentSocket, setErrorMessage, conversationId], ); // Track main socket state changes diff --git a/frontend/src/hooks/chat/use-chat-input-logic.ts b/frontend/src/hooks/chat/use-chat-input-logic.ts index 21dc682fc9..47a6fafacb 100644 --- a/frontend/src/hooks/chat/use-chat-input-logic.ts +++ b/frontend/src/hooks/chat/use-chat-input-logic.ts @@ -5,12 +5,15 @@ import { getTextContent, } from "#/components/features/chat/utils/chat-input.utils"; import { useConversationStore } from "#/stores/conversation-store"; +import { useConversationId } from "#/hooks/use-conversation-id"; +import { useDraftPersistence } from "./use-draft-persistence"; /** * Hook for managing chat input content logic */ export const useChatInputLogic = () => { const chatInputRef = useRef(null); + const { conversationId } = useConversationId(); const { messageToSend, @@ -19,6 +22,12 @@ export const useChatInputLogic = () => { setIsRightPanelShown, } = useConversationStore(); + // Draft persistence - saves to localStorage, restores on mount + const { saveDraft, clearDraft } = useDraftPersistence( + conversationId, + chatInputRef, + ); + // Save current input value when drawer state changes useEffect(() => { if (chatInputRef.current) { @@ -51,5 +60,7 @@ export const useChatInputLogic = () => { checkIsContentEmpty, clearEmptyContentHandler, getCurrentMessage, + saveDraft, + clearDraft, }; }; diff --git a/frontend/src/hooks/chat/use-draft-persistence.ts b/frontend/src/hooks/chat/use-draft-persistence.ts new file mode 100644 index 0000000000..fd958030b1 --- /dev/null +++ b/frontend/src/hooks/chat/use-draft-persistence.ts @@ -0,0 +1,179 @@ +import { useEffect, useRef, useCallback, useState } from "react"; +import { + useConversationLocalStorageState, + getConversationState, + setConversationState, +} from "#/utils/conversation-local-storage"; +import { getTextContent } from "#/components/features/chat/utils/chat-input.utils"; + +/** + * Check if a conversation ID is a temporary task ID. + * Task IDs have the format "task-{uuid}" and are used during V1 conversation initialization. + */ +const isTaskId = (id: string): boolean => id.startsWith("task-"); + +const DRAFT_SAVE_DEBOUNCE_MS = 500; + +/** + * Hook for persisting draft messages to localStorage. + * Handles debounced saving on input, restoration on mount, and clearing on confirmed delivery. + */ +export const useDraftPersistence = ( + conversationId: string, + chatInputRef: React.RefObject, +) => { + const { state, setDraftMessage } = + useConversationLocalStorageState(conversationId); + const saveTimeoutRef = useRef | null>(null); + const hasRestoredRef = useRef(false); + const [isRestored, setIsRestored] = useState(false); + + // Track current conversationId to prevent saving draft to wrong conversation + const currentConversationIdRef = useRef(conversationId); + // Track if this is the first mount to handle initial cleanup + const isFirstMountRef = useRef(true); + + // IMPORTANT: This effect must run FIRST when conversation changes. + // It handles three concerns: + // 1. Cleanup: Cancel pending saves from previous conversation + // 2. Task-to-real transition: Preserve draft typed during initialization + // 3. DOM reset: Clear stale content before restoration effect runs + useEffect(() => { + const previousConversationId = currentConversationIdRef.current; + const isInitialMount = isFirstMountRef.current; + currentConversationIdRef.current = conversationId; + isFirstMountRef.current = false; + + // --- 1. Cancel pending saves from previous conversation --- + // Prevents draft from being saved to wrong conversation if user switched quickly + if (saveTimeoutRef.current) { + clearTimeout(saveTimeoutRef.current); + saveTimeoutRef.current = null; + } + + const element = chatInputRef.current; + + // --- 2. Handle task-to-real ID transition (preserve draft during initialization) --- + // When a new V1 conversation initializes, it starts with a temporary "task-xxx" ID + // that transitions to a real conversation ID once ready. Task IDs don't persist + // to localStorage, so any draft typed during this phase would be lost. + // We detect this transition and transfer the draft to the new real ID. + if (!isInitialMount && previousConversationId !== conversationId) { + const wasTaskId = isTaskId(previousConversationId); + const isNowRealId = !isTaskId(conversationId); + + if (wasTaskId && isNowRealId && element) { + const currentText = getTextContent(element).trim(); + if (currentText) { + // Transfer draft to the new (real) conversation ID + setConversationState(conversationId, { draftMessage: currentText }); + // Keep draft visible in DOM and mark as restored to prevent overwrite + hasRestoredRef.current = true; + setIsRestored(true); + return; // Skip normal cleanup - draft is already in correct state + } + } + } + + // --- 3. Clear stale DOM content (will be restored by next effect if draft exists) --- + // This prevents stale drafts from appearing in new conversations due to: + // - Browser form restoration on back/forward navigation + // - React DOM recycling between conversation switches + // The restoration effect will then populate with the correct saved draft + if (element) { + element.textContent = ""; + } + + // Reset restoration flag so the restoration effect will run for new conversation + hasRestoredRef.current = false; + setIsRestored(false); + }, [conversationId, chatInputRef]); + + // Restore draft from localStorage - reads directly to avoid state sync timing issues + useEffect(() => { + if (hasRestoredRef.current) { + return; + } + + const element = chatInputRef.current; + if (!element) { + return; + } + + // Read directly from localStorage to avoid stale state from useConversationLocalStorageState + // The hook's state may not have synced yet after conversationId change + const { draftMessage } = getConversationState(conversationId); + + // Only restore if there's a saved draft and the input is empty + if (draftMessage && getTextContent(element).trim() === "") { + element.textContent = draftMessage; + // Move cursor to end + const selection = window.getSelection(); + const range = document.createRange(); + range.selectNodeContents(element); + range.collapse(false); + selection?.removeAllRanges(); + selection?.addRange(range); + } + + hasRestoredRef.current = true; + setIsRestored(true); + }, [chatInputRef, conversationId]); + + // Debounced save function - called from onInput handler + const saveDraft = useCallback(() => { + // Clear any pending save + if (saveTimeoutRef.current) { + clearTimeout(saveTimeoutRef.current); + } + + // Capture the conversationId at the time of input + const capturedConversationId = conversationId; + + saveTimeoutRef.current = setTimeout(() => { + // Verify we're still on the same conversation before saving + // This prevents saving draft to wrong conversation if user switched quickly + if (capturedConversationId !== currentConversationIdRef.current) { + return; + } + + const element = chatInputRef.current; + if (!element) { + return; + } + + const text = getTextContent(element).trim(); + // Only save if content has changed + if (text !== (state.draftMessage || "")) { + setDraftMessage(text || null); + } + }, DRAFT_SAVE_DEBOUNCE_MS); + }, [chatInputRef, state.draftMessage, setDraftMessage, conversationId]); + + // Clear draft - called after message delivery is confirmed + const clearDraft = useCallback(() => { + // Cancel any pending save + if (saveTimeoutRef.current) { + clearTimeout(saveTimeoutRef.current); + saveTimeoutRef.current = null; + } + setDraftMessage(null); + }, [setDraftMessage]); + + // Cleanup timeout on unmount + useEffect( + () => () => { + if (saveTimeoutRef.current) { + clearTimeout(saveTimeoutRef.current); + } + }, + [], + ); + + return { + saveDraft, + clearDraft, + isRestored, + hasDraft: !!state.draftMessage, + }; +}; diff --git a/frontend/src/hooks/use-send-message.ts b/frontend/src/hooks/use-send-message.ts index 4da5eafc2e..3f641521e6 100644 --- a/frontend/src/hooks/use-send-message.ts +++ b/frontend/src/hooks/use-send-message.ts @@ -5,6 +5,10 @@ import { useConversationWebSocket } from "#/contexts/conversation-websocket-cont import { useConversationId } from "#/hooks/use-conversation-id"; import { V1MessageContent } from "#/api/conversation-service/v1-conversation-service.types"; +interface SendResult { + queued: boolean; // true if message was queued for later delivery +} + /** * Unified hook for sending messages that works with both V0 and V1 conversations * - For V0 conversations: Uses Socket.IO WebSocket via useWsClient @@ -26,7 +30,7 @@ export function useSendMessage() { conversation?.conversation_version === "V1"; const send = useCallback( - async (event: Record) => { + async (event: Record): Promise => { if (isV1Conversation && v1Context) { // V1: Convert V0 event format to V1 message format const { action, args } = event as { @@ -57,19 +61,20 @@ export function useSendMessage() { } // Send via V1 WebSocket context (uses correct host/port) - await v1Context.sendMessage({ + const result = await v1Context.sendMessage({ role: "user", content, }); - } else { - // For non-message events, fall back to V0 send - // (e.g., agent state changes, other control events) - v0Send(event); + return result; } - } else { - // V0: Use Socket.IO + // For non-message events, fall back to V0 send + // (e.g., agent state changes, other control events) v0Send(event); + return { queued: false }; } + // V0: Use Socket.IO + v0Send(event); + return { queued: false }; }, [isV1Conversation, v1Context, v0Send, conversationId], ); diff --git a/frontend/src/utils/conversation-local-storage.ts b/frontend/src/utils/conversation-local-storage.ts index de16da9f55..4beb800b88 100644 --- a/frontend/src/utils/conversation-local-storage.ts +++ b/frontend/src/utils/conversation-local-storage.ts @@ -23,6 +23,7 @@ export interface ConversationState { unpinnedTabs: string[]; conversationMode: ConversationMode; subConversationTaskId: string | null; + draftMessage: string | null; } const DEFAULT_CONVERSATION_STATE: ConversationState = { @@ -31,6 +32,7 @@ const DEFAULT_CONVERSATION_STATE: ConversationState = { unpinnedTabs: [], conversationMode: "code", subConversationTaskId: null, + draftMessage: null, }; /** @@ -121,6 +123,7 @@ export function useConversationLocalStorageState(conversationId: string): { setRightPanelShown: (shown: boolean) => void; setUnpinnedTabs: (tabs: string[]) => void; setConversationMode: (mode: ConversationMode) => void; + setDraftMessage: (message: string | null) => void; } { const [state, setState] = useState(() => getConversationState(conversationId), @@ -178,5 +181,6 @@ export function useConversationLocalStorageState(conversationId: string): { setRightPanelShown: (shown) => updateState({ rightPanelShown: shown }), setUnpinnedTabs: (tabs) => updateState({ unpinnedTabs: tabs }), setConversationMode: (mode) => updateState({ conversationMode: mode }), + setDraftMessage: (message) => updateState({ draftMessage: message }), }; } diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index 7bb59fedbf..fe07f205c1 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -59,6 +59,9 @@ from openhands.app_server.event_callback.event_callback_service import ( from openhands.app_server.event_callback.set_title_callback_processor import ( SetTitleCallbackProcessor, ) +from openhands.app_server.pending_messages.pending_message_service import ( + PendingMessageService, +) from openhands.app_server.sandbox.docker_sandbox_service import DockerSandboxService from openhands.app_server.sandbox.sandbox_models import ( AGENT_SERVER, @@ -127,6 +130,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase): sandbox_service: SandboxService sandbox_spec_service: SandboxSpecService jwt_service: JwtService + pending_message_service: PendingMessageService sandbox_startup_timeout: int sandbox_startup_poll_frequency: int max_num_conversations_per_sandbox: int @@ -373,6 +377,15 @@ class LiveStatusAppConversationService(AppConversationServiceBase): task.app_conversation_id = info.id yield task + # Process any pending messages queued while waiting for conversation + if sandbox.session_api_key: + await self._process_pending_messages( + task_id=task.id, + conversation_id=info.id, + agent_server_url=agent_server_url, + session_api_key=sandbox.session_api_key, + ) + except Exception as exc: _logger.exception('Error starting conversation', stack_info=True) task.status = AppConversationStartTaskStatus.ERROR @@ -1424,6 +1437,89 @@ class LiveStatusAppConversationService(AppConversationServiceBase): plugins=plugins, ) + async def _process_pending_messages( + self, + task_id: UUID, + conversation_id: UUID, + agent_server_url: str, + session_api_key: str, + ) -> None: + """Process pending messages queued before conversation was ready. + + Messages are delivered concurrently to the agent server. After processing, + all messages are deleted from the database regardless of success or failure. + + Args: + task_id: The start task ID (may have been used as conversation_id initially) + conversation_id: The real conversation ID + agent_server_url: URL of the agent server + session_api_key: API key for authenticating with agent server + """ + # Convert UUIDs to strings for the pending message service + # The frontend uses task-{uuid.hex} format (no hyphens), matching OpenHandsUUID serialization + task_id_str = f'task-{task_id.hex}' + # conversation_id uses standard format (with hyphens) for agent server API compatibility + conversation_id_str = str(conversation_id) + + _logger.info(f'task_id={task_id_str} conversation_id={conversation_id_str}') + + # First, update any messages that were queued with the task_id + updated_count = await self.pending_message_service.update_conversation_id( + old_conversation_id=task_id_str, + new_conversation_id=conversation_id_str, + ) + _logger.info(f'updated_count={updated_count} ') + if updated_count > 0: + _logger.info( + f'Updated {updated_count} pending messages from task_id={task_id_str} ' + f'to conversation_id={conversation_id_str}' + ) + + # Get all pending messages for this conversation + pending_messages = await self.pending_message_service.get_pending_messages( + conversation_id_str + ) + + if not pending_messages: + return + + _logger.info( + f'Processing {len(pending_messages)} pending messages for ' + f'conversation {conversation_id_str}' + ) + + # Process messages sequentially to preserve order + for msg in pending_messages: + try: + # Serialize content objects to JSON-compatible dicts + content_json = [item.model_dump() for item in msg.content] + # Use the events endpoint which handles message sending + response = await self.httpx_client.post( + f'{agent_server_url}/api/conversations/{conversation_id_str}/events', + json={ + 'role': msg.role, + 'content': content_json, + 'run': True, + }, + headers={'X-Session-API-Key': session_api_key}, + timeout=30.0, + ) + response.raise_for_status() + _logger.debug(f'Delivered pending message {msg.id}') + except Exception as e: + _logger.warning(f'Failed to deliver pending message {msg.id}: {e}') + + # Delete all pending messages after processing (regardless of success/failure) + deleted_count = ( + await self.pending_message_service.delete_messages_for_conversation( + conversation_id_str + ) + ) + _logger.info( + f'Finished processing pending messages for conversation {conversation_id_str}. ' + f'Deleted {deleted_count} messages.' + ) + async def update_agent_server_conversation_title( self, conversation_id: str, @@ -1796,6 +1892,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): get_global_config, get_httpx_client, get_jwt_service, + get_pending_message_service, get_sandbox_service, get_sandbox_spec_service, get_user_context, @@ -1815,6 +1912,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): get_event_service(state, request) as event_service, get_jwt_service(state, request) as jwt_service, get_httpx_client(state, request) as httpx_client, + get_pending_message_service(state, request) as pending_message_service, ): access_token_hard_timeout = None if self.access_token_hard_timeout: @@ -1859,6 +1957,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): event_callback_service=event_callback_service, event_service=event_service, jwt_service=jwt_service, + pending_message_service=pending_message_service, sandbox_startup_timeout=self.sandbox_startup_timeout, sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency, max_num_conversations_per_sandbox=self.max_num_conversations_per_sandbox, diff --git a/openhands/app_server/app_lifespan/alembic/versions/007.py b/openhands/app_server/app_lifespan/alembic/versions/007.py new file mode 100644 index 0000000000..ef0b34b2eb --- /dev/null +++ b/openhands/app_server/app_lifespan/alembic/versions/007.py @@ -0,0 +1,39 @@ +"""Add pending_messages table for server-side message queuing + +Revision ID: 007 +Revises: 006 +Create Date: 2025-03-15 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '007' +down_revision: Union[str, None] = '006' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create pending_messages table for storing messages before conversation is ready. + + Messages are stored temporarily until the conversation becomes ready, then + delivered and deleted regardless of success or failure. + """ + op.create_table( + 'pending_messages', + sa.Column('id', sa.String(), primary_key=True), + sa.Column('conversation_id', sa.String(), nullable=False, index=True), + sa.Column('role', sa.String(20), nullable=False, server_default='user'), + sa.Column('content', sa.JSON, nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + ) + + +def downgrade() -> None: + """Remove pending_messages table.""" + op.drop_table('pending_messages') diff --git a/openhands/app_server/config.py b/openhands/app_server/config.py index 8c0ded6d2e..4b7f78e389 100644 --- a/openhands/app_server/config.py +++ b/openhands/app_server/config.py @@ -33,6 +33,10 @@ from openhands.app_server.event_callback.event_callback_service import ( EventCallbackService, EventCallbackServiceInjector, ) +from openhands.app_server.pending_messages.pending_message_service import ( + PendingMessageService, + PendingMessageServiceInjector, +) from openhands.app_server.sandbox.sandbox_service import ( SandboxService, SandboxServiceInjector, @@ -114,6 +118,7 @@ class AppServerConfig(OpenHandsModel): app_conversation_info: AppConversationInfoServiceInjector | None = None app_conversation_start_task: AppConversationStartTaskServiceInjector | None = None app_conversation: AppConversationServiceInjector | None = None + pending_message: PendingMessageServiceInjector | None = None user: UserContextInjector | None = None jwt: JwtServiceInjector | None = None httpx: HttpxClientInjector = Field(default_factory=HttpxClientInjector) @@ -280,6 +285,13 @@ def config_from_env() -> AppServerConfig: tavily_api_key=tavily_api_key ) + if config.pending_message is None: + from openhands.app_server.pending_messages.pending_message_service import ( + SQLPendingMessageServiceInjector, + ) + + config.pending_message = SQLPendingMessageServiceInjector() + if config.user is None: config.user = AuthUserContextInjector() @@ -358,6 +370,14 @@ def get_app_conversation_service( return injector.context(state, request) +def get_pending_message_service( + state: InjectorState, request: Request | None = None +) -> AsyncContextManager[PendingMessageService]: + injector = get_global_config().pending_message + assert injector is not None + return injector.context(state, request) + + def get_user_context( state: InjectorState, request: Request | None = None ) -> AsyncContextManager[UserContext]: @@ -433,6 +453,12 @@ def depends_app_conversation_service(): return Depends(injector.depends) +def depends_pending_message_service(): + injector = get_global_config().pending_message + assert injector is not None + return Depends(injector.depends) + + def depends_user_context(): injector = get_global_config().user assert injector is not None diff --git a/openhands/app_server/pending_messages/__init__.py b/openhands/app_server/pending_messages/__init__.py new file mode 100644 index 0000000000..5aa37fc675 --- /dev/null +++ b/openhands/app_server/pending_messages/__init__.py @@ -0,0 +1,21 @@ +"""Pending messages module for server-side message queuing.""" + +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessage, + PendingMessageResponse, +) +from openhands.app_server.pending_messages.pending_message_service import ( + PendingMessageService, + PendingMessageServiceInjector, + SQLPendingMessageService, + SQLPendingMessageServiceInjector, +) + +__all__ = [ + 'PendingMessage', + 'PendingMessageResponse', + 'PendingMessageService', + 'PendingMessageServiceInjector', + 'SQLPendingMessageService', + 'SQLPendingMessageServiceInjector', +] diff --git a/openhands/app_server/pending_messages/pending_message_models.py b/openhands/app_server/pending_messages/pending_message_models.py new file mode 100644 index 0000000000..9e0062b185 --- /dev/null +++ b/openhands/app_server/pending_messages/pending_message_models.py @@ -0,0 +1,32 @@ +"""Models for pending message queue functionality.""" + +from datetime import datetime +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from openhands.agent_server.models import ImageContent, TextContent +from openhands.agent_server.utils import utc_now + + +class PendingMessage(BaseModel): + """A message queued for delivery when conversation becomes ready. + + Pending messages are stored in the database and delivered to the agent_server + when the conversation transitions to READY status. Messages are deleted after + processing, regardless of success or failure. + """ + + id: str = Field(default_factory=lambda: str(uuid4())) + conversation_id: str # Can be task-{uuid} or real conversation UUID + role: str = 'user' + content: list[TextContent | ImageContent] + created_at: datetime = Field(default_factory=utc_now) + + +class PendingMessageResponse(BaseModel): + """Response when queueing a pending message.""" + + id: str + queued: bool + position: int = Field(description='Position in the queue (1-based)') diff --git a/openhands/app_server/pending_messages/pending_message_router.py b/openhands/app_server/pending_messages/pending_message_router.py new file mode 100644 index 0000000000..7c78e2d6eb --- /dev/null +++ b/openhands/app_server/pending_messages/pending_message_router.py @@ -0,0 +1,104 @@ +"""REST API router for pending messages.""" + +import logging + +from fastapi import APIRouter, HTTPException, Request, status +from pydantic import TypeAdapter, ValidationError + +from openhands.agent_server.models import ImageContent, TextContent +from openhands.app_server.config import depends_pending_message_service +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessageResponse, +) +from openhands.app_server.pending_messages.pending_message_service import ( + PendingMessageService, +) +from openhands.server.dependencies import get_dependencies + +logger = logging.getLogger(__name__) + +# Type adapter for validating content from request +_content_type_adapter = TypeAdapter(list[TextContent | ImageContent]) + +# Create router with authentication dependencies +router = APIRouter( + prefix='/conversations/{conversation_id}/pending-messages', + tags=['Pending Messages'], + dependencies=get_dependencies(), +) + +# Create dependency at module level +pending_message_service_dependency = depends_pending_message_service() + + +@router.post( + '', response_model=PendingMessageResponse, status_code=status.HTTP_201_CREATED +) +async def queue_pending_message( + conversation_id: str, + request: Request, + pending_service: PendingMessageService = pending_message_service_dependency, +) -> PendingMessageResponse: + """Queue a message for delivery when conversation becomes ready. + + This endpoint allows users to submit messages even when the conversation's + WebSocket connection is not yet established. Messages are stored server-side + and delivered automatically when the conversation transitions to READY status. + + Args: + conversation_id: The conversation ID (can be task ID before conversation is ready) + request: The FastAPI request containing message content + + Returns: + PendingMessageResponse with the message ID and queue position + + Raises: + HTTPException 400: If the request body is invalid + HTTPException 429: If too many pending messages are queued (limit: 10) + """ + try: + body = await request.json() + except Exception: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='Invalid request body', + ) + + raw_content = body.get('content') + role = body.get('role', 'user') + + if not raw_content or not isinstance(raw_content, list): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='content must be a non-empty list', + ) + + # Validate and parse content into typed objects + try: + content = _content_type_adapter.validate_python(raw_content) + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f'Invalid content format: {e}', + ) + + # Rate limit: max 10 pending messages per conversation + pending_count = await pending_service.count_pending_messages(conversation_id) + if pending_count >= 10: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail='Too many pending messages. Maximum 10 messages per conversation.', + ) + + response = await pending_service.add_message( + conversation_id=conversation_id, + content=content, + role=role, + ) + + logger.info( + f'Queued pending message {response.id} for conversation {conversation_id} ' + f'(position: {response.position})' + ) + + return response diff --git a/openhands/app_server/pending_messages/pending_message_service.py b/openhands/app_server/pending_messages/pending_message_service.py new file mode 100644 index 0000000000..44d426c409 --- /dev/null +++ b/openhands/app_server/pending_messages/pending_message_service.py @@ -0,0 +1,200 @@ +"""Service for managing pending messages in SQL database.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import AsyncGenerator + +from fastapi import Request +from pydantic import TypeAdapter +from sqlalchemy import JSON, Column, String, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.agent_server.models import ImageContent, TextContent +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessage, + PendingMessageResponse, +) +from openhands.app_server.services.injector import Injector, InjectorState +from openhands.app_server.utils.sql_utils import Base, UtcDateTime +from openhands.sdk.utils.models import DiscriminatedUnionMixin + +# Type adapter for deserializing content from JSON +_content_type_adapter = TypeAdapter(list[TextContent | ImageContent]) + + +class StoredPendingMessage(Base): # type: ignore + """SQLAlchemy model for pending messages.""" + + __tablename__ = 'pending_messages' + id = Column(String, primary_key=True) + conversation_id = Column(String, nullable=False, index=True) + role = Column(String(20), nullable=False, default='user') + content = Column(JSON, nullable=False) + created_at = Column(UtcDateTime, server_default=func.now(), index=True) + + +class PendingMessageService(ABC): + """Abstract service for managing pending messages.""" + + @abstractmethod + async def add_message( + self, + conversation_id: str, + content: list[TextContent | ImageContent], + role: str = 'user', + ) -> PendingMessageResponse: + """Queue a message for delivery when conversation becomes ready.""" + + @abstractmethod + async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]: + """Get all pending messages for a conversation, ordered by created_at.""" + + @abstractmethod + async def count_pending_messages(self, conversation_id: str) -> int: + """Count pending messages for a conversation.""" + + @abstractmethod + async def delete_messages_for_conversation(self, conversation_id: str) -> int: + """Delete all pending messages for a conversation, returning count deleted.""" + + @abstractmethod + async def update_conversation_id( + self, old_conversation_id: str, new_conversation_id: str + ) -> int: + """Update conversation_id when task-id transitions to real conversation-id. + + Returns the number of messages updated. + """ + + +@dataclass +class SQLPendingMessageService(PendingMessageService): + """SQL implementation of PendingMessageService.""" + + db_session: AsyncSession + + async def add_message( + self, + conversation_id: str, + content: list[TextContent | ImageContent], + role: str = 'user', + ) -> PendingMessageResponse: + """Queue a message for delivery when conversation becomes ready.""" + # Create the pending message + pending_message = PendingMessage( + conversation_id=conversation_id, + role=role, + content=content, + ) + + # Count existing pending messages for position + count_stmt = select(func.count()).where( + StoredPendingMessage.conversation_id == conversation_id + ) + result = await self.db_session.execute(count_stmt) + position = result.scalar() or 0 + + # Serialize content to JSON-compatible format for storage + content_json = [item.model_dump() for item in content] + + # Store in database + stored_message = StoredPendingMessage( + id=str(pending_message.id), + conversation_id=conversation_id, + role=role, + content=content_json, + created_at=pending_message.created_at, + ) + self.db_session.add(stored_message) + await self.db_session.commit() + + return PendingMessageResponse( + id=pending_message.id, + queued=True, + position=position + 1, + ) + + async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]: + """Get all pending messages for a conversation, ordered by created_at.""" + stmt = ( + select(StoredPendingMessage) + .where(StoredPendingMessage.conversation_id == conversation_id) + .order_by(StoredPendingMessage.created_at.asc()) + ) + result = await self.db_session.execute(stmt) + stored_messages = result.scalars().all() + + return [ + PendingMessage( + id=msg.id, + conversation_id=msg.conversation_id, + role=msg.role, + content=_content_type_adapter.validate_python(msg.content), + created_at=msg.created_at, + ) + for msg in stored_messages + ] + + async def count_pending_messages(self, conversation_id: str) -> int: + """Count pending messages for a conversation.""" + count_stmt = select(func.count()).where( + StoredPendingMessage.conversation_id == conversation_id + ) + result = await self.db_session.execute(count_stmt) + return result.scalar() or 0 + + async def delete_messages_for_conversation(self, conversation_id: str) -> int: + """Delete all pending messages for a conversation, returning count deleted.""" + stmt = select(StoredPendingMessage).where( + StoredPendingMessage.conversation_id == conversation_id + ) + result = await self.db_session.execute(stmt) + stored_messages = result.scalars().all() + + count = len(stored_messages) + for msg in stored_messages: + await self.db_session.delete(msg) + + if count > 0: + await self.db_session.commit() + + return count + + async def update_conversation_id( + self, old_conversation_id: str, new_conversation_id: str + ) -> int: + """Update conversation_id when task-id transitions to real conversation-id.""" + stmt = select(StoredPendingMessage).where( + StoredPendingMessage.conversation_id == old_conversation_id + ) + result = await self.db_session.execute(stmt) + stored_messages = result.scalars().all() + + count = len(stored_messages) + for msg in stored_messages: + msg.conversation_id = new_conversation_id + + if count > 0: + await self.db_session.commit() + + return count + + +class PendingMessageServiceInjector( + DiscriminatedUnionMixin, Injector[PendingMessageService], ABC +): + """Abstract injector for PendingMessageService.""" + + pass + + +class SQLPendingMessageServiceInjector(PendingMessageServiceInjector): + """SQL-based injector for PendingMessageService.""" + + async def inject( + self, state: InjectorState, request: Request | None = None + ) -> AsyncGenerator[PendingMessageService, None]: + from openhands.app_server.config import get_db_session + + async with get_db_session(state) as db_session: + yield SQLPendingMessageService(db_session=db_session) diff --git a/openhands/app_server/v1_router.py b/openhands/app_server/v1_router.py index 2a21c06abd..81823b481c 100644 --- a/openhands/app_server/v1_router.py +++ b/openhands/app_server/v1_router.py @@ -5,6 +5,9 @@ from openhands.app_server.event import event_router from openhands.app_server.event_callback import ( webhook_router, ) +from openhands.app_server.pending_messages.pending_message_router import ( + router as pending_message_router, +) from openhands.app_server.sandbox import sandbox_router, sandbox_spec_router from openhands.app_server.user import user_router from openhands.app_server.web_client import web_client_router @@ -13,6 +16,7 @@ from openhands.app_server.web_client import web_client_router router = APIRouter(prefix='/api/v1') router.include_router(event_router.router) router.include_router(app_conversation_router.router) +router.include_router(pending_message_router) router.include_router(sandbox_router.router) router.include_router(sandbox_spec_router.router) router.include_router(user_router.router) diff --git a/tests/unit/app_server/test_live_status_app_conversation_service.py b/tests/unit/app_server/test_live_status_app_conversation_service.py index cf32cfaf05..fcb251797f 100644 --- a/tests/unit/app_server/test_live_status_app_conversation_service.py +++ b/tests/unit/app_server/test_live_status_app_conversation_service.py @@ -80,6 +80,7 @@ class TestLiveStatusAppConversationService: self.mock_event_callback_service = Mock() self.mock_event_service = Mock() self.mock_httpx_client = Mock() + self.mock_pending_message_service = Mock() # Create service instance self.service = LiveStatusAppConversationService( @@ -92,6 +93,7 @@ class TestLiveStatusAppConversationService: sandbox_service=self.mock_sandbox_service, sandbox_spec_service=self.mock_sandbox_spec_service, jwt_service=self.mock_jwt_service, + pending_message_service=self.mock_pending_message_service, sandbox_startup_timeout=30, sandbox_startup_poll_frequency=1, max_num_conversations_per_sandbox=20, @@ -2329,6 +2331,7 @@ class TestPluginHandling: self.mock_event_callback_service = Mock() self.mock_event_service = Mock() self.mock_httpx_client = Mock() + self.mock_pending_message_service = Mock() # Create service instance self.service = LiveStatusAppConversationService( @@ -2341,6 +2344,7 @@ class TestPluginHandling: sandbox_service=self.mock_sandbox_service, sandbox_spec_service=self.mock_sandbox_spec_service, jwt_service=self.mock_jwt_service, + pending_message_service=self.mock_pending_message_service, sandbox_startup_timeout=30, sandbox_startup_poll_frequency=1, max_num_conversations_per_sandbox=20, diff --git a/tests/unit/app_server/test_pending_message_router.py b/tests/unit/app_server/test_pending_message_router.py new file mode 100644 index 0000000000..92dbe2c4a4 --- /dev/null +++ b/tests/unit/app_server/test_pending_message_router.py @@ -0,0 +1,227 @@ +"""Unit tests for the pending_message_router endpoints. + +This module tests the queue_pending_message endpoint, +focusing on request validation and rate limiting. +""" + +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest +from fastapi import HTTPException, status + +from openhands.agent_server.models import TextContent +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessageResponse, +) +from openhands.app_server.pending_messages.pending_message_router import ( + queue_pending_message, +) + + +def _make_mock_service( + add_message_return=None, + count_pending_messages_return=0, +): + """Create a mock PendingMessageService for testing.""" + service = MagicMock() + service.add_message = AsyncMock(return_value=add_message_return) + service.count_pending_messages = AsyncMock( + return_value=count_pending_messages_return + ) + return service + + +def _make_mock_request(body: dict): + """Create a mock FastAPI Request with given JSON body.""" + request = MagicMock() + request.json = AsyncMock(return_value=body) + return request + + +@pytest.mark.asyncio +class TestQueuePendingMessage: + """Test suite for queue_pending_message endpoint.""" + + async def test_queues_message_successfully(self): + """Test that a valid message is queued successfully.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + raw_content = [{'type': 'text', 'text': 'Hello, world!'}] + expected_response = PendingMessageResponse( + id=str(uuid4()), + queued=True, + position=1, + ) + mock_service = _make_mock_service( + add_message_return=expected_response, + count_pending_messages_return=0, + ) + mock_request = _make_mock_request({'content': raw_content, 'role': 'user'}) + + # Act + result = await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + # Assert + assert result == expected_response + mock_service.add_message.assert_called_once() + call_kwargs = mock_service.add_message.call_args.kwargs + assert call_kwargs['conversation_id'] == conversation_id + assert call_kwargs['role'] == 'user' + # Content should be parsed into typed objects + assert len(call_kwargs['content']) == 1 + assert isinstance(call_kwargs['content'][0], TextContent) + assert call_kwargs['content'][0].text == 'Hello, world!' + + async def test_uses_default_role_when_not_provided(self): + """Test that 'user' role is used by default.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + raw_content = [{'type': 'text', 'text': 'Test message'}] + expected_response = PendingMessageResponse( + id=str(uuid4()), + queued=True, + position=1, + ) + mock_service = _make_mock_service( + add_message_return=expected_response, + count_pending_messages_return=0, + ) + mock_request = _make_mock_request({'content': raw_content}) + + # Act + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + # Assert + mock_service.add_message.assert_called_once() + call_kwargs = mock_service.add_message.call_args.kwargs + assert call_kwargs['conversation_id'] == conversation_id + assert call_kwargs['role'] == 'user' + assert isinstance(call_kwargs['content'][0], TextContent) + + async def test_returns_400_for_invalid_json_body(self): + """Test that invalid JSON body returns 400 Bad Request.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + mock_service = _make_mock_service() + mock_request = MagicMock() + mock_request.json = AsyncMock(side_effect=Exception('Invalid JSON')) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert 'Invalid request body' in exc_info.value.detail + + async def test_returns_400_when_content_is_missing(self): + """Test that missing content returns 400 Bad Request.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + mock_service = _make_mock_service() + mock_request = _make_mock_request({'role': 'user'}) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert 'content must be a non-empty list' in exc_info.value.detail + + async def test_returns_400_when_content_is_not_a_list(self): + """Test that non-list content returns 400 Bad Request.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + mock_service = _make_mock_service() + mock_request = _make_mock_request({'content': 'not a list'}) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert 'content must be a non-empty list' in exc_info.value.detail + + async def test_returns_400_when_content_is_empty_list(self): + """Test that empty list content returns 400 Bad Request.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + mock_service = _make_mock_service() + mock_request = _make_mock_request({'content': []}) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert 'content must be a non-empty list' in exc_info.value.detail + + async def test_returns_429_when_rate_limit_exceeded(self): + """Test that exceeding rate limit returns 429 Too Many Requests.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + raw_content = [{'type': 'text', 'text': 'Test message'}] + mock_service = _make_mock_service(count_pending_messages_return=10) + mock_request = _make_mock_request({'content': raw_content}) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS + assert 'Maximum 10 messages' in exc_info.value.detail + + async def test_allows_up_to_10_messages(self): + """Test that 9 existing messages still allows adding one more.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + raw_content = [{'type': 'text', 'text': 'Test message'}] + expected_response = PendingMessageResponse( + id=str(uuid4()), + queued=True, + position=10, + ) + mock_service = _make_mock_service( + add_message_return=expected_response, + count_pending_messages_return=9, + ) + mock_request = _make_mock_request({'content': raw_content}) + + # Act + result = await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + # Assert + assert result == expected_response + mock_service.add_message.assert_called_once() diff --git a/tests/unit/app_server/test_pending_message_service.py b/tests/unit/app_server/test_pending_message_service.py new file mode 100644 index 0000000000..869aae05d0 --- /dev/null +++ b/tests/unit/app_server/test_pending_message_service.py @@ -0,0 +1,309 @@ +"""Tests for SQLPendingMessageService. + +This module tests the SQL implementation of PendingMessageService, +covering message queuing, retrieval, counting, deletion, and +conversation_id updates using SQLite as a mock database. +""" + +from typing import AsyncGenerator +from uuid import uuid4 + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from openhands.agent_server.models import TextContent +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessageResponse, +) +from openhands.app_server.pending_messages.pending_message_service import ( + SQLPendingMessageService, +) +from openhands.app_server.utils.sql_utils import Base + + +@pytest.fixture +async def async_engine(): + """Create an async SQLite engine for testing.""" + engine = create_async_engine( + 'sqlite+aiosqlite:///:memory:', + poolclass=StaticPool, + connect_args={'check_same_thread': False}, + echo=False, + ) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + await engine.dispose() + + +@pytest.fixture +async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]: + """Create an async session for testing.""" + async_session_maker = async_sessionmaker( + async_engine, class_=AsyncSession, expire_on_commit=False + ) + + async with async_session_maker() as db_session: + yield db_session + + +@pytest.fixture +def service(async_session) -> SQLPendingMessageService: + """Create a SQLPendingMessageService instance for testing.""" + return SQLPendingMessageService(db_session=async_session) + + +@pytest.fixture +def sample_content() -> list[TextContent]: + """Create sample message content for testing.""" + return [TextContent(text='Hello, this is a test message')] + + +class TestSQLPendingMessageService: + """Test suite for SQLPendingMessageService.""" + + @pytest.mark.asyncio + async def test_add_message_creates_message_with_correct_data( + self, + service: SQLPendingMessageService, + sample_content: list[TextContent], + ): + """Test that add_message creates a message with the expected fields.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + + # Act + response = await service.add_message( + conversation_id=conversation_id, + content=sample_content, + role='user', + ) + + # Assert + assert isinstance(response, PendingMessageResponse) + assert response.queued is True + assert response.id is not None + + # Verify the message was stored correctly + messages = await service.get_pending_messages(conversation_id) + assert len(messages) == 1 + assert messages[0].conversation_id == conversation_id + assert len(messages[0].content) == 1 + assert isinstance(messages[0].content[0], TextContent) + assert messages[0].content[0].text == sample_content[0].text + assert messages[0].role == 'user' + assert messages[0].created_at is not None + + @pytest.mark.asyncio + async def test_add_message_returns_correct_queue_position( + self, + service: SQLPendingMessageService, + sample_content: list[TextContent], + ): + """Test that queue position increments correctly for each message.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + + # Act - Add three messages + response1 = await service.add_message(conversation_id, sample_content) + response2 = await service.add_message(conversation_id, sample_content) + response3 = await service.add_message(conversation_id, sample_content) + + # Assert + assert response1.position == 1 + assert response2.position == 2 + assert response3.position == 3 + + @pytest.mark.asyncio + async def test_get_pending_messages_returns_messages_ordered_by_created_at( + self, + service: SQLPendingMessageService, + ): + """Test that messages are returned in chronological order.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + contents = [ + [TextContent(text='First message')], + [TextContent(text='Second message')], + [TextContent(text='Third message')], + ] + + for content in contents: + await service.add_message(conversation_id, content) + + # Act + messages = await service.get_pending_messages(conversation_id) + + # Assert + assert len(messages) == 3 + assert messages[0].content[0].text == 'First message' + assert messages[1].content[0].text == 'Second message' + assert messages[2].content[0].text == 'Third message' + + @pytest.mark.asyncio + async def test_get_pending_messages_returns_empty_list_when_none_exist( + self, + service: SQLPendingMessageService, + ): + """Test that an empty list is returned for a conversation with no messages.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + + # Act + messages = await service.get_pending_messages(conversation_id) + + # Assert + assert messages == [] + + @pytest.mark.asyncio + async def test_count_pending_messages_returns_correct_count( + self, + service: SQLPendingMessageService, + sample_content: list[TextContent], + ): + """Test that count_pending_messages returns the correct number.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + other_conversation_id = f'task-{uuid4().hex}' + + # Add 3 messages to first conversation + for _ in range(3): + await service.add_message(conversation_id, sample_content) + + # Add 2 messages to second conversation + for _ in range(2): + await service.add_message(other_conversation_id, sample_content) + + # Act + count1 = await service.count_pending_messages(conversation_id) + count2 = await service.count_pending_messages(other_conversation_id) + count_empty = await service.count_pending_messages('nonexistent') + + # Assert + assert count1 == 3 + assert count2 == 2 + assert count_empty == 0 + + @pytest.mark.asyncio + async def test_delete_messages_for_conversation_removes_all_messages( + self, + service: SQLPendingMessageService, + sample_content: list[TextContent], + ): + """Test that delete_messages_for_conversation removes all messages and returns count.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + other_conversation_id = f'task-{uuid4().hex}' + + # Add messages to both conversations + for _ in range(3): + await service.add_message(conversation_id, sample_content) + await service.add_message(other_conversation_id, sample_content) + + # Act + deleted_count = await service.delete_messages_for_conversation(conversation_id) + + # Assert + assert deleted_count == 3 + assert await service.count_pending_messages(conversation_id) == 0 + # Other conversation should be unaffected + assert await service.count_pending_messages(other_conversation_id) == 1 + + @pytest.mark.asyncio + async def test_delete_messages_for_conversation_returns_zero_when_none_exist( + self, + service: SQLPendingMessageService, + ): + """Test that deleting from nonexistent conversation returns 0.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + + # Act + deleted_count = await service.delete_messages_for_conversation(conversation_id) + + # Assert + assert deleted_count == 0 + + @pytest.mark.asyncio + async def test_update_conversation_id_updates_all_matching_messages( + self, + service: SQLPendingMessageService, + sample_content: list[TextContent], + ): + """Test that update_conversation_id updates all messages with the old ID.""" + # Arrange + old_conversation_id = f'task-{uuid4().hex}' + new_conversation_id = str(uuid4()) + unrelated_conversation_id = f'task-{uuid4().hex}' + + # Add messages to old conversation + for _ in range(3): + await service.add_message(old_conversation_id, sample_content) + + # Add message to unrelated conversation + await service.add_message(unrelated_conversation_id, sample_content) + + # Act + updated_count = await service.update_conversation_id( + old_conversation_id, new_conversation_id + ) + + # Assert + assert updated_count == 3 + + # Verify old conversation has no messages + assert await service.count_pending_messages(old_conversation_id) == 0 + + # Verify new conversation has all messages + messages = await service.get_pending_messages(new_conversation_id) + assert len(messages) == 3 + for msg in messages: + assert msg.conversation_id == new_conversation_id + + # Verify unrelated conversation is unchanged + assert await service.count_pending_messages(unrelated_conversation_id) == 1 + + @pytest.mark.asyncio + async def test_update_conversation_id_returns_zero_when_no_match( + self, + service: SQLPendingMessageService, + ): + """Test that updating nonexistent conversation_id returns 0.""" + # Arrange + old_conversation_id = f'task-{uuid4().hex}' + new_conversation_id = str(uuid4()) + + # Act + updated_count = await service.update_conversation_id( + old_conversation_id, new_conversation_id + ) + + # Assert + assert updated_count == 0 + + @pytest.mark.asyncio + async def test_messages_are_isolated_between_conversations( + self, + service: SQLPendingMessageService, + ): + """Test that operations on one conversation don't affect others.""" + # Arrange + conv1 = f'task-{uuid4().hex}' + conv2 = f'task-{uuid4().hex}' + + await service.add_message(conv1, [TextContent(text='Conv1 msg')]) + await service.add_message(conv2, [TextContent(text='Conv2 msg')]) + + # Act + messages1 = await service.get_pending_messages(conv1) + messages2 = await service.get_pending_messages(conv2) + + # Assert + assert len(messages1) == 1 + assert len(messages2) == 1 + assert messages1[0].content[0].text == 'Conv1 msg' + assert messages2[0].content[0].text == 'Conv2 msg' diff --git a/tests/unit/server/data_models/test_conversation.py b/tests/unit/server/data_models/test_conversation.py index 7fa64ab12a..99dbdfaacc 100644 --- a/tests/unit/server/data_models/test_conversation.py +++ b/tests/unit/server/data_models/test_conversation.py @@ -2187,6 +2187,7 @@ async def test_delete_v1_conversation_with_sub_conversations(): sandbox_service=mock_sandbox_service, sandbox_spec_service=MagicMock(), jwt_service=MagicMock(), + pending_message_service=MagicMock(), sandbox_startup_timeout=120, sandbox_startup_poll_frequency=2, max_num_conversations_per_sandbox=20, @@ -2311,6 +2312,7 @@ async def test_delete_v1_conversation_with_no_sub_conversations(): sandbox_service=mock_sandbox_service, sandbox_spec_service=MagicMock(), jwt_service=MagicMock(), + pending_message_service=MagicMock(), sandbox_startup_timeout=120, sandbox_startup_poll_frequency=2, max_num_conversations_per_sandbox=20, @@ -2465,6 +2467,7 @@ async def test_delete_v1_conversation_sub_conversation_deletion_error(): sandbox_service=mock_sandbox_service, sandbox_spec_service=MagicMock(), jwt_service=MagicMock(), + pending_message_service=MagicMock(), sandbox_startup_timeout=120, sandbox_startup_poll_frequency=2, max_num_conversations_per_sandbox=20,