fix(frontend): prevent chat message loss during websocket disconnections or page refresh (#13380)

This commit is contained in:
Hiep Le
2026-03-16 22:25:44 +07:00
committed by GitHub
parent aec95ecf3b
commit 238cab4d08
29 changed files with 2668 additions and 22 deletions

View File

@@ -0,0 +1,39 @@
"""Add pending_messages table for server-side message queuing
Revision ID: 101
Revises: 100
Create Date: 2025-03-15 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '101'
down_revision: Union[str, None] = '100'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create pending_messages table for storing messages before conversation is ready.
Messages are stored temporarily until the conversation becomes ready, then
delivered and deleted regardless of success or failure.
"""
op.create_table(
'pending_messages',
sa.Column('id', sa.String(), primary_key=True),
sa.Column('conversation_id', sa.String(), nullable=False, index=True),
sa.Column('role', sa.String(20), nullable=False, server_default='user'),
sa.Column('content', sa.JSON, nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
)
def downgrade() -> None:
"""Remove pending_messages table."""
op.drop_table('pending_messages')

View File

@@ -0,0 +1,172 @@
"""Enterprise injector for PendingMessageService with SAAS filtering."""
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from sqlalchemy import select
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.user import User
from openhands.agent_server.models import ImageContent, TextContent
from openhands.app_server.errors import AuthError
from openhands.app_server.pending_messages.pending_message_models import (
PendingMessageResponse,
)
from openhands.app_server.pending_messages.pending_message_service import (
PendingMessageService,
PendingMessageServiceInjector,
SQLPendingMessageService,
)
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import ADMIN
from openhands.app_server.user.user_context import UserContext
class SaasSQLPendingMessageService(SQLPendingMessageService):
"""Extended SQLPendingMessageService with user and organization-based filtering.
This enterprise version ensures that:
- Users can only queue messages for conversations they own
- Organization isolation is enforced for multi-tenant deployments
"""
def __init__(self, db_session, user_context: UserContext):
super().__init__(db_session=db_session)
self.user_context = user_context
async def _get_current_user(self) -> User | None:
"""Get the current user using the existing db_session.
Returns:
User object or None if no user_id is available
"""
user_id_str = await self.user_context.get_user_id()
if not user_id_str:
return None
user_id_uuid = UUID(user_id_str)
result = await self.db_session.execute(
select(User).where(User.id == user_id_uuid)
)
return result.scalars().first()
async def _validate_conversation_ownership(self, conversation_id: str) -> None:
"""Validate that the current user owns the conversation.
This ensures multi-tenant isolation by checking:
- The conversation belongs to the current user
- The conversation belongs to the user's current organization
Args:
conversation_id: The conversation ID to validate (can be task-id or UUID)
Raises:
AuthError: If user doesn't own the conversation or authentication fails
"""
# For internal operations (e.g., processing pending messages during startup)
# we need a mode that bypasses filtering. The ADMIN context enables this.
if self.user_context == ADMIN:
return
user_id_str = await self.user_context.get_user_id()
if not user_id_str:
raise AuthError('User authentication required')
user_id_uuid = UUID(user_id_str)
# Check conversation ownership via SAAS metadata
query = select(StoredConversationMetadataSaas).where(
StoredConversationMetadataSaas.conversation_id == conversation_id
)
result = await self.db_session.execute(query)
saas_metadata = result.scalar_one_or_none()
# If no SAAS metadata exists, the conversation might be a new task-id
# that hasn't been linked to a conversation yet. Allow access in this case
# as the message will be validated when the conversation is created.
if saas_metadata is None:
return
# Verify user ownership
if saas_metadata.user_id != user_id_uuid:
raise AuthError('You do not have access to this conversation')
# Verify organization ownership if applicable
user = await self._get_current_user()
if user and user.current_org_id is not None:
if saas_metadata.org_id != user.current_org_id:
raise AuthError('Conversation belongs to a different organization')
async def add_message(
self,
conversation_id: str,
content: list[TextContent | ImageContent],
role: str = 'user',
) -> PendingMessageResponse:
"""Queue a message with ownership validation.
Args:
conversation_id: The conversation ID to queue the message for
content: Message content
role: Message role (default: 'user')
Returns:
PendingMessageResponse with the queued message info
Raises:
AuthError: If user doesn't own the conversation
"""
await self._validate_conversation_ownership(conversation_id)
return await super().add_message(conversation_id, content, role)
async def get_pending_messages(self, conversation_id: str):
"""Get pending messages with ownership validation.
Args:
conversation_id: The conversation ID to get messages for
Returns:
List of pending messages
Raises:
AuthError: If user doesn't own the conversation
"""
await self._validate_conversation_ownership(conversation_id)
return await super().get_pending_messages(conversation_id)
async def count_pending_messages(self, conversation_id: str) -> int:
"""Count pending messages with ownership validation.
Args:
conversation_id: The conversation ID to count messages for
Returns:
Number of pending messages
Raises:
AuthError: If user doesn't own the conversation
"""
await self._validate_conversation_ownership(conversation_id)
return await super().count_pending_messages(conversation_id)
class SaasPendingMessageServiceInjector(PendingMessageServiceInjector):
"""Enterprise injector for PendingMessageService with SAAS filtering."""
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[PendingMessageService, None]:
from openhands.app_server.config import (
get_db_session,
get_user_context,
)
async with (
get_user_context(state, request) as user_context,
get_db_session(state, request) as db_session,
):
service = SaasSQLPendingMessageService(
db_session=db_session, user_context=user_context
)
yield service

View File

@@ -198,9 +198,9 @@ describe("InteractiveChatBox", () => {
expect(onSubmitMock).toHaveBeenCalledWith("Hello, world!", [], []);
});
it("should disable the submit button when agent is loading", async () => {
it("should disable the submit button when awaiting user confirmation", async () => {
const user = userEvent.setup();
mockStores(AgentState.LOADING);
mockStores(AgentState.AWAITING_USER_CONFIRMATION);
renderInteractiveChatBox({
onSubmit: onSubmitMock,

View File

@@ -229,4 +229,231 @@ describe("conversation localStorage utilities", () => {
expect(parsed.subConversationTaskId).toBeNull();
});
});
describe("draftMessage persistence", () => {
describe("getConversationState", () => {
it("returns default draftMessage as null when no state exists", () => {
// Arrange
const conversationId = "conv-draft-1";
// Act
const state = getConversationState(conversationId);
// Assert
expect(state.draftMessage).toBeNull();
});
it("retrieves draftMessage from localStorage when it exists", () => {
// Arrange
const conversationId = "conv-draft-2";
const draftText = "This is my saved draft message";
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`;
localStorage.setItem(
consolidatedKey,
JSON.stringify({
draftMessage: draftText,
}),
);
// Act
const state = getConversationState(conversationId);
// Assert
expect(state.draftMessage).toBe(draftText);
});
it("returns null draftMessage for task conversation IDs (not persisted)", () => {
// Arrange
const taskId = "task-uuid-123";
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${taskId}`;
// Even if somehow there's data in localStorage for a task ID
localStorage.setItem(
consolidatedKey,
JSON.stringify({
draftMessage: "Should not be returned",
}),
);
// Act
const state = getConversationState(taskId);
// Assert - should return default state, not the stored value
expect(state.draftMessage).toBeNull();
});
});
describe("setConversationState", () => {
it("persists draftMessage to localStorage", () => {
// Arrange
const conversationId = "conv-draft-3";
const draftText = "New draft message to save";
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`;
// Act
setConversationState(conversationId, {
draftMessage: draftText,
});
// Assert
const stored = localStorage.getItem(consolidatedKey);
expect(stored).not.toBeNull();
const parsed = JSON.parse(stored!);
expect(parsed.draftMessage).toBe(draftText);
});
it("does not persist draftMessage for task conversation IDs", () => {
// Arrange
const taskId = "task-draft-xyz";
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${taskId}`;
// Act
setConversationState(taskId, {
draftMessage: "Draft for task ID",
});
// Assert - nothing should be stored
expect(localStorage.getItem(consolidatedKey)).toBeNull();
});
it("merges draftMessage with existing state without overwriting other fields", () => {
// Arrange
const conversationId = "conv-draft-4";
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`;
localStorage.setItem(
consolidatedKey,
JSON.stringify({
selectedTab: "terminal",
rightPanelShown: false,
unpinnedTabs: ["tab-1", "tab-2"],
conversationMode: "plan",
subConversationTaskId: "task-123",
}),
);
// Act
setConversationState(conversationId, {
draftMessage: "Updated draft",
});
// Assert
const stored = localStorage.getItem(consolidatedKey);
const parsed = JSON.parse(stored!);
expect(parsed.draftMessage).toBe("Updated draft");
expect(parsed.selectedTab).toBe("terminal");
expect(parsed.rightPanelShown).toBe(false);
expect(parsed.unpinnedTabs).toEqual(["tab-1", "tab-2"]);
expect(parsed.conversationMode).toBe("plan");
expect(parsed.subConversationTaskId).toBe("task-123");
});
it("clears draftMessage when set to null", () => {
// Arrange
const conversationId = "conv-draft-5";
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`;
localStorage.setItem(
consolidatedKey,
JSON.stringify({
draftMessage: "Existing draft",
}),
);
// Act
setConversationState(conversationId, {
draftMessage: null,
});
// Assert
const stored = localStorage.getItem(consolidatedKey);
const parsed = JSON.parse(stored!);
expect(parsed.draftMessage).toBeNull();
});
it("clears draftMessage when set to empty string (stored as empty string)", () => {
// Arrange
const conversationId = "conv-draft-6";
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`;
localStorage.setItem(
consolidatedKey,
JSON.stringify({
draftMessage: "Existing draft",
}),
);
// Act
setConversationState(conversationId, {
draftMessage: "",
});
// Assert
const stored = localStorage.getItem(consolidatedKey);
const parsed = JSON.parse(stored!);
expect(parsed.draftMessage).toBe("");
});
});
describe("conversation-specific draft isolation", () => {
it("stores drafts separately for different conversations", () => {
// Arrange
const convA = "conv-A";
const convB = "conv-B";
const draftA = "Draft for conversation A";
const draftB = "Draft for conversation B";
// Act
setConversationState(convA, { draftMessage: draftA });
setConversationState(convB, { draftMessage: draftB });
// Assert
const stateA = getConversationState(convA);
const stateB = getConversationState(convB);
expect(stateA.draftMessage).toBe(draftA);
expect(stateB.draftMessage).toBe(draftB);
});
it("updating one conversation draft does not affect another", () => {
// Arrange
const convA = "conv-isolated-A";
const convB = "conv-isolated-B";
setConversationState(convA, { draftMessage: "Original draft A" });
setConversationState(convB, { draftMessage: "Original draft B" });
// Act - update only conversation A
setConversationState(convA, { draftMessage: "Updated draft A" });
// Assert - conversation B should be unchanged
const stateA = getConversationState(convA);
const stateB = getConversationState(convB);
expect(stateA.draftMessage).toBe("Updated draft A");
expect(stateB.draftMessage).toBe("Original draft B");
});
it("clearing one conversation draft does not affect another", () => {
// Arrange
const convA = "conv-clear-A";
const convB = "conv-clear-B";
setConversationState(convA, { draftMessage: "Draft A" });
setConversationState(convB, { draftMessage: "Draft B" });
// Act - clear draft for conversation A
setConversationState(convA, { draftMessage: null });
// Assert
const stateA = getConversationState(convA);
const stateB = getConversationState(convB);
expect(stateA.draftMessage).toBeNull();
expect(stateB.draftMessage).toBe("Draft B");
});
});
});
});

View File

@@ -1,3 +1,4 @@
import React from "react";
import {
describe,
it,
@@ -8,7 +9,7 @@ import {
afterEach,
vi,
} from "vitest";
import { screen, waitFor, render, cleanup } from "@testing-library/react";
import { screen, waitFor, render, cleanup, act } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { http, HttpResponse } from "msw";
import { MemoryRouter, Route, Routes } from "react-router";
@@ -682,8 +683,242 @@ describe("Conversation WebSocket Handler", () => {
// 7. Message Sending Tests
describe("Message Sending", () => {
it.todo("should send user actions through WebSocket when connected");
it.todo("should handle send attempts when disconnected");
it("should send user actions through WebSocket when connected", async () => {
// Arrange
const conversationId = "test-conversation-send";
let receivedMessage: unknown = null;
// Set up MSW to capture sent messages
mswServer.use(
wsLink.addEventListener("connection", ({ client, server }) => {
server.connect();
// Capture messages sent from client
client.addEventListener("message", (event) => {
receivedMessage = JSON.parse(event.data as string);
});
}),
);
// Create ref to store sendMessage function
let sendMessageFn: typeof useConversationWebSocket extends () => infer R
? R extends { sendMessage: infer S }
? S
: null
: null = null;
function TestComponent() {
const context = useConversationWebSocket();
React.useEffect(() => {
if (context?.sendMessage) {
sendMessageFn = context.sendMessage;
}
}, [context?.sendMessage]);
return (
<div>
<div data-testid="connection-state">
{context?.connectionState || "NOT_AVAILABLE"}
</div>
</div>
);
}
// Act
renderWithWebSocketContext(
<TestComponent />,
conversationId,
`http://localhost:3000/api/conversations/${conversationId}`,
);
// Wait for connection
await waitFor(() => {
expect(screen.getByTestId("connection-state")).toHaveTextContent(
"OPEN",
);
});
// Send a message
await waitFor(() => {
expect(sendMessageFn).not.toBeNull();
});
await act(async () => {
await sendMessageFn!({
role: "user",
content: [{ type: "text", text: "Hello from test" }],
});
});
// Assert - message should have been received by mock server
await waitFor(() => {
expect(receivedMessage).toEqual({
role: "user",
content: [{ type: "text", text: "Hello from test" }],
});
});
});
it("should not throw error when sendMessage is called with WebSocket connected", async () => {
// This test verifies that sendMessage doesn't throw an error
// when the WebSocket is connected.
const conversationId = "test-conversation-no-throw";
let sendError: Error | null = null;
// Set up MSW to connect and receive messages
mswServer.use(
wsLink.addEventListener("connection", ({ server }) => {
server.connect();
}),
);
// Create ref to store sendMessage function
let sendMessageFn: typeof useConversationWebSocket extends () => infer R
? R extends { sendMessage: infer S }
? S
: null
: null = null;
function TestComponent() {
const context = useConversationWebSocket();
React.useEffect(() => {
if (context?.sendMessage) {
sendMessageFn = context.sendMessage;
}
}, [context?.sendMessage]);
return (
<div>
<div data-testid="connection-state">
{context?.connectionState || "NOT_AVAILABLE"}
</div>
</div>
);
}
// Act
renderWithWebSocketContext(
<TestComponent />,
conversationId,
`http://localhost:3000/api/conversations/${conversationId}`,
);
// Wait for connection
await waitFor(() => {
expect(screen.getByTestId("connection-state")).toHaveTextContent(
"OPEN",
);
});
// Wait for the context to be available
await waitFor(() => {
expect(sendMessageFn).not.toBeNull();
});
// Try to send a message
await act(async () => {
try {
await sendMessageFn!({
role: "user",
content: [{ type: "text", text: "Test message" }],
});
} catch (error) {
sendError = error as Error;
}
});
// Assert - should NOT throw an error
expect(sendError).toBeNull();
});
it("should send multiple messages through WebSocket in order", async () => {
// Arrange
const conversationId = "test-conversation-multi";
const receivedMessages: unknown[] = [];
// Set up MSW to capture sent messages
mswServer.use(
wsLink.addEventListener("connection", ({ client, server }) => {
server.connect();
// Capture messages sent from client
client.addEventListener("message", (event) => {
receivedMessages.push(JSON.parse(event.data as string));
});
}),
);
// Create ref to store sendMessage function
let sendMessageFn: typeof useConversationWebSocket extends () => infer R
? R extends { sendMessage: infer S }
? S
: null
: null = null;
function TestComponent() {
const context = useConversationWebSocket();
React.useEffect(() => {
if (context?.sendMessage) {
sendMessageFn = context.sendMessage;
}
}, [context?.sendMessage]);
return (
<div>
<div data-testid="connection-state">
{context?.connectionState || "NOT_AVAILABLE"}
</div>
</div>
);
}
// Act
renderWithWebSocketContext(
<TestComponent />,
conversationId,
`http://localhost:3000/api/conversations/${conversationId}`,
);
// Wait for connection
await waitFor(() => {
expect(screen.getByTestId("connection-state")).toHaveTextContent(
"OPEN",
);
});
await waitFor(() => {
expect(sendMessageFn).not.toBeNull();
});
// Send multiple messages
await act(async () => {
await sendMessageFn!({
role: "user",
content: [{ type: "text", text: "Message 1" }],
});
await sendMessageFn!({
role: "user",
content: [{ type: "text", text: "Message 2" }],
});
});
// Assert - both messages should have been received in order
await waitFor(() => {
expect(receivedMessages.length).toBe(2);
});
expect(receivedMessages[0]).toEqual({
role: "user",
content: [{ type: "text", text: "Message 1" }],
});
expect(receivedMessages[1]).toEqual({
role: "user",
content: [{ type: "text", text: "Message 2" }],
});
});
});
// 8. History Loading State Tests

View File

@@ -0,0 +1,594 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { renderHook, act } from "@testing-library/react";
import { useDraftPersistence } from "#/hooks/chat/use-draft-persistence";
import * as conversationLocalStorage from "#/utils/conversation-local-storage";
// Mock the entire module
vi.mock("#/utils/conversation-local-storage", () => ({
useConversationLocalStorageState: vi.fn(),
getConversationState: vi.fn(),
setConversationState: vi.fn(),
}));
// Mock the getTextContent utility
vi.mock("#/components/features/chat/utils/chat-input.utils", () => ({
getTextContent: vi.fn((el: HTMLDivElement | null) => el?.textContent || ""),
}));
describe("useDraftPersistence", () => {
let mockSetDraftMessage: (message: string | null) => void;
// Create a mock ref to contentEditable div
const createMockChatInputRef = (initialContent = "") => {
const div = document.createElement("div");
div.setAttribute("contenteditable", "true");
div.textContent = initialContent;
return { current: div };
};
beforeEach(() => {
vi.clearAllMocks();
vi.useFakeTimers();
localStorage.clear();
mockSetDraftMessage = vi.fn<(message: string | null) => void>();
// Default mock for useConversationLocalStorageState
vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({
state: {
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: null,
},
setSelectedTab: vi.fn(),
setRightPanelShown: vi.fn(),
setUnpinnedTabs: vi.fn(),
setConversationMode: vi.fn(),
setDraftMessage: mockSetDraftMessage,
});
// Default mock for getConversationState
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: null,
});
});
afterEach(() => {
vi.useRealTimers();
vi.clearAllMocks();
});
describe("draft restoration on mount", () => {
it("restores draft from localStorage when mounting with existing draft", () => {
// Arrange
const conversationId = "conv-restore-1";
const savedDraft = "Previously saved draft message";
const chatInputRef = createMockChatInputRef();
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: savedDraft,
});
// Act
renderHook(() => useDraftPersistence(conversationId, chatInputRef));
// Assert - draft should be restored to the DOM element
expect(chatInputRef.current?.textContent).toBe(savedDraft);
});
it("clears input on mount then restores draft if exists", () => {
// Arrange
const conversationId = "conv-restore-2";
const existingContent = "Stale content from previous conversation";
const savedDraft = "Saved draft";
const chatInputRef = createMockChatInputRef(existingContent);
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: savedDraft,
});
// Act
renderHook(() => useDraftPersistence(conversationId, chatInputRef));
// Assert - input cleared then draft restored
expect(chatInputRef.current?.textContent).toBe(savedDraft);
});
it("clears input when no draft exists for conversation", () => {
// Arrange
const conversationId = "conv-no-draft";
const chatInputRef = createMockChatInputRef("Some stale content");
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: null,
});
// Act
renderHook(() => useDraftPersistence(conversationId, chatInputRef));
// Assert - content should be cleared since there's no draft
expect(chatInputRef.current?.textContent).toBe("");
});
});
describe("debounced saving", () => {
it("saves draft after debounce period", () => {
// Arrange
const conversationId = "conv-debounce-1";
const chatInputRef = createMockChatInputRef();
const { result } = renderHook(() =>
useDraftPersistence(conversationId, chatInputRef),
);
// Act - simulate user typing
chatInputRef.current!.textContent = "New draft content";
act(() => {
result.current.saveDraft();
});
// Assert - should not save immediately
expect(mockSetDraftMessage).not.toHaveBeenCalled();
// Fast forward past debounce period (500ms)
act(() => {
vi.advanceTimersByTime(500);
});
// Assert - should save after debounce
expect(mockSetDraftMessage).toHaveBeenCalledWith("New draft content");
});
it("cancels pending save when new input arrives before debounce", () => {
// Arrange
const conversationId = "conv-debounce-2";
const chatInputRef = createMockChatInputRef();
const { result } = renderHook(() =>
useDraftPersistence(conversationId, chatInputRef),
);
// Act - first input
chatInputRef.current!.textContent = "First";
act(() => {
result.current.saveDraft();
});
// Wait 200ms (less than debounce)
act(() => {
vi.advanceTimersByTime(200);
});
// Second input before debounce completes
chatInputRef.current!.textContent = "First Second";
act(() => {
result.current.saveDraft();
});
// Complete the second debounce
act(() => {
vi.advanceTimersByTime(500);
});
// Assert - should only save the final value once
expect(mockSetDraftMessage).toHaveBeenCalledTimes(1);
expect(mockSetDraftMessage).toHaveBeenCalledWith("First Second");
});
it("does not save if content matches existing draft", () => {
// Arrange
const conversationId = "conv-no-change";
const existingDraft = "Existing draft";
const chatInputRef = createMockChatInputRef(existingDraft);
vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({
state: {
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: existingDraft,
},
setSelectedTab: vi.fn(),
setRightPanelShown: vi.fn(),
setUnpinnedTabs: vi.fn(),
setConversationMode: vi.fn(),
setDraftMessage: mockSetDraftMessage,
});
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: existingDraft,
});
const { result } = renderHook(() =>
useDraftPersistence(conversationId, chatInputRef),
);
// Act - try to save same content
act(() => {
result.current.saveDraft();
});
act(() => {
vi.advanceTimersByTime(500);
});
// Assert - should not save since content is the same
expect(mockSetDraftMessage).not.toHaveBeenCalled();
});
});
describe("clearDraft", () => {
it("clears the draft from localStorage", () => {
// Arrange
const conversationId = "conv-clear-1";
const chatInputRef = createMockChatInputRef("Some content");
const { result } = renderHook(() =>
useDraftPersistence(conversationId, chatInputRef),
);
// Act
act(() => {
result.current.clearDraft();
});
// Assert
expect(mockSetDraftMessage).toHaveBeenCalledWith(null);
});
it("cancels any pending debounced save when clearing", () => {
// Arrange
const conversationId = "conv-clear-2";
const chatInputRef = createMockChatInputRef();
const { result } = renderHook(() =>
useDraftPersistence(conversationId, chatInputRef),
);
// Start a save
chatInputRef.current!.textContent = "Pending draft";
act(() => {
result.current.saveDraft();
});
// Clear before debounce completes
act(() => {
vi.advanceTimersByTime(200);
result.current.clearDraft();
});
// Complete the original debounce period
act(() => {
vi.advanceTimersByTime(500);
});
// Assert - only the clear should have been called (the pending save should be cancelled)
expect(mockSetDraftMessage).toHaveBeenCalledTimes(1);
expect(mockSetDraftMessage).toHaveBeenCalledWith(null);
});
});
describe("conversation switching", () => {
it("clears input when switching to a new conversation without a draft", () => {
// Arrange
const chatInputRef = createMockChatInputRef("Draft from conv A");
// First conversation has a draft
vi.mocked(conversationLocalStorage.getConversationState)
.mockReturnValueOnce({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: "Draft from conv A",
})
.mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: null,
});
const { rerender } = renderHook(
({ conversationId }) =>
useDraftPersistence(conversationId, chatInputRef),
{ initialProps: { conversationId: "conv-A" } },
);
// Act - switch to conversation B
rerender({ conversationId: "conv-B" });
// Assert - input should be cleared (no draft for conv-B)
expect(chatInputRef.current?.textContent).toBe("");
});
it("restores draft when switching to a conversation with an existing draft", () => {
// Arrange
const chatInputRef = createMockChatInputRef();
const draftForConvB = "Saved draft for conversation B";
vi.mocked(conversationLocalStorage.getConversationState)
.mockReturnValueOnce({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: null,
})
.mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: draftForConvB,
});
const { rerender } = renderHook(
({ conversationId }) =>
useDraftPersistence(conversationId, chatInputRef),
{ initialProps: { conversationId: "conv-A" } },
);
// Act - switch to conversation B
rerender({ conversationId: "conv-B" });
// Assert - draft for conv-B should be restored
expect(chatInputRef.current?.textContent).toBe(draftForConvB);
});
it("cancels pending save when switching conversations", () => {
// Arrange
const chatInputRef = createMockChatInputRef();
const { result, rerender } = renderHook(
({ conversationId }) =>
useDraftPersistence(conversationId, chatInputRef),
{ initialProps: { conversationId: "conv-A" } },
);
// Start typing in conv-A
chatInputRef.current!.textContent = "Draft for conv-A";
act(() => {
result.current.saveDraft();
});
// Switch conversation before debounce completes
act(() => {
vi.advanceTimersByTime(200);
});
rerender({ conversationId: "conv-B" });
// Complete the debounce period
act(() => {
vi.advanceTimersByTime(500);
});
// Assert - the save should NOT have happened because conversation changed
expect(mockSetDraftMessage).not.toHaveBeenCalled();
});
});
describe("task ID to real conversation ID transition", () => {
it("transfers draft from task ID to real conversation ID during transition", () => {
// Arrange
const chatInputRef = createMockChatInputRef("Draft typed during init");
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: null,
});
const { rerender } = renderHook(
({ conversationId }) =>
useDraftPersistence(conversationId, chatInputRef),
{ initialProps: { conversationId: "task-abc-123" } },
);
// Simulate user typing during task initialization
chatInputRef.current!.textContent = "Draft typed during init";
// Act - transition to real conversation ID
rerender({ conversationId: "conv-real-123" });
// Assert - draft should be saved to the new real conversation ID
expect(conversationLocalStorage.setConversationState).toHaveBeenCalledWith(
"conv-real-123",
{ draftMessage: "Draft typed during init" },
);
// And the draft should remain visible in the input
expect(chatInputRef.current?.textContent).toBe("Draft typed during init");
});
it("does not transfer empty draft during task-to-real transition", () => {
// Arrange
const chatInputRef = createMockChatInputRef("");
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: null,
});
const { rerender } = renderHook(
({ conversationId }) =>
useDraftPersistence(conversationId, chatInputRef),
{ initialProps: { conversationId: "task-abc-123" } },
);
// Act - transition to real conversation ID with empty input
rerender({ conversationId: "conv-real-123" });
// Assert - no draft should be saved (input is cleared, checked by hook)
// The setConversationState should not be called with draftMessage
expect(conversationLocalStorage.setConversationState).not.toHaveBeenCalled();
});
it("does not transfer draft for non-task ID transitions", () => {
// Arrange
const chatInputRef = createMockChatInputRef("Some draft");
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: null,
});
const { rerender } = renderHook(
({ conversationId }) =>
useDraftPersistence(conversationId, chatInputRef),
{ initialProps: { conversationId: "conv-A" } },
);
// Act - normal conversation switch (not task-to-real)
rerender({ conversationId: "conv-B" });
// Assert - should not use setConversationState directly
// (the normal path uses setDraftMessage from the hook)
expect(conversationLocalStorage.setConversationState).not.toHaveBeenCalled();
});
});
describe("hasDraft and isRestored state", () => {
it("returns hasDraft true when draft exists in hook state", () => {
// Arrange
const conversationId = "conv-has-draft";
const chatInputRef = createMockChatInputRef();
vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({
state: {
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: "Existing draft",
},
setSelectedTab: vi.fn(),
setRightPanelShown: vi.fn(),
setUnpinnedTabs: vi.fn(),
setConversationMode: vi.fn(),
setDraftMessage: mockSetDraftMessage,
});
// Act
const { result } = renderHook(() =>
useDraftPersistence(conversationId, chatInputRef),
);
// Assert
expect(result.current.hasDraft).toBe(true);
});
it("returns hasDraft false when no draft exists", () => {
// Arrange
const conversationId = "conv-no-draft";
const chatInputRef = createMockChatInputRef();
// Act
const { result } = renderHook(() =>
useDraftPersistence(conversationId, chatInputRef),
);
// Assert
expect(result.current.hasDraft).toBe(false);
});
it("sets isRestored to true after restoration completes", () => {
// Arrange
const conversationId = "conv-restored";
const chatInputRef = createMockChatInputRef();
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
selectedTab: "editor",
rightPanelShown: true,
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: "Draft to restore",
});
// Act
const { result } = renderHook(() =>
useDraftPersistence(conversationId, chatInputRef),
);
// Assert
expect(result.current.isRestored).toBe(true);
});
});
describe("cleanup on unmount", () => {
it("clears pending timeout on unmount", () => {
// Arrange
const conversationId = "conv-unmount";
const chatInputRef = createMockChatInputRef();
const { result, unmount } = renderHook(() =>
useDraftPersistence(conversationId, chatInputRef),
);
// Start a save
chatInputRef.current!.textContent = "Draft";
act(() => {
result.current.saveDraft();
});
// Unmount before debounce completes
unmount();
// Complete the debounce period
act(() => {
vi.advanceTimersByTime(500);
});
// Assert - save should not have been called after unmount
expect(mockSetDraftMessage).not.toHaveBeenCalled();
});
});
});

View File

@@ -88,6 +88,7 @@ describe("useHandlePlanClick", () => {
unpinnedTabs: [],
subConversationTaskId: null,
conversationMode: "code",
draftMessage: null,
});
});
@@ -117,6 +118,7 @@ describe("useHandlePlanClick", () => {
unpinnedTabs: [],
subConversationTaskId: storedTaskId,
conversationMode: "code",
draftMessage: null,
});
renderHook(() => useHandlePlanClick());
@@ -155,6 +157,7 @@ describe("useHandlePlanClick", () => {
unpinnedTabs: [],
subConversationTaskId: storedTaskId,
conversationMode: "code",
draftMessage: null,
});
renderHook(() => useHandlePlanClick());

View File

@@ -0,0 +1,40 @@
/**
* Pending Message Service
*
* This service handles server-side message queuing for V1 conversations.
* Messages can be queued when the WebSocket is not connected and will be
* delivered automatically when the conversation becomes ready.
*/
import { openHands } from "../open-hands-axios";
import type {
PendingMessageResponse,
QueuePendingMessageRequest,
} from "./pending-message-service.types";
class PendingMessageService {
/**
* Queue a message for delivery when conversation becomes ready.
*
* This endpoint allows users to submit messages even when the conversation's
* WebSocket connection is not yet established. Messages are stored server-side
* and delivered automatically when the conversation transitions to READY status.
*
* @param conversationId The conversation ID (can be task ID before conversation is ready)
* @param message The message to queue
* @returns PendingMessageResponse with the message ID and queue position
* @throws Error if too many pending messages (limit: 10 per conversation)
*/
static async queueMessage(
conversationId: string,
message: QueuePendingMessageRequest,
): Promise<PendingMessageResponse> {
const { data } = await openHands.post<PendingMessageResponse>(
`/api/v1/conversations/${conversationId}/pending-messages`,
message,
);
return data;
}
}
export default PendingMessageService;

View File

@@ -0,0 +1,22 @@
/**
* Types for the pending message service
*/
import type { V1MessageContent } from "../conversation-service/v1-conversation-service.types";
/**
* Response when queueing a pending message
*/
export interface PendingMessageResponse {
id: string;
queued: boolean;
position: number;
}
/**
* Request to queue a pending message
*/
export interface QueuePendingMessageRequest {
role?: "user";
content: V1MessageContent[];
}

View File

@@ -190,8 +190,14 @@ export function ChatInterface() {
const prompt =
uploadedFiles.length > 0 ? `${content}\n\n${filePrompt}` : content;
send(createChatMessage(prompt, imageUrls, uploadedFiles, timestamp));
setOptimisticUserMessage(content);
const result = await send(
createChatMessage(prompt, imageUrls, uploadedFiles, timestamp),
);
// Only show optimistic UI if message was sent immediately via WebSocket
// If queued for later delivery, the message will appear when actually delivered
if (!result.queued) {
setOptimisticUserMessage(content);
}
setMessageToSend("");
};

View File

@@ -60,6 +60,7 @@ export function CustomChatInput({
messageToSend,
checkIsContentEmpty,
clearEmptyContentHandler,
saveDraft,
} = useChatInputLogic();
const {
@@ -158,6 +159,7 @@ export function CustomChatInput({
onInput={() => {
handleInput();
updateSlashMenu();
saveDraft();
}}
onPaste={handlePaste}
onKeyDown={(e) => {

View File

@@ -142,8 +142,9 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) {
handleSubmit(suggestion);
};
// Allow users to submit messages during LOADING state - they will be
// queued server-side and delivered when the conversation becomes ready
const isDisabled =
curAgentState === AgentState.LOADING ||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION ||
isTaskPolling(subConversationTaskStatus);

View File

@@ -40,6 +40,7 @@ import type {
V1SendMessageRequest,
} from "#/api/conversation-service/v1-conversation-service.types";
import EventService from "#/api/event-service/event-service.api";
import PendingMessageService from "#/api/pending-message-service/pending-message-service.api";
import { useConversationStore } from "#/stores/conversation-store";
import { isBudgetOrCreditError, trackError } from "#/utils/error-handler";
import { useTracking } from "#/hooks/use-tracking";
@@ -47,6 +48,7 @@ import { useReadConversationFile } from "#/hooks/mutation/use-read-conversation-
import useMetricsStore from "#/stores/metrics-store";
import { I18nKey } from "#/i18n/declaration";
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
import { setConversationState } from "#/utils/conversation-local-storage";
// eslint-disable-next-line @typescript-eslint/naming-convention
export type V1_WebSocketConnectionState =
@@ -55,9 +57,13 @@ export type V1_WebSocketConnectionState =
| "CLOSED"
| "CLOSING";
interface SendMessageResult {
queued: boolean; // true if message was queued for later delivery, false if sent immediately
}
interface ConversationWebSocketContextType {
connectionState: V1_WebSocketConnectionState;
sendMessage: (message: V1SendMessageRequest) => Promise<void>;
sendMessage: (message: V1SendMessageRequest) => Promise<SendMessageResult>;
isLoadingHistory: boolean;
}
@@ -397,6 +403,10 @@ export function ConversationWebSocketProvider({
// Clear optimistic user message when a user message is confirmed
if (isUserMessageEvent(event)) {
removeOptimisticUserMessage();
// Clear draft from localStorage - message was successfully delivered
if (conversationId) {
setConversationState(conversationId, { draftMessage: null });
}
}
// Handle cache invalidation for ActionEvent
@@ -556,6 +566,11 @@ export function ConversationWebSocketProvider({
// Clear optimistic user message when a user message is confirmed
if (isUserMessageEvent(event)) {
removeOptimisticUserMessage();
// Clear draft from localStorage - message was successfully delivered
// Use main conversationId since user types in main conversation input
if (conversationId) {
setConversationState(conversationId, { draftMessage: null });
}
}
// Handle cache invalidation for ActionEvent
@@ -810,21 +825,44 @@ export function ConversationWebSocketProvider({
);
// V1 send message function via WebSocket
// Falls back to REST API queue when WebSocket is not connected
const sendMessage = useCallback(
async (message: V1SendMessageRequest) => {
async (message: V1SendMessageRequest): Promise<SendMessageResult> => {
const currentMode = useConversationStore.getState().conversationMode;
const currentSocket =
currentMode === "plan" ? planningAgentSocket : mainSocket;
if (!currentSocket || currentSocket.readyState !== WebSocket.OPEN) {
const error = "WebSocket is not connected";
setErrorMessage(error);
throw new Error(error);
// WebSocket not connected - queue message via REST API
// Message will be delivered automatically when conversation becomes ready
if (!conversationId) {
const error = new Error("No conversation ID available");
setErrorMessage(error.message);
throw error;
}
try {
await PendingMessageService.queueMessage(conversationId, {
role: "user",
content: message.content,
});
// Message queued successfully - it will be delivered when ready
// Return queued: true so caller knows not to show optimistic UI
return { queued: true };
} catch (error) {
const errorMessage =
error instanceof Error
? error.message
: "Failed to queue message for delivery";
setErrorMessage(errorMessage);
throw error;
}
}
try {
// Send message through WebSocket as JSON
currentSocket.send(JSON.stringify(message));
return { queued: false };
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : "Failed to send message";
@@ -832,7 +870,7 @@ export function ConversationWebSocketProvider({
throw error;
}
},
[mainSocket, planningAgentSocket, setErrorMessage],
[mainSocket, planningAgentSocket, setErrorMessage, conversationId],
);
// Track main socket state changes

View File

@@ -5,12 +5,15 @@ import {
getTextContent,
} from "#/components/features/chat/utils/chat-input.utils";
import { useConversationStore } from "#/stores/conversation-store";
import { useConversationId } from "#/hooks/use-conversation-id";
import { useDraftPersistence } from "./use-draft-persistence";
/**
* Hook for managing chat input content logic
*/
export const useChatInputLogic = () => {
const chatInputRef = useRef<HTMLDivElement | null>(null);
const { conversationId } = useConversationId();
const {
messageToSend,
@@ -19,6 +22,12 @@ export const useChatInputLogic = () => {
setIsRightPanelShown,
} = useConversationStore();
// Draft persistence - saves to localStorage, restores on mount
const { saveDraft, clearDraft } = useDraftPersistence(
conversationId,
chatInputRef,
);
// Save current input value when drawer state changes
useEffect(() => {
if (chatInputRef.current) {
@@ -51,5 +60,7 @@ export const useChatInputLogic = () => {
checkIsContentEmpty,
clearEmptyContentHandler,
getCurrentMessage,
saveDraft,
clearDraft,
};
};

View File

@@ -0,0 +1,179 @@
import { useEffect, useRef, useCallback, useState } from "react";
import {
useConversationLocalStorageState,
getConversationState,
setConversationState,
} from "#/utils/conversation-local-storage";
import { getTextContent } from "#/components/features/chat/utils/chat-input.utils";
/**
* Check if a conversation ID is a temporary task ID.
* Task IDs have the format "task-{uuid}" and are used during V1 conversation initialization.
*/
const isTaskId = (id: string): boolean => id.startsWith("task-");
const DRAFT_SAVE_DEBOUNCE_MS = 500;
/**
* Hook for persisting draft messages to localStorage.
* Handles debounced saving on input, restoration on mount, and clearing on confirmed delivery.
*/
export const useDraftPersistence = (
conversationId: string,
chatInputRef: React.RefObject<HTMLDivElement | null>,
) => {
const { state, setDraftMessage } =
useConversationLocalStorageState(conversationId);
const saveTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const hasRestoredRef = useRef(false);
const [isRestored, setIsRestored] = useState(false);
// Track current conversationId to prevent saving draft to wrong conversation
const currentConversationIdRef = useRef(conversationId);
// Track if this is the first mount to handle initial cleanup
const isFirstMountRef = useRef(true);
// IMPORTANT: This effect must run FIRST when conversation changes.
// It handles three concerns:
// 1. Cleanup: Cancel pending saves from previous conversation
// 2. Task-to-real transition: Preserve draft typed during initialization
// 3. DOM reset: Clear stale content before restoration effect runs
useEffect(() => {
const previousConversationId = currentConversationIdRef.current;
const isInitialMount = isFirstMountRef.current;
currentConversationIdRef.current = conversationId;
isFirstMountRef.current = false;
// --- 1. Cancel pending saves from previous conversation ---
// Prevents draft from being saved to wrong conversation if user switched quickly
if (saveTimeoutRef.current) {
clearTimeout(saveTimeoutRef.current);
saveTimeoutRef.current = null;
}
const element = chatInputRef.current;
// --- 2. Handle task-to-real ID transition (preserve draft during initialization) ---
// When a new V1 conversation initializes, it starts with a temporary "task-xxx" ID
// that transitions to a real conversation ID once ready. Task IDs don't persist
// to localStorage, so any draft typed during this phase would be lost.
// We detect this transition and transfer the draft to the new real ID.
if (!isInitialMount && previousConversationId !== conversationId) {
const wasTaskId = isTaskId(previousConversationId);
const isNowRealId = !isTaskId(conversationId);
if (wasTaskId && isNowRealId && element) {
const currentText = getTextContent(element).trim();
if (currentText) {
// Transfer draft to the new (real) conversation ID
setConversationState(conversationId, { draftMessage: currentText });
// Keep draft visible in DOM and mark as restored to prevent overwrite
hasRestoredRef.current = true;
setIsRestored(true);
return; // Skip normal cleanup - draft is already in correct state
}
}
}
// --- 3. Clear stale DOM content (will be restored by next effect if draft exists) ---
// This prevents stale drafts from appearing in new conversations due to:
// - Browser form restoration on back/forward navigation
// - React DOM recycling between conversation switches
// The restoration effect will then populate with the correct saved draft
if (element) {
element.textContent = "";
}
// Reset restoration flag so the restoration effect will run for new conversation
hasRestoredRef.current = false;
setIsRestored(false);
}, [conversationId, chatInputRef]);
// Restore draft from localStorage - reads directly to avoid state sync timing issues
useEffect(() => {
if (hasRestoredRef.current) {
return;
}
const element = chatInputRef.current;
if (!element) {
return;
}
// Read directly from localStorage to avoid stale state from useConversationLocalStorageState
// The hook's state may not have synced yet after conversationId change
const { draftMessage } = getConversationState(conversationId);
// Only restore if there's a saved draft and the input is empty
if (draftMessage && getTextContent(element).trim() === "") {
element.textContent = draftMessage;
// Move cursor to end
const selection = window.getSelection();
const range = document.createRange();
range.selectNodeContents(element);
range.collapse(false);
selection?.removeAllRanges();
selection?.addRange(range);
}
hasRestoredRef.current = true;
setIsRestored(true);
}, [chatInputRef, conversationId]);
// Debounced save function - called from onInput handler
const saveDraft = useCallback(() => {
// Clear any pending save
if (saveTimeoutRef.current) {
clearTimeout(saveTimeoutRef.current);
}
// Capture the conversationId at the time of input
const capturedConversationId = conversationId;
saveTimeoutRef.current = setTimeout(() => {
// Verify we're still on the same conversation before saving
// This prevents saving draft to wrong conversation if user switched quickly
if (capturedConversationId !== currentConversationIdRef.current) {
return;
}
const element = chatInputRef.current;
if (!element) {
return;
}
const text = getTextContent(element).trim();
// Only save if content has changed
if (text !== (state.draftMessage || "")) {
setDraftMessage(text || null);
}
}, DRAFT_SAVE_DEBOUNCE_MS);
}, [chatInputRef, state.draftMessage, setDraftMessage, conversationId]);
// Clear draft - called after message delivery is confirmed
const clearDraft = useCallback(() => {
// Cancel any pending save
if (saveTimeoutRef.current) {
clearTimeout(saveTimeoutRef.current);
saveTimeoutRef.current = null;
}
setDraftMessage(null);
}, [setDraftMessage]);
// Cleanup timeout on unmount
useEffect(
() => () => {
if (saveTimeoutRef.current) {
clearTimeout(saveTimeoutRef.current);
}
},
[],
);
return {
saveDraft,
clearDraft,
isRestored,
hasDraft: !!state.draftMessage,
};
};

View File

@@ -5,6 +5,10 @@ import { useConversationWebSocket } from "#/contexts/conversation-websocket-cont
import { useConversationId } from "#/hooks/use-conversation-id";
import { V1MessageContent } from "#/api/conversation-service/v1-conversation-service.types";
interface SendResult {
queued: boolean; // true if message was queued for later delivery
}
/**
* Unified hook for sending messages that works with both V0 and V1 conversations
* - For V0 conversations: Uses Socket.IO WebSocket via useWsClient
@@ -26,7 +30,7 @@ export function useSendMessage() {
conversation?.conversation_version === "V1";
const send = useCallback(
async (event: Record<string, unknown>) => {
async (event: Record<string, unknown>): Promise<SendResult> => {
if (isV1Conversation && v1Context) {
// V1: Convert V0 event format to V1 message format
const { action, args } = event as {
@@ -57,19 +61,20 @@ export function useSendMessage() {
}
// Send via V1 WebSocket context (uses correct host/port)
await v1Context.sendMessage({
const result = await v1Context.sendMessage({
role: "user",
content,
});
} else {
// For non-message events, fall back to V0 send
// (e.g., agent state changes, other control events)
v0Send(event);
return result;
}
} else {
// V0: Use Socket.IO
// For non-message events, fall back to V0 send
// (e.g., agent state changes, other control events)
v0Send(event);
return { queued: false };
}
// V0: Use Socket.IO
v0Send(event);
return { queued: false };
},
[isV1Conversation, v1Context, v0Send, conversationId],
);

View File

@@ -23,6 +23,7 @@ export interface ConversationState {
unpinnedTabs: string[];
conversationMode: ConversationMode;
subConversationTaskId: string | null;
draftMessage: string | null;
}
const DEFAULT_CONVERSATION_STATE: ConversationState = {
@@ -31,6 +32,7 @@ const DEFAULT_CONVERSATION_STATE: ConversationState = {
unpinnedTabs: [],
conversationMode: "code",
subConversationTaskId: null,
draftMessage: null,
};
/**
@@ -121,6 +123,7 @@ export function useConversationLocalStorageState(conversationId: string): {
setRightPanelShown: (shown: boolean) => void;
setUnpinnedTabs: (tabs: string[]) => void;
setConversationMode: (mode: ConversationMode) => void;
setDraftMessage: (message: string | null) => void;
} {
const [state, setState] = useState<ConversationState>(() =>
getConversationState(conversationId),
@@ -178,5 +181,6 @@ export function useConversationLocalStorageState(conversationId: string): {
setRightPanelShown: (shown) => updateState({ rightPanelShown: shown }),
setUnpinnedTabs: (tabs) => updateState({ unpinnedTabs: tabs }),
setConversationMode: (mode) => updateState({ conversationMode: mode }),
setDraftMessage: (message) => updateState({ draftMessage: message }),
};
}

View File

@@ -59,6 +59,9 @@ from openhands.app_server.event_callback.event_callback_service import (
from openhands.app_server.event_callback.set_title_callback_processor import (
SetTitleCallbackProcessor,
)
from openhands.app_server.pending_messages.pending_message_service import (
PendingMessageService,
)
from openhands.app_server.sandbox.docker_sandbox_service import DockerSandboxService
from openhands.app_server.sandbox.sandbox_models import (
AGENT_SERVER,
@@ -127,6 +130,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
sandbox_service: SandboxService
sandbox_spec_service: SandboxSpecService
jwt_service: JwtService
pending_message_service: PendingMessageService
sandbox_startup_timeout: int
sandbox_startup_poll_frequency: int
max_num_conversations_per_sandbox: int
@@ -373,6 +377,15 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
task.app_conversation_id = info.id
yield task
# Process any pending messages queued while waiting for conversation
if sandbox.session_api_key:
await self._process_pending_messages(
task_id=task.id,
conversation_id=info.id,
agent_server_url=agent_server_url,
session_api_key=sandbox.session_api_key,
)
except Exception as exc:
_logger.exception('Error starting conversation', stack_info=True)
task.status = AppConversationStartTaskStatus.ERROR
@@ -1424,6 +1437,89 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
plugins=plugins,
)
async def _process_pending_messages(
self,
task_id: UUID,
conversation_id: UUID,
agent_server_url: str,
session_api_key: str,
) -> None:
"""Process pending messages queued before conversation was ready.
Messages are delivered concurrently to the agent server. After processing,
all messages are deleted from the database regardless of success or failure.
Args:
task_id: The start task ID (may have been used as conversation_id initially)
conversation_id: The real conversation ID
agent_server_url: URL of the agent server
session_api_key: API key for authenticating with agent server
"""
# Convert UUIDs to strings for the pending message service
# The frontend uses task-{uuid.hex} format (no hyphens), matching OpenHandsUUID serialization
task_id_str = f'task-{task_id.hex}'
# conversation_id uses standard format (with hyphens) for agent server API compatibility
conversation_id_str = str(conversation_id)
_logger.info(f'task_id={task_id_str} conversation_id={conversation_id_str}')
# First, update any messages that were queued with the task_id
updated_count = await self.pending_message_service.update_conversation_id(
old_conversation_id=task_id_str,
new_conversation_id=conversation_id_str,
)
_logger.info(f'updated_count={updated_count} ')
if updated_count > 0:
_logger.info(
f'Updated {updated_count} pending messages from task_id={task_id_str} '
f'to conversation_id={conversation_id_str}'
)
# Get all pending messages for this conversation
pending_messages = await self.pending_message_service.get_pending_messages(
conversation_id_str
)
if not pending_messages:
return
_logger.info(
f'Processing {len(pending_messages)} pending messages for '
f'conversation {conversation_id_str}'
)
# Process messages sequentially to preserve order
for msg in pending_messages:
try:
# Serialize content objects to JSON-compatible dicts
content_json = [item.model_dump() for item in msg.content]
# Use the events endpoint which handles message sending
response = await self.httpx_client.post(
f'{agent_server_url}/api/conversations/{conversation_id_str}/events',
json={
'role': msg.role,
'content': content_json,
'run': True,
},
headers={'X-Session-API-Key': session_api_key},
timeout=30.0,
)
response.raise_for_status()
_logger.debug(f'Delivered pending message {msg.id}')
except Exception as e:
_logger.warning(f'Failed to deliver pending message {msg.id}: {e}')
# Delete all pending messages after processing (regardless of success/failure)
deleted_count = (
await self.pending_message_service.delete_messages_for_conversation(
conversation_id_str
)
)
_logger.info(
f'Finished processing pending messages for conversation {conversation_id_str}. '
f'Deleted {deleted_count} messages.'
)
async def update_agent_server_conversation_title(
self,
conversation_id: str,
@@ -1796,6 +1892,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
get_global_config,
get_httpx_client,
get_jwt_service,
get_pending_message_service,
get_sandbox_service,
get_sandbox_spec_service,
get_user_context,
@@ -1815,6 +1912,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
get_event_service(state, request) as event_service,
get_jwt_service(state, request) as jwt_service,
get_httpx_client(state, request) as httpx_client,
get_pending_message_service(state, request) as pending_message_service,
):
access_token_hard_timeout = None
if self.access_token_hard_timeout:
@@ -1859,6 +1957,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
event_callback_service=event_callback_service,
event_service=event_service,
jwt_service=jwt_service,
pending_message_service=pending_message_service,
sandbox_startup_timeout=self.sandbox_startup_timeout,
sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency,
max_num_conversations_per_sandbox=self.max_num_conversations_per_sandbox,

View File

@@ -0,0 +1,39 @@
"""Add pending_messages table for server-side message queuing
Revision ID: 007
Revises: 006
Create Date: 2025-03-15 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '007'
down_revision: Union[str, None] = '006'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create pending_messages table for storing messages before conversation is ready.
Messages are stored temporarily until the conversation becomes ready, then
delivered and deleted regardless of success or failure.
"""
op.create_table(
'pending_messages',
sa.Column('id', sa.String(), primary_key=True),
sa.Column('conversation_id', sa.String(), nullable=False, index=True),
sa.Column('role', sa.String(20), nullable=False, server_default='user'),
sa.Column('content', sa.JSON, nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
)
def downgrade() -> None:
"""Remove pending_messages table."""
op.drop_table('pending_messages')

View File

@@ -33,6 +33,10 @@ from openhands.app_server.event_callback.event_callback_service import (
EventCallbackService,
EventCallbackServiceInjector,
)
from openhands.app_server.pending_messages.pending_message_service import (
PendingMessageService,
PendingMessageServiceInjector,
)
from openhands.app_server.sandbox.sandbox_service import (
SandboxService,
SandboxServiceInjector,
@@ -114,6 +118,7 @@ class AppServerConfig(OpenHandsModel):
app_conversation_info: AppConversationInfoServiceInjector | None = None
app_conversation_start_task: AppConversationStartTaskServiceInjector | None = None
app_conversation: AppConversationServiceInjector | None = None
pending_message: PendingMessageServiceInjector | None = None
user: UserContextInjector | None = None
jwt: JwtServiceInjector | None = None
httpx: HttpxClientInjector = Field(default_factory=HttpxClientInjector)
@@ -280,6 +285,13 @@ def config_from_env() -> AppServerConfig:
tavily_api_key=tavily_api_key
)
if config.pending_message is None:
from openhands.app_server.pending_messages.pending_message_service import (
SQLPendingMessageServiceInjector,
)
config.pending_message = SQLPendingMessageServiceInjector()
if config.user is None:
config.user = AuthUserContextInjector()
@@ -358,6 +370,14 @@ def get_app_conversation_service(
return injector.context(state, request)
def get_pending_message_service(
state: InjectorState, request: Request | None = None
) -> AsyncContextManager[PendingMessageService]:
injector = get_global_config().pending_message
assert injector is not None
return injector.context(state, request)
def get_user_context(
state: InjectorState, request: Request | None = None
) -> AsyncContextManager[UserContext]:
@@ -433,6 +453,12 @@ def depends_app_conversation_service():
return Depends(injector.depends)
def depends_pending_message_service():
injector = get_global_config().pending_message
assert injector is not None
return Depends(injector.depends)
def depends_user_context():
injector = get_global_config().user
assert injector is not None

View File

@@ -0,0 +1,21 @@
"""Pending messages module for server-side message queuing."""
from openhands.app_server.pending_messages.pending_message_models import (
PendingMessage,
PendingMessageResponse,
)
from openhands.app_server.pending_messages.pending_message_service import (
PendingMessageService,
PendingMessageServiceInjector,
SQLPendingMessageService,
SQLPendingMessageServiceInjector,
)
__all__ = [
'PendingMessage',
'PendingMessageResponse',
'PendingMessageService',
'PendingMessageServiceInjector',
'SQLPendingMessageService',
'SQLPendingMessageServiceInjector',
]

View File

@@ -0,0 +1,32 @@
"""Models for pending message queue functionality."""
from datetime import datetime
from uuid import uuid4
from pydantic import BaseModel, Field
from openhands.agent_server.models import ImageContent, TextContent
from openhands.agent_server.utils import utc_now
class PendingMessage(BaseModel):
"""A message queued for delivery when conversation becomes ready.
Pending messages are stored in the database and delivered to the agent_server
when the conversation transitions to READY status. Messages are deleted after
processing, regardless of success or failure.
"""
id: str = Field(default_factory=lambda: str(uuid4()))
conversation_id: str # Can be task-{uuid} or real conversation UUID
role: str = 'user'
content: list[TextContent | ImageContent]
created_at: datetime = Field(default_factory=utc_now)
class PendingMessageResponse(BaseModel):
"""Response when queueing a pending message."""
id: str
queued: bool
position: int = Field(description='Position in the queue (1-based)')

View File

@@ -0,0 +1,104 @@
"""REST API router for pending messages."""
import logging
from fastapi import APIRouter, HTTPException, Request, status
from pydantic import TypeAdapter, ValidationError
from openhands.agent_server.models import ImageContent, TextContent
from openhands.app_server.config import depends_pending_message_service
from openhands.app_server.pending_messages.pending_message_models import (
PendingMessageResponse,
)
from openhands.app_server.pending_messages.pending_message_service import (
PendingMessageService,
)
from openhands.server.dependencies import get_dependencies
logger = logging.getLogger(__name__)
# Type adapter for validating content from request
_content_type_adapter = TypeAdapter(list[TextContent | ImageContent])
# Create router with authentication dependencies
router = APIRouter(
prefix='/conversations/{conversation_id}/pending-messages',
tags=['Pending Messages'],
dependencies=get_dependencies(),
)
# Create dependency at module level
pending_message_service_dependency = depends_pending_message_service()
@router.post(
'', response_model=PendingMessageResponse, status_code=status.HTTP_201_CREATED
)
async def queue_pending_message(
conversation_id: str,
request: Request,
pending_service: PendingMessageService = pending_message_service_dependency,
) -> PendingMessageResponse:
"""Queue a message for delivery when conversation becomes ready.
This endpoint allows users to submit messages even when the conversation's
WebSocket connection is not yet established. Messages are stored server-side
and delivered automatically when the conversation transitions to READY status.
Args:
conversation_id: The conversation ID (can be task ID before conversation is ready)
request: The FastAPI request containing message content
Returns:
PendingMessageResponse with the message ID and queue position
Raises:
HTTPException 400: If the request body is invalid
HTTPException 429: If too many pending messages are queued (limit: 10)
"""
try:
body = await request.json()
except Exception:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Invalid request body',
)
raw_content = body.get('content')
role = body.get('role', 'user')
if not raw_content or not isinstance(raw_content, list):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='content must be a non-empty list',
)
# Validate and parse content into typed objects
try:
content = _content_type_adapter.validate_python(raw_content)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f'Invalid content format: {e}',
)
# Rate limit: max 10 pending messages per conversation
pending_count = await pending_service.count_pending_messages(conversation_id)
if pending_count >= 10:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail='Too many pending messages. Maximum 10 messages per conversation.',
)
response = await pending_service.add_message(
conversation_id=conversation_id,
content=content,
role=role,
)
logger.info(
f'Queued pending message {response.id} for conversation {conversation_id} '
f'(position: {response.position})'
)
return response

View File

@@ -0,0 +1,200 @@
"""Service for managing pending messages in SQL database."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import AsyncGenerator
from fastapi import Request
from pydantic import TypeAdapter
from sqlalchemy import JSON, Column, String, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.agent_server.models import ImageContent, TextContent
from openhands.app_server.pending_messages.pending_message_models import (
PendingMessage,
PendingMessageResponse,
)
from openhands.app_server.services.injector import Injector, InjectorState
from openhands.app_server.utils.sql_utils import Base, UtcDateTime
from openhands.sdk.utils.models import DiscriminatedUnionMixin
# Type adapter for deserializing content from JSON
_content_type_adapter = TypeAdapter(list[TextContent | ImageContent])
class StoredPendingMessage(Base): # type: ignore
"""SQLAlchemy model for pending messages."""
__tablename__ = 'pending_messages'
id = Column(String, primary_key=True)
conversation_id = Column(String, nullable=False, index=True)
role = Column(String(20), nullable=False, default='user')
content = Column(JSON, nullable=False)
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
class PendingMessageService(ABC):
"""Abstract service for managing pending messages."""
@abstractmethod
async def add_message(
self,
conversation_id: str,
content: list[TextContent | ImageContent],
role: str = 'user',
) -> PendingMessageResponse:
"""Queue a message for delivery when conversation becomes ready."""
@abstractmethod
async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]:
"""Get all pending messages for a conversation, ordered by created_at."""
@abstractmethod
async def count_pending_messages(self, conversation_id: str) -> int:
"""Count pending messages for a conversation."""
@abstractmethod
async def delete_messages_for_conversation(self, conversation_id: str) -> int:
"""Delete all pending messages for a conversation, returning count deleted."""
@abstractmethod
async def update_conversation_id(
self, old_conversation_id: str, new_conversation_id: str
) -> int:
"""Update conversation_id when task-id transitions to real conversation-id.
Returns the number of messages updated.
"""
@dataclass
class SQLPendingMessageService(PendingMessageService):
"""SQL implementation of PendingMessageService."""
db_session: AsyncSession
async def add_message(
self,
conversation_id: str,
content: list[TextContent | ImageContent],
role: str = 'user',
) -> PendingMessageResponse:
"""Queue a message for delivery when conversation becomes ready."""
# Create the pending message
pending_message = PendingMessage(
conversation_id=conversation_id,
role=role,
content=content,
)
# Count existing pending messages for position
count_stmt = select(func.count()).where(
StoredPendingMessage.conversation_id == conversation_id
)
result = await self.db_session.execute(count_stmt)
position = result.scalar() or 0
# Serialize content to JSON-compatible format for storage
content_json = [item.model_dump() for item in content]
# Store in database
stored_message = StoredPendingMessage(
id=str(pending_message.id),
conversation_id=conversation_id,
role=role,
content=content_json,
created_at=pending_message.created_at,
)
self.db_session.add(stored_message)
await self.db_session.commit()
return PendingMessageResponse(
id=pending_message.id,
queued=True,
position=position + 1,
)
async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]:
"""Get all pending messages for a conversation, ordered by created_at."""
stmt = (
select(StoredPendingMessage)
.where(StoredPendingMessage.conversation_id == conversation_id)
.order_by(StoredPendingMessage.created_at.asc())
)
result = await self.db_session.execute(stmt)
stored_messages = result.scalars().all()
return [
PendingMessage(
id=msg.id,
conversation_id=msg.conversation_id,
role=msg.role,
content=_content_type_adapter.validate_python(msg.content),
created_at=msg.created_at,
)
for msg in stored_messages
]
async def count_pending_messages(self, conversation_id: str) -> int:
"""Count pending messages for a conversation."""
count_stmt = select(func.count()).where(
StoredPendingMessage.conversation_id == conversation_id
)
result = await self.db_session.execute(count_stmt)
return result.scalar() or 0
async def delete_messages_for_conversation(self, conversation_id: str) -> int:
"""Delete all pending messages for a conversation, returning count deleted."""
stmt = select(StoredPendingMessage).where(
StoredPendingMessage.conversation_id == conversation_id
)
result = await self.db_session.execute(stmt)
stored_messages = result.scalars().all()
count = len(stored_messages)
for msg in stored_messages:
await self.db_session.delete(msg)
if count > 0:
await self.db_session.commit()
return count
async def update_conversation_id(
self, old_conversation_id: str, new_conversation_id: str
) -> int:
"""Update conversation_id when task-id transitions to real conversation-id."""
stmt = select(StoredPendingMessage).where(
StoredPendingMessage.conversation_id == old_conversation_id
)
result = await self.db_session.execute(stmt)
stored_messages = result.scalars().all()
count = len(stored_messages)
for msg in stored_messages:
msg.conversation_id = new_conversation_id
if count > 0:
await self.db_session.commit()
return count
class PendingMessageServiceInjector(
DiscriminatedUnionMixin, Injector[PendingMessageService], ABC
):
"""Abstract injector for PendingMessageService."""
pass
class SQLPendingMessageServiceInjector(PendingMessageServiceInjector):
"""SQL-based injector for PendingMessageService."""
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[PendingMessageService, None]:
from openhands.app_server.config import get_db_session
async with get_db_session(state) as db_session:
yield SQLPendingMessageService(db_session=db_session)

View File

@@ -5,6 +5,9 @@ from openhands.app_server.event import event_router
from openhands.app_server.event_callback import (
webhook_router,
)
from openhands.app_server.pending_messages.pending_message_router import (
router as pending_message_router,
)
from openhands.app_server.sandbox import sandbox_router, sandbox_spec_router
from openhands.app_server.user import user_router
from openhands.app_server.web_client import web_client_router
@@ -13,6 +16,7 @@ from openhands.app_server.web_client import web_client_router
router = APIRouter(prefix='/api/v1')
router.include_router(event_router.router)
router.include_router(app_conversation_router.router)
router.include_router(pending_message_router)
router.include_router(sandbox_router.router)
router.include_router(sandbox_spec_router.router)
router.include_router(user_router.router)

View File

@@ -80,6 +80,7 @@ class TestLiveStatusAppConversationService:
self.mock_event_callback_service = Mock()
self.mock_event_service = Mock()
self.mock_httpx_client = Mock()
self.mock_pending_message_service = Mock()
# Create service instance
self.service = LiveStatusAppConversationService(
@@ -92,6 +93,7 @@ class TestLiveStatusAppConversationService:
sandbox_service=self.mock_sandbox_service,
sandbox_spec_service=self.mock_sandbox_spec_service,
jwt_service=self.mock_jwt_service,
pending_message_service=self.mock_pending_message_service,
sandbox_startup_timeout=30,
sandbox_startup_poll_frequency=1,
max_num_conversations_per_sandbox=20,
@@ -2329,6 +2331,7 @@ class TestPluginHandling:
self.mock_event_callback_service = Mock()
self.mock_event_service = Mock()
self.mock_httpx_client = Mock()
self.mock_pending_message_service = Mock()
# Create service instance
self.service = LiveStatusAppConversationService(
@@ -2341,6 +2344,7 @@ class TestPluginHandling:
sandbox_service=self.mock_sandbox_service,
sandbox_spec_service=self.mock_sandbox_spec_service,
jwt_service=self.mock_jwt_service,
pending_message_service=self.mock_pending_message_service,
sandbox_startup_timeout=30,
sandbox_startup_poll_frequency=1,
max_num_conversations_per_sandbox=20,

View File

@@ -0,0 +1,227 @@
"""Unit tests for the pending_message_router endpoints.
This module tests the queue_pending_message endpoint,
focusing on request validation and rate limiting.
"""
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from fastapi import HTTPException, status
from openhands.agent_server.models import TextContent
from openhands.app_server.pending_messages.pending_message_models import (
PendingMessageResponse,
)
from openhands.app_server.pending_messages.pending_message_router import (
queue_pending_message,
)
def _make_mock_service(
add_message_return=None,
count_pending_messages_return=0,
):
"""Create a mock PendingMessageService for testing."""
service = MagicMock()
service.add_message = AsyncMock(return_value=add_message_return)
service.count_pending_messages = AsyncMock(
return_value=count_pending_messages_return
)
return service
def _make_mock_request(body: dict):
"""Create a mock FastAPI Request with given JSON body."""
request = MagicMock()
request.json = AsyncMock(return_value=body)
return request
@pytest.mark.asyncio
class TestQueuePendingMessage:
"""Test suite for queue_pending_message endpoint."""
async def test_queues_message_successfully(self):
"""Test that a valid message is queued successfully."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
raw_content = [{'type': 'text', 'text': 'Hello, world!'}]
expected_response = PendingMessageResponse(
id=str(uuid4()),
queued=True,
position=1,
)
mock_service = _make_mock_service(
add_message_return=expected_response,
count_pending_messages_return=0,
)
mock_request = _make_mock_request({'content': raw_content, 'role': 'user'})
# Act
result = await queue_pending_message(
conversation_id=conversation_id,
request=mock_request,
pending_service=mock_service,
)
# Assert
assert result == expected_response
mock_service.add_message.assert_called_once()
call_kwargs = mock_service.add_message.call_args.kwargs
assert call_kwargs['conversation_id'] == conversation_id
assert call_kwargs['role'] == 'user'
# Content should be parsed into typed objects
assert len(call_kwargs['content']) == 1
assert isinstance(call_kwargs['content'][0], TextContent)
assert call_kwargs['content'][0].text == 'Hello, world!'
async def test_uses_default_role_when_not_provided(self):
"""Test that 'user' role is used by default."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
raw_content = [{'type': 'text', 'text': 'Test message'}]
expected_response = PendingMessageResponse(
id=str(uuid4()),
queued=True,
position=1,
)
mock_service = _make_mock_service(
add_message_return=expected_response,
count_pending_messages_return=0,
)
mock_request = _make_mock_request({'content': raw_content})
# Act
await queue_pending_message(
conversation_id=conversation_id,
request=mock_request,
pending_service=mock_service,
)
# Assert
mock_service.add_message.assert_called_once()
call_kwargs = mock_service.add_message.call_args.kwargs
assert call_kwargs['conversation_id'] == conversation_id
assert call_kwargs['role'] == 'user'
assert isinstance(call_kwargs['content'][0], TextContent)
async def test_returns_400_for_invalid_json_body(self):
"""Test that invalid JSON body returns 400 Bad Request."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
mock_service = _make_mock_service()
mock_request = MagicMock()
mock_request.json = AsyncMock(side_effect=Exception('Invalid JSON'))
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await queue_pending_message(
conversation_id=conversation_id,
request=mock_request,
pending_service=mock_service,
)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
assert 'Invalid request body' in exc_info.value.detail
async def test_returns_400_when_content_is_missing(self):
"""Test that missing content returns 400 Bad Request."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
mock_service = _make_mock_service()
mock_request = _make_mock_request({'role': 'user'})
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await queue_pending_message(
conversation_id=conversation_id,
request=mock_request,
pending_service=mock_service,
)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
assert 'content must be a non-empty list' in exc_info.value.detail
async def test_returns_400_when_content_is_not_a_list(self):
"""Test that non-list content returns 400 Bad Request."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
mock_service = _make_mock_service()
mock_request = _make_mock_request({'content': 'not a list'})
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await queue_pending_message(
conversation_id=conversation_id,
request=mock_request,
pending_service=mock_service,
)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
assert 'content must be a non-empty list' in exc_info.value.detail
async def test_returns_400_when_content_is_empty_list(self):
"""Test that empty list content returns 400 Bad Request."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
mock_service = _make_mock_service()
mock_request = _make_mock_request({'content': []})
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await queue_pending_message(
conversation_id=conversation_id,
request=mock_request,
pending_service=mock_service,
)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
assert 'content must be a non-empty list' in exc_info.value.detail
async def test_returns_429_when_rate_limit_exceeded(self):
"""Test that exceeding rate limit returns 429 Too Many Requests."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
raw_content = [{'type': 'text', 'text': 'Test message'}]
mock_service = _make_mock_service(count_pending_messages_return=10)
mock_request = _make_mock_request({'content': raw_content})
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await queue_pending_message(
conversation_id=conversation_id,
request=mock_request,
pending_service=mock_service,
)
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert 'Maximum 10 messages' in exc_info.value.detail
async def test_allows_up_to_10_messages(self):
"""Test that 9 existing messages still allows adding one more."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
raw_content = [{'type': 'text', 'text': 'Test message'}]
expected_response = PendingMessageResponse(
id=str(uuid4()),
queued=True,
position=10,
)
mock_service = _make_mock_service(
add_message_return=expected_response,
count_pending_messages_return=9,
)
mock_request = _make_mock_request({'content': raw_content})
# Act
result = await queue_pending_message(
conversation_id=conversation_id,
request=mock_request,
pending_service=mock_service,
)
# Assert
assert result == expected_response
mock_service.add_message.assert_called_once()

View File

@@ -0,0 +1,309 @@
"""Tests for SQLPendingMessageService.
This module tests the SQL implementation of PendingMessageService,
covering message queuing, retrieval, counting, deletion, and
conversation_id updates using SQLite as a mock database.
"""
from typing import AsyncGenerator
from uuid import uuid4
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from openhands.agent_server.models import TextContent
from openhands.app_server.pending_messages.pending_message_models import (
PendingMessageResponse,
)
from openhands.app_server.pending_messages.pending_message_service import (
SQLPendingMessageService,
)
from openhands.app_server.utils.sql_utils import Base
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
yield db_session
@pytest.fixture
def service(async_session) -> SQLPendingMessageService:
"""Create a SQLPendingMessageService instance for testing."""
return SQLPendingMessageService(db_session=async_session)
@pytest.fixture
def sample_content() -> list[TextContent]:
"""Create sample message content for testing."""
return [TextContent(text='Hello, this is a test message')]
class TestSQLPendingMessageService:
"""Test suite for SQLPendingMessageService."""
@pytest.mark.asyncio
async def test_add_message_creates_message_with_correct_data(
self,
service: SQLPendingMessageService,
sample_content: list[TextContent],
):
"""Test that add_message creates a message with the expected fields."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
# Act
response = await service.add_message(
conversation_id=conversation_id,
content=sample_content,
role='user',
)
# Assert
assert isinstance(response, PendingMessageResponse)
assert response.queued is True
assert response.id is not None
# Verify the message was stored correctly
messages = await service.get_pending_messages(conversation_id)
assert len(messages) == 1
assert messages[0].conversation_id == conversation_id
assert len(messages[0].content) == 1
assert isinstance(messages[0].content[0], TextContent)
assert messages[0].content[0].text == sample_content[0].text
assert messages[0].role == 'user'
assert messages[0].created_at is not None
@pytest.mark.asyncio
async def test_add_message_returns_correct_queue_position(
self,
service: SQLPendingMessageService,
sample_content: list[TextContent],
):
"""Test that queue position increments correctly for each message."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
# Act - Add three messages
response1 = await service.add_message(conversation_id, sample_content)
response2 = await service.add_message(conversation_id, sample_content)
response3 = await service.add_message(conversation_id, sample_content)
# Assert
assert response1.position == 1
assert response2.position == 2
assert response3.position == 3
@pytest.mark.asyncio
async def test_get_pending_messages_returns_messages_ordered_by_created_at(
self,
service: SQLPendingMessageService,
):
"""Test that messages are returned in chronological order."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
contents = [
[TextContent(text='First message')],
[TextContent(text='Second message')],
[TextContent(text='Third message')],
]
for content in contents:
await service.add_message(conversation_id, content)
# Act
messages = await service.get_pending_messages(conversation_id)
# Assert
assert len(messages) == 3
assert messages[0].content[0].text == 'First message'
assert messages[1].content[0].text == 'Second message'
assert messages[2].content[0].text == 'Third message'
@pytest.mark.asyncio
async def test_get_pending_messages_returns_empty_list_when_none_exist(
self,
service: SQLPendingMessageService,
):
"""Test that an empty list is returned for a conversation with no messages."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
# Act
messages = await service.get_pending_messages(conversation_id)
# Assert
assert messages == []
@pytest.mark.asyncio
async def test_count_pending_messages_returns_correct_count(
self,
service: SQLPendingMessageService,
sample_content: list[TextContent],
):
"""Test that count_pending_messages returns the correct number."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
other_conversation_id = f'task-{uuid4().hex}'
# Add 3 messages to first conversation
for _ in range(3):
await service.add_message(conversation_id, sample_content)
# Add 2 messages to second conversation
for _ in range(2):
await service.add_message(other_conversation_id, sample_content)
# Act
count1 = await service.count_pending_messages(conversation_id)
count2 = await service.count_pending_messages(other_conversation_id)
count_empty = await service.count_pending_messages('nonexistent')
# Assert
assert count1 == 3
assert count2 == 2
assert count_empty == 0
@pytest.mark.asyncio
async def test_delete_messages_for_conversation_removes_all_messages(
self,
service: SQLPendingMessageService,
sample_content: list[TextContent],
):
"""Test that delete_messages_for_conversation removes all messages and returns count."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
other_conversation_id = f'task-{uuid4().hex}'
# Add messages to both conversations
for _ in range(3):
await service.add_message(conversation_id, sample_content)
await service.add_message(other_conversation_id, sample_content)
# Act
deleted_count = await service.delete_messages_for_conversation(conversation_id)
# Assert
assert deleted_count == 3
assert await service.count_pending_messages(conversation_id) == 0
# Other conversation should be unaffected
assert await service.count_pending_messages(other_conversation_id) == 1
@pytest.mark.asyncio
async def test_delete_messages_for_conversation_returns_zero_when_none_exist(
self,
service: SQLPendingMessageService,
):
"""Test that deleting from nonexistent conversation returns 0."""
# Arrange
conversation_id = f'task-{uuid4().hex}'
# Act
deleted_count = await service.delete_messages_for_conversation(conversation_id)
# Assert
assert deleted_count == 0
@pytest.mark.asyncio
async def test_update_conversation_id_updates_all_matching_messages(
self,
service: SQLPendingMessageService,
sample_content: list[TextContent],
):
"""Test that update_conversation_id updates all messages with the old ID."""
# Arrange
old_conversation_id = f'task-{uuid4().hex}'
new_conversation_id = str(uuid4())
unrelated_conversation_id = f'task-{uuid4().hex}'
# Add messages to old conversation
for _ in range(3):
await service.add_message(old_conversation_id, sample_content)
# Add message to unrelated conversation
await service.add_message(unrelated_conversation_id, sample_content)
# Act
updated_count = await service.update_conversation_id(
old_conversation_id, new_conversation_id
)
# Assert
assert updated_count == 3
# Verify old conversation has no messages
assert await service.count_pending_messages(old_conversation_id) == 0
# Verify new conversation has all messages
messages = await service.get_pending_messages(new_conversation_id)
assert len(messages) == 3
for msg in messages:
assert msg.conversation_id == new_conversation_id
# Verify unrelated conversation is unchanged
assert await service.count_pending_messages(unrelated_conversation_id) == 1
@pytest.mark.asyncio
async def test_update_conversation_id_returns_zero_when_no_match(
self,
service: SQLPendingMessageService,
):
"""Test that updating nonexistent conversation_id returns 0."""
# Arrange
old_conversation_id = f'task-{uuid4().hex}'
new_conversation_id = str(uuid4())
# Act
updated_count = await service.update_conversation_id(
old_conversation_id, new_conversation_id
)
# Assert
assert updated_count == 0
@pytest.mark.asyncio
async def test_messages_are_isolated_between_conversations(
self,
service: SQLPendingMessageService,
):
"""Test that operations on one conversation don't affect others."""
# Arrange
conv1 = f'task-{uuid4().hex}'
conv2 = f'task-{uuid4().hex}'
await service.add_message(conv1, [TextContent(text='Conv1 msg')])
await service.add_message(conv2, [TextContent(text='Conv2 msg')])
# Act
messages1 = await service.get_pending_messages(conv1)
messages2 = await service.get_pending_messages(conv2)
# Assert
assert len(messages1) == 1
assert len(messages2) == 1
assert messages1[0].content[0].text == 'Conv1 msg'
assert messages2[0].content[0].text == 'Conv2 msg'

View File

@@ -2187,6 +2187,7 @@ async def test_delete_v1_conversation_with_sub_conversations():
sandbox_service=mock_sandbox_service,
sandbox_spec_service=MagicMock(),
jwt_service=MagicMock(),
pending_message_service=MagicMock(),
sandbox_startup_timeout=120,
sandbox_startup_poll_frequency=2,
max_num_conversations_per_sandbox=20,
@@ -2311,6 +2312,7 @@ async def test_delete_v1_conversation_with_no_sub_conversations():
sandbox_service=mock_sandbox_service,
sandbox_spec_service=MagicMock(),
jwt_service=MagicMock(),
pending_message_service=MagicMock(),
sandbox_startup_timeout=120,
sandbox_startup_poll_frequency=2,
max_num_conversations_per_sandbox=20,
@@ -2465,6 +2467,7 @@ async def test_delete_v1_conversation_sub_conversation_deletion_error():
sandbox_service=mock_sandbox_service,
sandbox_spec_service=MagicMock(),
jwt_service=MagicMock(),
pending_message_service=MagicMock(),
sandbox_startup_timeout=120,
sandbox_startup_poll_frequency=2,
max_num_conversations_per_sandbox=20,