mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
fix(frontend): prevent chat message loss during websocket disconnections or page refresh (#13380)
This commit is contained in:
@@ -0,0 +1,39 @@
|
|||||||
|
"""Add pending_messages table for server-side message queuing
|
||||||
|
|
||||||
|
Revision ID: 101
|
||||||
|
Revises: 100
|
||||||
|
Create Date: 2025-03-15 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '101'
|
||||||
|
down_revision: Union[str, None] = '100'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Create pending_messages table for storing messages before conversation is ready.
|
||||||
|
|
||||||
|
Messages are stored temporarily until the conversation becomes ready, then
|
||||||
|
delivered and deleted regardless of success or failure.
|
||||||
|
"""
|
||||||
|
op.create_table(
|
||||||
|
'pending_messages',
|
||||||
|
sa.Column('id', sa.String(), primary_key=True),
|
||||||
|
sa.Column('conversation_id', sa.String(), nullable=False, index=True),
|
||||||
|
sa.Column('role', sa.String(20), nullable=False, server_default='user'),
|
||||||
|
sa.Column('content', sa.JSON, nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove pending_messages table."""
|
||||||
|
op.drop_table('pending_messages')
|
||||||
172
enterprise/server/utils/saas_pending_message_injector.py
Normal file
172
enterprise/server/utils/saas_pending_message_injector.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""Enterprise injector for PendingMessageService with SAAS filtering."""
|
||||||
|
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from sqlalchemy import select
|
||||||
|
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||||
|
from storage.user import User
|
||||||
|
|
||||||
|
from openhands.agent_server.models import ImageContent, TextContent
|
||||||
|
from openhands.app_server.errors import AuthError
|
||||||
|
from openhands.app_server.pending_messages.pending_message_models import (
|
||||||
|
PendingMessageResponse,
|
||||||
|
)
|
||||||
|
from openhands.app_server.pending_messages.pending_message_service import (
|
||||||
|
PendingMessageService,
|
||||||
|
PendingMessageServiceInjector,
|
||||||
|
SQLPendingMessageService,
|
||||||
|
)
|
||||||
|
from openhands.app_server.services.injector import InjectorState
|
||||||
|
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||||
|
from openhands.app_server.user.user_context import UserContext
|
||||||
|
|
||||||
|
|
||||||
|
class SaasSQLPendingMessageService(SQLPendingMessageService):
|
||||||
|
"""Extended SQLPendingMessageService with user and organization-based filtering.
|
||||||
|
|
||||||
|
This enterprise version ensures that:
|
||||||
|
- Users can only queue messages for conversations they own
|
||||||
|
- Organization isolation is enforced for multi-tenant deployments
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_session, user_context: UserContext):
|
||||||
|
super().__init__(db_session=db_session)
|
||||||
|
self.user_context = user_context
|
||||||
|
|
||||||
|
async def _get_current_user(self) -> User | None:
|
||||||
|
"""Get the current user using the existing db_session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User object or None if no user_id is available
|
||||||
|
"""
|
||||||
|
user_id_str = await self.user_context.get_user_id()
|
||||||
|
if not user_id_str:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_id_uuid = UUID(user_id_str)
|
||||||
|
result = await self.db_session.execute(
|
||||||
|
select(User).where(User.id == user_id_uuid)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
async def _validate_conversation_ownership(self, conversation_id: str) -> None:
|
||||||
|
"""Validate that the current user owns the conversation.
|
||||||
|
|
||||||
|
This ensures multi-tenant isolation by checking:
|
||||||
|
- The conversation belongs to the current user
|
||||||
|
- The conversation belongs to the user's current organization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The conversation ID to validate (can be task-id or UUID)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthError: If user doesn't own the conversation or authentication fails
|
||||||
|
"""
|
||||||
|
# For internal operations (e.g., processing pending messages during startup)
|
||||||
|
# we need a mode that bypasses filtering. The ADMIN context enables this.
|
||||||
|
if self.user_context == ADMIN:
|
||||||
|
return
|
||||||
|
|
||||||
|
user_id_str = await self.user_context.get_user_id()
|
||||||
|
if not user_id_str:
|
||||||
|
raise AuthError('User authentication required')
|
||||||
|
|
||||||
|
user_id_uuid = UUID(user_id_str)
|
||||||
|
|
||||||
|
# Check conversation ownership via SAAS metadata
|
||||||
|
query = select(StoredConversationMetadataSaas).where(
|
||||||
|
StoredConversationMetadataSaas.conversation_id == conversation_id
|
||||||
|
)
|
||||||
|
result = await self.db_session.execute(query)
|
||||||
|
saas_metadata = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# If no SAAS metadata exists, the conversation might be a new task-id
|
||||||
|
# that hasn't been linked to a conversation yet. Allow access in this case
|
||||||
|
# as the message will be validated when the conversation is created.
|
||||||
|
if saas_metadata is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Verify user ownership
|
||||||
|
if saas_metadata.user_id != user_id_uuid:
|
||||||
|
raise AuthError('You do not have access to this conversation')
|
||||||
|
|
||||||
|
# Verify organization ownership if applicable
|
||||||
|
user = await self._get_current_user()
|
||||||
|
if user and user.current_org_id is not None:
|
||||||
|
if saas_metadata.org_id != user.current_org_id:
|
||||||
|
raise AuthError('Conversation belongs to a different organization')
|
||||||
|
|
||||||
|
async def add_message(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
content: list[TextContent | ImageContent],
|
||||||
|
role: str = 'user',
|
||||||
|
) -> PendingMessageResponse:
|
||||||
|
"""Queue a message with ownership validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The conversation ID to queue the message for
|
||||||
|
content: Message content
|
||||||
|
role: Message role (default: 'user')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PendingMessageResponse with the queued message info
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthError: If user doesn't own the conversation
|
||||||
|
"""
|
||||||
|
await self._validate_conversation_ownership(conversation_id)
|
||||||
|
return await super().add_message(conversation_id, content, role)
|
||||||
|
|
||||||
|
async def get_pending_messages(self, conversation_id: str):
|
||||||
|
"""Get pending messages with ownership validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The conversation ID to get messages for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of pending messages
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthError: If user doesn't own the conversation
|
||||||
|
"""
|
||||||
|
await self._validate_conversation_ownership(conversation_id)
|
||||||
|
return await super().get_pending_messages(conversation_id)
|
||||||
|
|
||||||
|
async def count_pending_messages(self, conversation_id: str) -> int:
|
||||||
|
"""Count pending messages with ownership validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The conversation ID to count messages for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of pending messages
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthError: If user doesn't own the conversation
|
||||||
|
"""
|
||||||
|
await self._validate_conversation_ownership(conversation_id)
|
||||||
|
return await super().count_pending_messages(conversation_id)
|
||||||
|
|
||||||
|
|
||||||
|
class SaasPendingMessageServiceInjector(PendingMessageServiceInjector):
|
||||||
|
"""Enterprise injector for PendingMessageService with SAAS filtering."""
|
||||||
|
|
||||||
|
async def inject(
|
||||||
|
self, state: InjectorState, request: Request | None = None
|
||||||
|
) -> AsyncGenerator[PendingMessageService, None]:
|
||||||
|
from openhands.app_server.config import (
|
||||||
|
get_db_session,
|
||||||
|
get_user_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with (
|
||||||
|
get_user_context(state, request) as user_context,
|
||||||
|
get_db_session(state, request) as db_session,
|
||||||
|
):
|
||||||
|
service = SaasSQLPendingMessageService(
|
||||||
|
db_session=db_session, user_context=user_context
|
||||||
|
)
|
||||||
|
yield service
|
||||||
@@ -198,9 +198,9 @@ describe("InteractiveChatBox", () => {
|
|||||||
expect(onSubmitMock).toHaveBeenCalledWith("Hello, world!", [], []);
|
expect(onSubmitMock).toHaveBeenCalledWith("Hello, world!", [], []);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should disable the submit button when agent is loading", async () => {
|
it("should disable the submit button when awaiting user confirmation", async () => {
|
||||||
const user = userEvent.setup();
|
const user = userEvent.setup();
|
||||||
mockStores(AgentState.LOADING);
|
mockStores(AgentState.AWAITING_USER_CONFIRMATION);
|
||||||
|
|
||||||
renderInteractiveChatBox({
|
renderInteractiveChatBox({
|
||||||
onSubmit: onSubmitMock,
|
onSubmit: onSubmitMock,
|
||||||
|
|||||||
@@ -229,4 +229,231 @@ describe("conversation localStorage utilities", () => {
|
|||||||
expect(parsed.subConversationTaskId).toBeNull();
|
expect(parsed.subConversationTaskId).toBeNull();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("draftMessage persistence", () => {
|
||||||
|
describe("getConversationState", () => {
|
||||||
|
it("returns default draftMessage as null when no state exists", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-draft-1";
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const state = getConversationState(conversationId);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(state.draftMessage).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("retrieves draftMessage from localStorage when it exists", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-draft-2";
|
||||||
|
const draftText = "This is my saved draft message";
|
||||||
|
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`;
|
||||||
|
|
||||||
|
localStorage.setItem(
|
||||||
|
consolidatedKey,
|
||||||
|
JSON.stringify({
|
||||||
|
draftMessage: draftText,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const state = getConversationState(conversationId);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(state.draftMessage).toBe(draftText);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns null draftMessage for task conversation IDs (not persisted)", () => {
|
||||||
|
// Arrange
|
||||||
|
const taskId = "task-uuid-123";
|
||||||
|
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${taskId}`;
|
||||||
|
|
||||||
|
// Even if somehow there's data in localStorage for a task ID
|
||||||
|
localStorage.setItem(
|
||||||
|
consolidatedKey,
|
||||||
|
JSON.stringify({
|
||||||
|
draftMessage: "Should not be returned",
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const state = getConversationState(taskId);
|
||||||
|
|
||||||
|
// Assert - should return default state, not the stored value
|
||||||
|
expect(state.draftMessage).toBeNull();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("setConversationState", () => {
|
||||||
|
it("persists draftMessage to localStorage", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-draft-3";
|
||||||
|
const draftText = "New draft message to save";
|
||||||
|
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`;
|
||||||
|
|
||||||
|
// Act
|
||||||
|
setConversationState(conversationId, {
|
||||||
|
draftMessage: draftText,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
const stored = localStorage.getItem(consolidatedKey);
|
||||||
|
expect(stored).not.toBeNull();
|
||||||
|
const parsed = JSON.parse(stored!);
|
||||||
|
expect(parsed.draftMessage).toBe(draftText);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("does not persist draftMessage for task conversation IDs", () => {
|
||||||
|
// Arrange
|
||||||
|
const taskId = "task-draft-xyz";
|
||||||
|
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${taskId}`;
|
||||||
|
|
||||||
|
// Act
|
||||||
|
setConversationState(taskId, {
|
||||||
|
draftMessage: "Draft for task ID",
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - nothing should be stored
|
||||||
|
expect(localStorage.getItem(consolidatedKey)).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("merges draftMessage with existing state without overwriting other fields", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-draft-4";
|
||||||
|
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`;
|
||||||
|
|
||||||
|
localStorage.setItem(
|
||||||
|
consolidatedKey,
|
||||||
|
JSON.stringify({
|
||||||
|
selectedTab: "terminal",
|
||||||
|
rightPanelShown: false,
|
||||||
|
unpinnedTabs: ["tab-1", "tab-2"],
|
||||||
|
conversationMode: "plan",
|
||||||
|
subConversationTaskId: "task-123",
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
setConversationState(conversationId, {
|
||||||
|
draftMessage: "Updated draft",
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
const stored = localStorage.getItem(consolidatedKey);
|
||||||
|
const parsed = JSON.parse(stored!);
|
||||||
|
|
||||||
|
expect(parsed.draftMessage).toBe("Updated draft");
|
||||||
|
expect(parsed.selectedTab).toBe("terminal");
|
||||||
|
expect(parsed.rightPanelShown).toBe(false);
|
||||||
|
expect(parsed.unpinnedTabs).toEqual(["tab-1", "tab-2"]);
|
||||||
|
expect(parsed.conversationMode).toBe("plan");
|
||||||
|
expect(parsed.subConversationTaskId).toBe("task-123");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("clears draftMessage when set to null", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-draft-5";
|
||||||
|
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`;
|
||||||
|
|
||||||
|
localStorage.setItem(
|
||||||
|
consolidatedKey,
|
||||||
|
JSON.stringify({
|
||||||
|
draftMessage: "Existing draft",
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
setConversationState(conversationId, {
|
||||||
|
draftMessage: null,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
const stored = localStorage.getItem(consolidatedKey);
|
||||||
|
const parsed = JSON.parse(stored!);
|
||||||
|
expect(parsed.draftMessage).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("clears draftMessage when set to empty string (stored as empty string)", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-draft-6";
|
||||||
|
const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`;
|
||||||
|
|
||||||
|
localStorage.setItem(
|
||||||
|
consolidatedKey,
|
||||||
|
JSON.stringify({
|
||||||
|
draftMessage: "Existing draft",
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
setConversationState(conversationId, {
|
||||||
|
draftMessage: "",
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
const stored = localStorage.getItem(consolidatedKey);
|
||||||
|
const parsed = JSON.parse(stored!);
|
||||||
|
expect(parsed.draftMessage).toBe("");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("conversation-specific draft isolation", () => {
|
||||||
|
it("stores drafts separately for different conversations", () => {
|
||||||
|
// Arrange
|
||||||
|
const convA = "conv-A";
|
||||||
|
const convB = "conv-B";
|
||||||
|
const draftA = "Draft for conversation A";
|
||||||
|
const draftB = "Draft for conversation B";
|
||||||
|
|
||||||
|
// Act
|
||||||
|
setConversationState(convA, { draftMessage: draftA });
|
||||||
|
setConversationState(convB, { draftMessage: draftB });
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
const stateA = getConversationState(convA);
|
||||||
|
const stateB = getConversationState(convB);
|
||||||
|
|
||||||
|
expect(stateA.draftMessage).toBe(draftA);
|
||||||
|
expect(stateB.draftMessage).toBe(draftB);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("updating one conversation draft does not affect another", () => {
|
||||||
|
// Arrange
|
||||||
|
const convA = "conv-isolated-A";
|
||||||
|
const convB = "conv-isolated-B";
|
||||||
|
|
||||||
|
setConversationState(convA, { draftMessage: "Original draft A" });
|
||||||
|
setConversationState(convB, { draftMessage: "Original draft B" });
|
||||||
|
|
||||||
|
// Act - update only conversation A
|
||||||
|
setConversationState(convA, { draftMessage: "Updated draft A" });
|
||||||
|
|
||||||
|
// Assert - conversation B should be unchanged
|
||||||
|
const stateA = getConversationState(convA);
|
||||||
|
const stateB = getConversationState(convB);
|
||||||
|
|
||||||
|
expect(stateA.draftMessage).toBe("Updated draft A");
|
||||||
|
expect(stateB.draftMessage).toBe("Original draft B");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("clearing one conversation draft does not affect another", () => {
|
||||||
|
// Arrange
|
||||||
|
const convA = "conv-clear-A";
|
||||||
|
const convB = "conv-clear-B";
|
||||||
|
|
||||||
|
setConversationState(convA, { draftMessage: "Draft A" });
|
||||||
|
setConversationState(convB, { draftMessage: "Draft B" });
|
||||||
|
|
||||||
|
// Act - clear draft for conversation A
|
||||||
|
setConversationState(convA, { draftMessage: null });
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
const stateA = getConversationState(convA);
|
||||||
|
const stateB = getConversationState(convB);
|
||||||
|
|
||||||
|
expect(stateA.draftMessage).toBeNull();
|
||||||
|
expect(stateB.draftMessage).toBe("Draft B");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import React from "react";
|
||||||
import {
|
import {
|
||||||
describe,
|
describe,
|
||||||
it,
|
it,
|
||||||
@@ -8,7 +9,7 @@ import {
|
|||||||
afterEach,
|
afterEach,
|
||||||
vi,
|
vi,
|
||||||
} from "vitest";
|
} from "vitest";
|
||||||
import { screen, waitFor, render, cleanup } from "@testing-library/react";
|
import { screen, waitFor, render, cleanup, act } from "@testing-library/react";
|
||||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||||
import { http, HttpResponse } from "msw";
|
import { http, HttpResponse } from "msw";
|
||||||
import { MemoryRouter, Route, Routes } from "react-router";
|
import { MemoryRouter, Route, Routes } from "react-router";
|
||||||
@@ -682,8 +683,242 @@ describe("Conversation WebSocket Handler", () => {
|
|||||||
|
|
||||||
// 7. Message Sending Tests
|
// 7. Message Sending Tests
|
||||||
describe("Message Sending", () => {
|
describe("Message Sending", () => {
|
||||||
it.todo("should send user actions through WebSocket when connected");
|
it("should send user actions through WebSocket when connected", async () => {
|
||||||
it.todo("should handle send attempts when disconnected");
|
// Arrange
|
||||||
|
const conversationId = "test-conversation-send";
|
||||||
|
let receivedMessage: unknown = null;
|
||||||
|
|
||||||
|
// Set up MSW to capture sent messages
|
||||||
|
mswServer.use(
|
||||||
|
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||||
|
server.connect();
|
||||||
|
|
||||||
|
// Capture messages sent from client
|
||||||
|
client.addEventListener("message", (event) => {
|
||||||
|
receivedMessage = JSON.parse(event.data as string);
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create ref to store sendMessage function
|
||||||
|
let sendMessageFn: typeof useConversationWebSocket extends () => infer R
|
||||||
|
? R extends { sendMessage: infer S }
|
||||||
|
? S
|
||||||
|
: null
|
||||||
|
: null = null;
|
||||||
|
|
||||||
|
function TestComponent() {
|
||||||
|
const context = useConversationWebSocket();
|
||||||
|
|
||||||
|
React.useEffect(() => {
|
||||||
|
if (context?.sendMessage) {
|
||||||
|
sendMessageFn = context.sendMessage;
|
||||||
|
}
|
||||||
|
}, [context?.sendMessage]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<div data-testid="connection-state">
|
||||||
|
{context?.connectionState || "NOT_AVAILABLE"}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Act
|
||||||
|
renderWithWebSocketContext(
|
||||||
|
<TestComponent />,
|
||||||
|
conversationId,
|
||||||
|
`http://localhost:3000/api/conversations/${conversationId}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wait for connection
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||||
|
"OPEN",
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Send a message
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(sendMessageFn).not.toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await sendMessageFn!({
|
||||||
|
role: "user",
|
||||||
|
content: [{ type: "text", text: "Hello from test" }],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - message should have been received by mock server
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(receivedMessage).toEqual({
|
||||||
|
role: "user",
|
||||||
|
content: [{ type: "text", text: "Hello from test" }],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should not throw error when sendMessage is called with WebSocket connected", async () => {
|
||||||
|
// This test verifies that sendMessage doesn't throw an error
|
||||||
|
// when the WebSocket is connected.
|
||||||
|
const conversationId = "test-conversation-no-throw";
|
||||||
|
let sendError: Error | null = null;
|
||||||
|
|
||||||
|
// Set up MSW to connect and receive messages
|
||||||
|
mswServer.use(
|
||||||
|
wsLink.addEventListener("connection", ({ server }) => {
|
||||||
|
server.connect();
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create ref to store sendMessage function
|
||||||
|
let sendMessageFn: typeof useConversationWebSocket extends () => infer R
|
||||||
|
? R extends { sendMessage: infer S }
|
||||||
|
? S
|
||||||
|
: null
|
||||||
|
: null = null;
|
||||||
|
|
||||||
|
function TestComponent() {
|
||||||
|
const context = useConversationWebSocket();
|
||||||
|
|
||||||
|
React.useEffect(() => {
|
||||||
|
if (context?.sendMessage) {
|
||||||
|
sendMessageFn = context.sendMessage;
|
||||||
|
}
|
||||||
|
}, [context?.sendMessage]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<div data-testid="connection-state">
|
||||||
|
{context?.connectionState || "NOT_AVAILABLE"}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Act
|
||||||
|
renderWithWebSocketContext(
|
||||||
|
<TestComponent />,
|
||||||
|
conversationId,
|
||||||
|
`http://localhost:3000/api/conversations/${conversationId}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wait for connection
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||||
|
"OPEN",
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait for the context to be available
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(sendMessageFn).not.toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Try to send a message
|
||||||
|
await act(async () => {
|
||||||
|
try {
|
||||||
|
await sendMessageFn!({
|
||||||
|
role: "user",
|
||||||
|
content: [{ type: "text", text: "Test message" }],
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
sendError = error as Error;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - should NOT throw an error
|
||||||
|
expect(sendError).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should send multiple messages through WebSocket in order", async () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "test-conversation-multi";
|
||||||
|
const receivedMessages: unknown[] = [];
|
||||||
|
|
||||||
|
// Set up MSW to capture sent messages
|
||||||
|
mswServer.use(
|
||||||
|
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||||
|
server.connect();
|
||||||
|
|
||||||
|
// Capture messages sent from client
|
||||||
|
client.addEventListener("message", (event) => {
|
||||||
|
receivedMessages.push(JSON.parse(event.data as string));
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create ref to store sendMessage function
|
||||||
|
let sendMessageFn: typeof useConversationWebSocket extends () => infer R
|
||||||
|
? R extends { sendMessage: infer S }
|
||||||
|
? S
|
||||||
|
: null
|
||||||
|
: null = null;
|
||||||
|
|
||||||
|
function TestComponent() {
|
||||||
|
const context = useConversationWebSocket();
|
||||||
|
|
||||||
|
React.useEffect(() => {
|
||||||
|
if (context?.sendMessage) {
|
||||||
|
sendMessageFn = context.sendMessage;
|
||||||
|
}
|
||||||
|
}, [context?.sendMessage]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<div data-testid="connection-state">
|
||||||
|
{context?.connectionState || "NOT_AVAILABLE"}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Act
|
||||||
|
renderWithWebSocketContext(
|
||||||
|
<TestComponent />,
|
||||||
|
conversationId,
|
||||||
|
`http://localhost:3000/api/conversations/${conversationId}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wait for connection
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||||
|
"OPEN",
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(sendMessageFn).not.toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Send multiple messages
|
||||||
|
await act(async () => {
|
||||||
|
await sendMessageFn!({
|
||||||
|
role: "user",
|
||||||
|
content: [{ type: "text", text: "Message 1" }],
|
||||||
|
});
|
||||||
|
await sendMessageFn!({
|
||||||
|
role: "user",
|
||||||
|
content: [{ type: "text", text: "Message 2" }],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - both messages should have been received in order
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(receivedMessages.length).toBe(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(receivedMessages[0]).toEqual({
|
||||||
|
role: "user",
|
||||||
|
content: [{ type: "text", text: "Message 1" }],
|
||||||
|
});
|
||||||
|
expect(receivedMessages[1]).toEqual({
|
||||||
|
role: "user",
|
||||||
|
content: [{ type: "text", text: "Message 2" }],
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// 8. History Loading State Tests
|
// 8. History Loading State Tests
|
||||||
|
|||||||
594
frontend/__tests__/hooks/use-draft-persistence.test.tsx
Normal file
594
frontend/__tests__/hooks/use-draft-persistence.test.tsx
Normal file
@@ -0,0 +1,594 @@
|
|||||||
|
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||||
|
import { renderHook, act } from "@testing-library/react";
|
||||||
|
import { useDraftPersistence } from "#/hooks/chat/use-draft-persistence";
|
||||||
|
import * as conversationLocalStorage from "#/utils/conversation-local-storage";
|
||||||
|
|
||||||
|
// Mock the entire module
|
||||||
|
vi.mock("#/utils/conversation-local-storage", () => ({
|
||||||
|
useConversationLocalStorageState: vi.fn(),
|
||||||
|
getConversationState: vi.fn(),
|
||||||
|
setConversationState: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock the getTextContent utility
|
||||||
|
vi.mock("#/components/features/chat/utils/chat-input.utils", () => ({
|
||||||
|
getTextContent: vi.fn((el: HTMLDivElement | null) => el?.textContent || ""),
|
||||||
|
}));
|
||||||
|
|
||||||
|
describe("useDraftPersistence", () => {
|
||||||
|
let mockSetDraftMessage: (message: string | null) => void;
|
||||||
|
|
||||||
|
// Create a mock ref to contentEditable div
|
||||||
|
const createMockChatInputRef = (initialContent = "") => {
|
||||||
|
const div = document.createElement("div");
|
||||||
|
div.setAttribute("contenteditable", "true");
|
||||||
|
div.textContent = initialContent;
|
||||||
|
return { current: div };
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks();
|
||||||
|
vi.useFakeTimers();
|
||||||
|
localStorage.clear();
|
||||||
|
|
||||||
|
mockSetDraftMessage = vi.fn<(message: string | null) => void>();
|
||||||
|
|
||||||
|
// Default mock for useConversationLocalStorageState
|
||||||
|
vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({
|
||||||
|
state: {
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: null,
|
||||||
|
},
|
||||||
|
setSelectedTab: vi.fn(),
|
||||||
|
setRightPanelShown: vi.fn(),
|
||||||
|
setUnpinnedTabs: vi.fn(),
|
||||||
|
setConversationMode: vi.fn(),
|
||||||
|
setDraftMessage: mockSetDraftMessage,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Default mock for getConversationState
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: null,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.useRealTimers();
|
||||||
|
vi.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("draft restoration on mount", () => {
|
||||||
|
it("restores draft from localStorage when mounting with existing draft", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-restore-1";
|
||||||
|
const savedDraft = "Previously saved draft message";
|
||||||
|
const chatInputRef = createMockChatInputRef();
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: savedDraft,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Act
|
||||||
|
renderHook(() => useDraftPersistence(conversationId, chatInputRef));
|
||||||
|
|
||||||
|
// Assert - draft should be restored to the DOM element
|
||||||
|
expect(chatInputRef.current?.textContent).toBe(savedDraft);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("clears input on mount then restores draft if exists", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-restore-2";
|
||||||
|
const existingContent = "Stale content from previous conversation";
|
||||||
|
const savedDraft = "Saved draft";
|
||||||
|
const chatInputRef = createMockChatInputRef(existingContent);
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: savedDraft,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Act
|
||||||
|
renderHook(() => useDraftPersistence(conversationId, chatInputRef));
|
||||||
|
|
||||||
|
// Assert - input cleared then draft restored
|
||||||
|
expect(chatInputRef.current?.textContent).toBe(savedDraft);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("clears input when no draft exists for conversation", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-no-draft";
|
||||||
|
const chatInputRef = createMockChatInputRef("Some stale content");
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: null,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Act
|
||||||
|
renderHook(() => useDraftPersistence(conversationId, chatInputRef));
|
||||||
|
|
||||||
|
// Assert - content should be cleared since there's no draft
|
||||||
|
expect(chatInputRef.current?.textContent).toBe("");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("debounced saving", () => {
|
||||||
|
it("saves draft after debounce period", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-debounce-1";
|
||||||
|
const chatInputRef = createMockChatInputRef();
|
||||||
|
|
||||||
|
const { result } = renderHook(() =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act - simulate user typing
|
||||||
|
chatInputRef.current!.textContent = "New draft content";
|
||||||
|
act(() => {
|
||||||
|
result.current.saveDraft();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - should not save immediately
|
||||||
|
expect(mockSetDraftMessage).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
// Fast forward past debounce period (500ms)
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(500);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - should save after debounce
|
||||||
|
expect(mockSetDraftMessage).toHaveBeenCalledWith("New draft content");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("cancels pending save when new input arrives before debounce", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-debounce-2";
|
||||||
|
const chatInputRef = createMockChatInputRef();
|
||||||
|
|
||||||
|
const { result } = renderHook(() =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act - first input
|
||||||
|
chatInputRef.current!.textContent = "First";
|
||||||
|
act(() => {
|
||||||
|
result.current.saveDraft();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait 200ms (less than debounce)
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(200);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Second input before debounce completes
|
||||||
|
chatInputRef.current!.textContent = "First Second";
|
||||||
|
act(() => {
|
||||||
|
result.current.saveDraft();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Complete the second debounce
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(500);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - should only save the final value once
|
||||||
|
expect(mockSetDraftMessage).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockSetDraftMessage).toHaveBeenCalledWith("First Second");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("does not save if content matches existing draft", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-no-change";
|
||||||
|
const existingDraft = "Existing draft";
|
||||||
|
const chatInputRef = createMockChatInputRef(existingDraft);
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({
|
||||||
|
state: {
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: existingDraft,
|
||||||
|
},
|
||||||
|
setSelectedTab: vi.fn(),
|
||||||
|
setRightPanelShown: vi.fn(),
|
||||||
|
setUnpinnedTabs: vi.fn(),
|
||||||
|
setConversationMode: vi.fn(),
|
||||||
|
setDraftMessage: mockSetDraftMessage,
|
||||||
|
});
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: existingDraft,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { result } = renderHook(() =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act - try to save same content
|
||||||
|
act(() => {
|
||||||
|
result.current.saveDraft();
|
||||||
|
});
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(500);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - should not save since content is the same
|
||||||
|
expect(mockSetDraftMessage).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("clearDraft", () => {
|
||||||
|
it("clears the draft from localStorage", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-clear-1";
|
||||||
|
const chatInputRef = createMockChatInputRef("Some content");
|
||||||
|
|
||||||
|
const { result } = renderHook(() =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
act(() => {
|
||||||
|
result.current.clearDraft();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(mockSetDraftMessage).toHaveBeenCalledWith(null);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("cancels any pending debounced save when clearing", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-clear-2";
|
||||||
|
const chatInputRef = createMockChatInputRef();
|
||||||
|
|
||||||
|
const { result } = renderHook(() =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Start a save
|
||||||
|
chatInputRef.current!.textContent = "Pending draft";
|
||||||
|
act(() => {
|
||||||
|
result.current.saveDraft();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Clear before debounce completes
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(200);
|
||||||
|
result.current.clearDraft();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Complete the original debounce period
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(500);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - only the clear should have been called (the pending save should be cancelled)
|
||||||
|
expect(mockSetDraftMessage).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockSetDraftMessage).toHaveBeenCalledWith(null);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("conversation switching", () => {
|
||||||
|
it("clears input when switching to a new conversation without a draft", () => {
|
||||||
|
// Arrange
|
||||||
|
const chatInputRef = createMockChatInputRef("Draft from conv A");
|
||||||
|
|
||||||
|
// First conversation has a draft
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState)
|
||||||
|
.mockReturnValueOnce({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: "Draft from conv A",
|
||||||
|
})
|
||||||
|
.mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: null,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { rerender } = renderHook(
|
||||||
|
({ conversationId }) =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
{ initialProps: { conversationId: "conv-A" } },
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act - switch to conversation B
|
||||||
|
rerender({ conversationId: "conv-B" });
|
||||||
|
|
||||||
|
// Assert - input should be cleared (no draft for conv-B)
|
||||||
|
expect(chatInputRef.current?.textContent).toBe("");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("restores draft when switching to a conversation with an existing draft", () => {
|
||||||
|
// Arrange
|
||||||
|
const chatInputRef = createMockChatInputRef();
|
||||||
|
const draftForConvB = "Saved draft for conversation B";
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState)
|
||||||
|
.mockReturnValueOnce({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: null,
|
||||||
|
})
|
||||||
|
.mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: draftForConvB,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { rerender } = renderHook(
|
||||||
|
({ conversationId }) =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
{ initialProps: { conversationId: "conv-A" } },
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act - switch to conversation B
|
||||||
|
rerender({ conversationId: "conv-B" });
|
||||||
|
|
||||||
|
// Assert - draft for conv-B should be restored
|
||||||
|
expect(chatInputRef.current?.textContent).toBe(draftForConvB);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("cancels pending save when switching conversations", () => {
|
||||||
|
// Arrange
|
||||||
|
const chatInputRef = createMockChatInputRef();
|
||||||
|
|
||||||
|
const { result, rerender } = renderHook(
|
||||||
|
({ conversationId }) =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
{ initialProps: { conversationId: "conv-A" } },
|
||||||
|
);
|
||||||
|
|
||||||
|
// Start typing in conv-A
|
||||||
|
chatInputRef.current!.textContent = "Draft for conv-A";
|
||||||
|
act(() => {
|
||||||
|
result.current.saveDraft();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Switch conversation before debounce completes
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(200);
|
||||||
|
});
|
||||||
|
rerender({ conversationId: "conv-B" });
|
||||||
|
|
||||||
|
// Complete the debounce period
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(500);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - the save should NOT have happened because conversation changed
|
||||||
|
expect(mockSetDraftMessage).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("task ID to real conversation ID transition", () => {
|
||||||
|
it("transfers draft from task ID to real conversation ID during transition", () => {
|
||||||
|
// Arrange
|
||||||
|
const chatInputRef = createMockChatInputRef("Draft typed during init");
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: null,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { rerender } = renderHook(
|
||||||
|
({ conversationId }) =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
{ initialProps: { conversationId: "task-abc-123" } },
|
||||||
|
);
|
||||||
|
|
||||||
|
// Simulate user typing during task initialization
|
||||||
|
chatInputRef.current!.textContent = "Draft typed during init";
|
||||||
|
|
||||||
|
// Act - transition to real conversation ID
|
||||||
|
rerender({ conversationId: "conv-real-123" });
|
||||||
|
|
||||||
|
// Assert - draft should be saved to the new real conversation ID
|
||||||
|
expect(conversationLocalStorage.setConversationState).toHaveBeenCalledWith(
|
||||||
|
"conv-real-123",
|
||||||
|
{ draftMessage: "Draft typed during init" },
|
||||||
|
);
|
||||||
|
|
||||||
|
// And the draft should remain visible in the input
|
||||||
|
expect(chatInputRef.current?.textContent).toBe("Draft typed during init");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("does not transfer empty draft during task-to-real transition", () => {
|
||||||
|
// Arrange
|
||||||
|
const chatInputRef = createMockChatInputRef("");
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: null,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { rerender } = renderHook(
|
||||||
|
({ conversationId }) =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
{ initialProps: { conversationId: "task-abc-123" } },
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act - transition to real conversation ID with empty input
|
||||||
|
rerender({ conversationId: "conv-real-123" });
|
||||||
|
|
||||||
|
// Assert - no draft should be saved (input is cleared, checked by hook)
|
||||||
|
// The setConversationState should not be called with draftMessage
|
||||||
|
expect(conversationLocalStorage.setConversationState).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("does not transfer draft for non-task ID transitions", () => {
|
||||||
|
// Arrange
|
||||||
|
const chatInputRef = createMockChatInputRef("Some draft");
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: null,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { rerender } = renderHook(
|
||||||
|
({ conversationId }) =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
{ initialProps: { conversationId: "conv-A" } },
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act - normal conversation switch (not task-to-real)
|
||||||
|
rerender({ conversationId: "conv-B" });
|
||||||
|
|
||||||
|
// Assert - should not use setConversationState directly
|
||||||
|
// (the normal path uses setDraftMessage from the hook)
|
||||||
|
expect(conversationLocalStorage.setConversationState).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("hasDraft and isRestored state", () => {
|
||||||
|
it("returns hasDraft true when draft exists in hook state", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-has-draft";
|
||||||
|
const chatInputRef = createMockChatInputRef();
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({
|
||||||
|
state: {
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: "Existing draft",
|
||||||
|
},
|
||||||
|
setSelectedTab: vi.fn(),
|
||||||
|
setRightPanelShown: vi.fn(),
|
||||||
|
setUnpinnedTabs: vi.fn(),
|
||||||
|
setConversationMode: vi.fn(),
|
||||||
|
setDraftMessage: mockSetDraftMessage,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const { result } = renderHook(() =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(result.current.hasDraft).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns hasDraft false when no draft exists", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-no-draft";
|
||||||
|
const chatInputRef = createMockChatInputRef();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const { result } = renderHook(() =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(result.current.hasDraft).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("sets isRestored to true after restoration completes", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-restored";
|
||||||
|
const chatInputRef = createMockChatInputRef();
|
||||||
|
|
||||||
|
vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({
|
||||||
|
selectedTab: "editor",
|
||||||
|
rightPanelShown: true,
|
||||||
|
unpinnedTabs: [],
|
||||||
|
conversationMode: "code",
|
||||||
|
subConversationTaskId: null,
|
||||||
|
draftMessage: "Draft to restore",
|
||||||
|
});
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const { result } = renderHook(() =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(result.current.isRestored).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("cleanup on unmount", () => {
|
||||||
|
it("clears pending timeout on unmount", () => {
|
||||||
|
// Arrange
|
||||||
|
const conversationId = "conv-unmount";
|
||||||
|
const chatInputRef = createMockChatInputRef();
|
||||||
|
|
||||||
|
const { result, unmount } = renderHook(() =>
|
||||||
|
useDraftPersistence(conversationId, chatInputRef),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Start a save
|
||||||
|
chatInputRef.current!.textContent = "Draft";
|
||||||
|
act(() => {
|
||||||
|
result.current.saveDraft();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Unmount before debounce completes
|
||||||
|
unmount();
|
||||||
|
|
||||||
|
// Complete the debounce period
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(500);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert - save should not have been called after unmount
|
||||||
|
expect(mockSetDraftMessage).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -88,6 +88,7 @@ describe("useHandlePlanClick", () => {
|
|||||||
unpinnedTabs: [],
|
unpinnedTabs: [],
|
||||||
subConversationTaskId: null,
|
subConversationTaskId: null,
|
||||||
conversationMode: "code",
|
conversationMode: "code",
|
||||||
|
draftMessage: null,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -117,6 +118,7 @@ describe("useHandlePlanClick", () => {
|
|||||||
unpinnedTabs: [],
|
unpinnedTabs: [],
|
||||||
subConversationTaskId: storedTaskId,
|
subConversationTaskId: storedTaskId,
|
||||||
conversationMode: "code",
|
conversationMode: "code",
|
||||||
|
draftMessage: null,
|
||||||
});
|
});
|
||||||
|
|
||||||
renderHook(() => useHandlePlanClick());
|
renderHook(() => useHandlePlanClick());
|
||||||
@@ -155,6 +157,7 @@ describe("useHandlePlanClick", () => {
|
|||||||
unpinnedTabs: [],
|
unpinnedTabs: [],
|
||||||
subConversationTaskId: storedTaskId,
|
subConversationTaskId: storedTaskId,
|
||||||
conversationMode: "code",
|
conversationMode: "code",
|
||||||
|
draftMessage: null,
|
||||||
});
|
});
|
||||||
|
|
||||||
renderHook(() => useHandlePlanClick());
|
renderHook(() => useHandlePlanClick());
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
/**
|
||||||
|
* Pending Message Service
|
||||||
|
*
|
||||||
|
* This service handles server-side message queuing for V1 conversations.
|
||||||
|
* Messages can be queued when the WebSocket is not connected and will be
|
||||||
|
* delivered automatically when the conversation becomes ready.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { openHands } from "../open-hands-axios";
|
||||||
|
import type {
|
||||||
|
PendingMessageResponse,
|
||||||
|
QueuePendingMessageRequest,
|
||||||
|
} from "./pending-message-service.types";
|
||||||
|
|
||||||
|
class PendingMessageService {
|
||||||
|
/**
|
||||||
|
* Queue a message for delivery when conversation becomes ready.
|
||||||
|
*
|
||||||
|
* This endpoint allows users to submit messages even when the conversation's
|
||||||
|
* WebSocket connection is not yet established. Messages are stored server-side
|
||||||
|
* and delivered automatically when the conversation transitions to READY status.
|
||||||
|
*
|
||||||
|
* @param conversationId The conversation ID (can be task ID before conversation is ready)
|
||||||
|
* @param message The message to queue
|
||||||
|
* @returns PendingMessageResponse with the message ID and queue position
|
||||||
|
* @throws Error if too many pending messages (limit: 10 per conversation)
|
||||||
|
*/
|
||||||
|
static async queueMessage(
|
||||||
|
conversationId: string,
|
||||||
|
message: QueuePendingMessageRequest,
|
||||||
|
): Promise<PendingMessageResponse> {
|
||||||
|
const { data } = await openHands.post<PendingMessageResponse>(
|
||||||
|
`/api/v1/conversations/${conversationId}/pending-messages`,
|
||||||
|
message,
|
||||||
|
);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default PendingMessageService;
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
/**
|
||||||
|
* Types for the pending message service
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { V1MessageContent } from "../conversation-service/v1-conversation-service.types";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Response when queueing a pending message
|
||||||
|
*/
|
||||||
|
export interface PendingMessageResponse {
|
||||||
|
id: string;
|
||||||
|
queued: boolean;
|
||||||
|
position: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Request to queue a pending message
|
||||||
|
*/
|
||||||
|
export interface QueuePendingMessageRequest {
|
||||||
|
role?: "user";
|
||||||
|
content: V1MessageContent[];
|
||||||
|
}
|
||||||
@@ -190,8 +190,14 @@ export function ChatInterface() {
|
|||||||
const prompt =
|
const prompt =
|
||||||
uploadedFiles.length > 0 ? `${content}\n\n${filePrompt}` : content;
|
uploadedFiles.length > 0 ? `${content}\n\n${filePrompt}` : content;
|
||||||
|
|
||||||
send(createChatMessage(prompt, imageUrls, uploadedFiles, timestamp));
|
const result = await send(
|
||||||
setOptimisticUserMessage(content);
|
createChatMessage(prompt, imageUrls, uploadedFiles, timestamp),
|
||||||
|
);
|
||||||
|
// Only show optimistic UI if message was sent immediately via WebSocket
|
||||||
|
// If queued for later delivery, the message will appear when actually delivered
|
||||||
|
if (!result.queued) {
|
||||||
|
setOptimisticUserMessage(content);
|
||||||
|
}
|
||||||
setMessageToSend("");
|
setMessageToSend("");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ export function CustomChatInput({
|
|||||||
messageToSend,
|
messageToSend,
|
||||||
checkIsContentEmpty,
|
checkIsContentEmpty,
|
||||||
clearEmptyContentHandler,
|
clearEmptyContentHandler,
|
||||||
|
saveDraft,
|
||||||
} = useChatInputLogic();
|
} = useChatInputLogic();
|
||||||
|
|
||||||
const {
|
const {
|
||||||
@@ -158,6 +159,7 @@ export function CustomChatInput({
|
|||||||
onInput={() => {
|
onInput={() => {
|
||||||
handleInput();
|
handleInput();
|
||||||
updateSlashMenu();
|
updateSlashMenu();
|
||||||
|
saveDraft();
|
||||||
}}
|
}}
|
||||||
onPaste={handlePaste}
|
onPaste={handlePaste}
|
||||||
onKeyDown={(e) => {
|
onKeyDown={(e) => {
|
||||||
|
|||||||
@@ -142,8 +142,9 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) {
|
|||||||
handleSubmit(suggestion);
|
handleSubmit(suggestion);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Allow users to submit messages during LOADING state - they will be
|
||||||
|
// queued server-side and delivered when the conversation becomes ready
|
||||||
const isDisabled =
|
const isDisabled =
|
||||||
curAgentState === AgentState.LOADING ||
|
|
||||||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION ||
|
curAgentState === AgentState.AWAITING_USER_CONFIRMATION ||
|
||||||
isTaskPolling(subConversationTaskStatus);
|
isTaskPolling(subConversationTaskStatus);
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ import type {
|
|||||||
V1SendMessageRequest,
|
V1SendMessageRequest,
|
||||||
} from "#/api/conversation-service/v1-conversation-service.types";
|
} from "#/api/conversation-service/v1-conversation-service.types";
|
||||||
import EventService from "#/api/event-service/event-service.api";
|
import EventService from "#/api/event-service/event-service.api";
|
||||||
|
import PendingMessageService from "#/api/pending-message-service/pending-message-service.api";
|
||||||
import { useConversationStore } from "#/stores/conversation-store";
|
import { useConversationStore } from "#/stores/conversation-store";
|
||||||
import { isBudgetOrCreditError, trackError } from "#/utils/error-handler";
|
import { isBudgetOrCreditError, trackError } from "#/utils/error-handler";
|
||||||
import { useTracking } from "#/hooks/use-tracking";
|
import { useTracking } from "#/hooks/use-tracking";
|
||||||
@@ -47,6 +48,7 @@ import { useReadConversationFile } from "#/hooks/mutation/use-read-conversation-
|
|||||||
import useMetricsStore from "#/stores/metrics-store";
|
import useMetricsStore from "#/stores/metrics-store";
|
||||||
import { I18nKey } from "#/i18n/declaration";
|
import { I18nKey } from "#/i18n/declaration";
|
||||||
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
|
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
|
||||||
|
import { setConversationState } from "#/utils/conversation-local-storage";
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||||
export type V1_WebSocketConnectionState =
|
export type V1_WebSocketConnectionState =
|
||||||
@@ -55,9 +57,13 @@ export type V1_WebSocketConnectionState =
|
|||||||
| "CLOSED"
|
| "CLOSED"
|
||||||
| "CLOSING";
|
| "CLOSING";
|
||||||
|
|
||||||
|
interface SendMessageResult {
|
||||||
|
queued: boolean; // true if message was queued for later delivery, false if sent immediately
|
||||||
|
}
|
||||||
|
|
||||||
interface ConversationWebSocketContextType {
|
interface ConversationWebSocketContextType {
|
||||||
connectionState: V1_WebSocketConnectionState;
|
connectionState: V1_WebSocketConnectionState;
|
||||||
sendMessage: (message: V1SendMessageRequest) => Promise<void>;
|
sendMessage: (message: V1SendMessageRequest) => Promise<SendMessageResult>;
|
||||||
isLoadingHistory: boolean;
|
isLoadingHistory: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -397,6 +403,10 @@ export function ConversationWebSocketProvider({
|
|||||||
// Clear optimistic user message when a user message is confirmed
|
// Clear optimistic user message when a user message is confirmed
|
||||||
if (isUserMessageEvent(event)) {
|
if (isUserMessageEvent(event)) {
|
||||||
removeOptimisticUserMessage();
|
removeOptimisticUserMessage();
|
||||||
|
// Clear draft from localStorage - message was successfully delivered
|
||||||
|
if (conversationId) {
|
||||||
|
setConversationState(conversationId, { draftMessage: null });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle cache invalidation for ActionEvent
|
// Handle cache invalidation for ActionEvent
|
||||||
@@ -556,6 +566,11 @@ export function ConversationWebSocketProvider({
|
|||||||
// Clear optimistic user message when a user message is confirmed
|
// Clear optimistic user message when a user message is confirmed
|
||||||
if (isUserMessageEvent(event)) {
|
if (isUserMessageEvent(event)) {
|
||||||
removeOptimisticUserMessage();
|
removeOptimisticUserMessage();
|
||||||
|
// Clear draft from localStorage - message was successfully delivered
|
||||||
|
// Use main conversationId since user types in main conversation input
|
||||||
|
if (conversationId) {
|
||||||
|
setConversationState(conversationId, { draftMessage: null });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle cache invalidation for ActionEvent
|
// Handle cache invalidation for ActionEvent
|
||||||
@@ -810,21 +825,44 @@ export function ConversationWebSocketProvider({
|
|||||||
);
|
);
|
||||||
|
|
||||||
// V1 send message function via WebSocket
|
// V1 send message function via WebSocket
|
||||||
|
// Falls back to REST API queue when WebSocket is not connected
|
||||||
const sendMessage = useCallback(
|
const sendMessage = useCallback(
|
||||||
async (message: V1SendMessageRequest) => {
|
async (message: V1SendMessageRequest): Promise<SendMessageResult> => {
|
||||||
const currentMode = useConversationStore.getState().conversationMode;
|
const currentMode = useConversationStore.getState().conversationMode;
|
||||||
const currentSocket =
|
const currentSocket =
|
||||||
currentMode === "plan" ? planningAgentSocket : mainSocket;
|
currentMode === "plan" ? planningAgentSocket : mainSocket;
|
||||||
|
|
||||||
if (!currentSocket || currentSocket.readyState !== WebSocket.OPEN) {
|
if (!currentSocket || currentSocket.readyState !== WebSocket.OPEN) {
|
||||||
const error = "WebSocket is not connected";
|
// WebSocket not connected - queue message via REST API
|
||||||
setErrorMessage(error);
|
// Message will be delivered automatically when conversation becomes ready
|
||||||
throw new Error(error);
|
if (!conversationId) {
|
||||||
|
const error = new Error("No conversation ID available");
|
||||||
|
setErrorMessage(error.message);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await PendingMessageService.queueMessage(conversationId, {
|
||||||
|
role: "user",
|
||||||
|
content: message.content,
|
||||||
|
});
|
||||||
|
// Message queued successfully - it will be delivered when ready
|
||||||
|
// Return queued: true so caller knows not to show optimistic UI
|
||||||
|
return { queued: true };
|
||||||
|
} catch (error) {
|
||||||
|
const errorMessage =
|
||||||
|
error instanceof Error
|
||||||
|
? error.message
|
||||||
|
: "Failed to queue message for delivery";
|
||||||
|
setErrorMessage(errorMessage);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Send message through WebSocket as JSON
|
// Send message through WebSocket as JSON
|
||||||
currentSocket.send(JSON.stringify(message));
|
currentSocket.send(JSON.stringify(message));
|
||||||
|
return { queued: false };
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const errorMessage =
|
const errorMessage =
|
||||||
error instanceof Error ? error.message : "Failed to send message";
|
error instanceof Error ? error.message : "Failed to send message";
|
||||||
@@ -832,7 +870,7 @@ export function ConversationWebSocketProvider({
|
|||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[mainSocket, planningAgentSocket, setErrorMessage],
|
[mainSocket, planningAgentSocket, setErrorMessage, conversationId],
|
||||||
);
|
);
|
||||||
|
|
||||||
// Track main socket state changes
|
// Track main socket state changes
|
||||||
|
|||||||
@@ -5,12 +5,15 @@ import {
|
|||||||
getTextContent,
|
getTextContent,
|
||||||
} from "#/components/features/chat/utils/chat-input.utils";
|
} from "#/components/features/chat/utils/chat-input.utils";
|
||||||
import { useConversationStore } from "#/stores/conversation-store";
|
import { useConversationStore } from "#/stores/conversation-store";
|
||||||
|
import { useConversationId } from "#/hooks/use-conversation-id";
|
||||||
|
import { useDraftPersistence } from "./use-draft-persistence";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hook for managing chat input content logic
|
* Hook for managing chat input content logic
|
||||||
*/
|
*/
|
||||||
export const useChatInputLogic = () => {
|
export const useChatInputLogic = () => {
|
||||||
const chatInputRef = useRef<HTMLDivElement | null>(null);
|
const chatInputRef = useRef<HTMLDivElement | null>(null);
|
||||||
|
const { conversationId } = useConversationId();
|
||||||
|
|
||||||
const {
|
const {
|
||||||
messageToSend,
|
messageToSend,
|
||||||
@@ -19,6 +22,12 @@ export const useChatInputLogic = () => {
|
|||||||
setIsRightPanelShown,
|
setIsRightPanelShown,
|
||||||
} = useConversationStore();
|
} = useConversationStore();
|
||||||
|
|
||||||
|
// Draft persistence - saves to localStorage, restores on mount
|
||||||
|
const { saveDraft, clearDraft } = useDraftPersistence(
|
||||||
|
conversationId,
|
||||||
|
chatInputRef,
|
||||||
|
);
|
||||||
|
|
||||||
// Save current input value when drawer state changes
|
// Save current input value when drawer state changes
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (chatInputRef.current) {
|
if (chatInputRef.current) {
|
||||||
@@ -51,5 +60,7 @@ export const useChatInputLogic = () => {
|
|||||||
checkIsContentEmpty,
|
checkIsContentEmpty,
|
||||||
clearEmptyContentHandler,
|
clearEmptyContentHandler,
|
||||||
getCurrentMessage,
|
getCurrentMessage,
|
||||||
|
saveDraft,
|
||||||
|
clearDraft,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
179
frontend/src/hooks/chat/use-draft-persistence.ts
Normal file
179
frontend/src/hooks/chat/use-draft-persistence.ts
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
import { useEffect, useRef, useCallback, useState } from "react";
|
||||||
|
import {
|
||||||
|
useConversationLocalStorageState,
|
||||||
|
getConversationState,
|
||||||
|
setConversationState,
|
||||||
|
} from "#/utils/conversation-local-storage";
|
||||||
|
import { getTextContent } from "#/components/features/chat/utils/chat-input.utils";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a conversation ID is a temporary task ID.
|
||||||
|
* Task IDs have the format "task-{uuid}" and are used during V1 conversation initialization.
|
||||||
|
*/
|
||||||
|
const isTaskId = (id: string): boolean => id.startsWith("task-");
|
||||||
|
|
||||||
|
const DRAFT_SAVE_DEBOUNCE_MS = 500;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hook for persisting draft messages to localStorage.
|
||||||
|
* Handles debounced saving on input, restoration on mount, and clearing on confirmed delivery.
|
||||||
|
*/
|
||||||
|
export const useDraftPersistence = (
|
||||||
|
conversationId: string,
|
||||||
|
chatInputRef: React.RefObject<HTMLDivElement | null>,
|
||||||
|
) => {
|
||||||
|
const { state, setDraftMessage } =
|
||||||
|
useConversationLocalStorageState(conversationId);
|
||||||
|
const saveTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||||
|
const hasRestoredRef = useRef(false);
|
||||||
|
const [isRestored, setIsRestored] = useState(false);
|
||||||
|
|
||||||
|
// Track current conversationId to prevent saving draft to wrong conversation
|
||||||
|
const currentConversationIdRef = useRef(conversationId);
|
||||||
|
// Track if this is the first mount to handle initial cleanup
|
||||||
|
const isFirstMountRef = useRef(true);
|
||||||
|
|
||||||
|
// IMPORTANT: This effect must run FIRST when conversation changes.
|
||||||
|
// It handles three concerns:
|
||||||
|
// 1. Cleanup: Cancel pending saves from previous conversation
|
||||||
|
// 2. Task-to-real transition: Preserve draft typed during initialization
|
||||||
|
// 3. DOM reset: Clear stale content before restoration effect runs
|
||||||
|
useEffect(() => {
|
||||||
|
const previousConversationId = currentConversationIdRef.current;
|
||||||
|
const isInitialMount = isFirstMountRef.current;
|
||||||
|
currentConversationIdRef.current = conversationId;
|
||||||
|
isFirstMountRef.current = false;
|
||||||
|
|
||||||
|
// --- 1. Cancel pending saves from previous conversation ---
|
||||||
|
// Prevents draft from being saved to wrong conversation if user switched quickly
|
||||||
|
if (saveTimeoutRef.current) {
|
||||||
|
clearTimeout(saveTimeoutRef.current);
|
||||||
|
saveTimeoutRef.current = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const element = chatInputRef.current;
|
||||||
|
|
||||||
|
// --- 2. Handle task-to-real ID transition (preserve draft during initialization) ---
|
||||||
|
// When a new V1 conversation initializes, it starts with a temporary "task-xxx" ID
|
||||||
|
// that transitions to a real conversation ID once ready. Task IDs don't persist
|
||||||
|
// to localStorage, so any draft typed during this phase would be lost.
|
||||||
|
// We detect this transition and transfer the draft to the new real ID.
|
||||||
|
if (!isInitialMount && previousConversationId !== conversationId) {
|
||||||
|
const wasTaskId = isTaskId(previousConversationId);
|
||||||
|
const isNowRealId = !isTaskId(conversationId);
|
||||||
|
|
||||||
|
if (wasTaskId && isNowRealId && element) {
|
||||||
|
const currentText = getTextContent(element).trim();
|
||||||
|
if (currentText) {
|
||||||
|
// Transfer draft to the new (real) conversation ID
|
||||||
|
setConversationState(conversationId, { draftMessage: currentText });
|
||||||
|
// Keep draft visible in DOM and mark as restored to prevent overwrite
|
||||||
|
hasRestoredRef.current = true;
|
||||||
|
setIsRestored(true);
|
||||||
|
return; // Skip normal cleanup - draft is already in correct state
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- 3. Clear stale DOM content (will be restored by next effect if draft exists) ---
|
||||||
|
// This prevents stale drafts from appearing in new conversations due to:
|
||||||
|
// - Browser form restoration on back/forward navigation
|
||||||
|
// - React DOM recycling between conversation switches
|
||||||
|
// The restoration effect will then populate with the correct saved draft
|
||||||
|
if (element) {
|
||||||
|
element.textContent = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset restoration flag so the restoration effect will run for new conversation
|
||||||
|
hasRestoredRef.current = false;
|
||||||
|
setIsRestored(false);
|
||||||
|
}, [conversationId, chatInputRef]);
|
||||||
|
|
||||||
|
// Restore draft from localStorage - reads directly to avoid state sync timing issues
|
||||||
|
useEffect(() => {
|
||||||
|
if (hasRestoredRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const element = chatInputRef.current;
|
||||||
|
if (!element) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read directly from localStorage to avoid stale state from useConversationLocalStorageState
|
||||||
|
// The hook's state may not have synced yet after conversationId change
|
||||||
|
const { draftMessage } = getConversationState(conversationId);
|
||||||
|
|
||||||
|
// Only restore if there's a saved draft and the input is empty
|
||||||
|
if (draftMessage && getTextContent(element).trim() === "") {
|
||||||
|
element.textContent = draftMessage;
|
||||||
|
// Move cursor to end
|
||||||
|
const selection = window.getSelection();
|
||||||
|
const range = document.createRange();
|
||||||
|
range.selectNodeContents(element);
|
||||||
|
range.collapse(false);
|
||||||
|
selection?.removeAllRanges();
|
||||||
|
selection?.addRange(range);
|
||||||
|
}
|
||||||
|
|
||||||
|
hasRestoredRef.current = true;
|
||||||
|
setIsRestored(true);
|
||||||
|
}, [chatInputRef, conversationId]);
|
||||||
|
|
||||||
|
// Debounced save function - called from onInput handler
|
||||||
|
const saveDraft = useCallback(() => {
|
||||||
|
// Clear any pending save
|
||||||
|
if (saveTimeoutRef.current) {
|
||||||
|
clearTimeout(saveTimeoutRef.current);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture the conversationId at the time of input
|
||||||
|
const capturedConversationId = conversationId;
|
||||||
|
|
||||||
|
saveTimeoutRef.current = setTimeout(() => {
|
||||||
|
// Verify we're still on the same conversation before saving
|
||||||
|
// This prevents saving draft to wrong conversation if user switched quickly
|
||||||
|
if (capturedConversationId !== currentConversationIdRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const element = chatInputRef.current;
|
||||||
|
if (!element) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const text = getTextContent(element).trim();
|
||||||
|
// Only save if content has changed
|
||||||
|
if (text !== (state.draftMessage || "")) {
|
||||||
|
setDraftMessage(text || null);
|
||||||
|
}
|
||||||
|
}, DRAFT_SAVE_DEBOUNCE_MS);
|
||||||
|
}, [chatInputRef, state.draftMessage, setDraftMessage, conversationId]);
|
||||||
|
|
||||||
|
// Clear draft - called after message delivery is confirmed
|
||||||
|
const clearDraft = useCallback(() => {
|
||||||
|
// Cancel any pending save
|
||||||
|
if (saveTimeoutRef.current) {
|
||||||
|
clearTimeout(saveTimeoutRef.current);
|
||||||
|
saveTimeoutRef.current = null;
|
||||||
|
}
|
||||||
|
setDraftMessage(null);
|
||||||
|
}, [setDraftMessage]);
|
||||||
|
|
||||||
|
// Cleanup timeout on unmount
|
||||||
|
useEffect(
|
||||||
|
() => () => {
|
||||||
|
if (saveTimeoutRef.current) {
|
||||||
|
clearTimeout(saveTimeoutRef.current);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
saveDraft,
|
||||||
|
clearDraft,
|
||||||
|
isRestored,
|
||||||
|
hasDraft: !!state.draftMessage,
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -5,6 +5,10 @@ import { useConversationWebSocket } from "#/contexts/conversation-websocket-cont
|
|||||||
import { useConversationId } from "#/hooks/use-conversation-id";
|
import { useConversationId } from "#/hooks/use-conversation-id";
|
||||||
import { V1MessageContent } from "#/api/conversation-service/v1-conversation-service.types";
|
import { V1MessageContent } from "#/api/conversation-service/v1-conversation-service.types";
|
||||||
|
|
||||||
|
interface SendResult {
|
||||||
|
queued: boolean; // true if message was queued for later delivery
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unified hook for sending messages that works with both V0 and V1 conversations
|
* Unified hook for sending messages that works with both V0 and V1 conversations
|
||||||
* - For V0 conversations: Uses Socket.IO WebSocket via useWsClient
|
* - For V0 conversations: Uses Socket.IO WebSocket via useWsClient
|
||||||
@@ -26,7 +30,7 @@ export function useSendMessage() {
|
|||||||
conversation?.conversation_version === "V1";
|
conversation?.conversation_version === "V1";
|
||||||
|
|
||||||
const send = useCallback(
|
const send = useCallback(
|
||||||
async (event: Record<string, unknown>) => {
|
async (event: Record<string, unknown>): Promise<SendResult> => {
|
||||||
if (isV1Conversation && v1Context) {
|
if (isV1Conversation && v1Context) {
|
||||||
// V1: Convert V0 event format to V1 message format
|
// V1: Convert V0 event format to V1 message format
|
||||||
const { action, args } = event as {
|
const { action, args } = event as {
|
||||||
@@ -57,19 +61,20 @@ export function useSendMessage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send via V1 WebSocket context (uses correct host/port)
|
// Send via V1 WebSocket context (uses correct host/port)
|
||||||
await v1Context.sendMessage({
|
const result = await v1Context.sendMessage({
|
||||||
role: "user",
|
role: "user",
|
||||||
content,
|
content,
|
||||||
});
|
});
|
||||||
} else {
|
return result;
|
||||||
// For non-message events, fall back to V0 send
|
|
||||||
// (e.g., agent state changes, other control events)
|
|
||||||
v0Send(event);
|
|
||||||
}
|
}
|
||||||
} else {
|
// For non-message events, fall back to V0 send
|
||||||
// V0: Use Socket.IO
|
// (e.g., agent state changes, other control events)
|
||||||
v0Send(event);
|
v0Send(event);
|
||||||
|
return { queued: false };
|
||||||
}
|
}
|
||||||
|
// V0: Use Socket.IO
|
||||||
|
v0Send(event);
|
||||||
|
return { queued: false };
|
||||||
},
|
},
|
||||||
[isV1Conversation, v1Context, v0Send, conversationId],
|
[isV1Conversation, v1Context, v0Send, conversationId],
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ export interface ConversationState {
|
|||||||
unpinnedTabs: string[];
|
unpinnedTabs: string[];
|
||||||
conversationMode: ConversationMode;
|
conversationMode: ConversationMode;
|
||||||
subConversationTaskId: string | null;
|
subConversationTaskId: string | null;
|
||||||
|
draftMessage: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_CONVERSATION_STATE: ConversationState = {
|
const DEFAULT_CONVERSATION_STATE: ConversationState = {
|
||||||
@@ -31,6 +32,7 @@ const DEFAULT_CONVERSATION_STATE: ConversationState = {
|
|||||||
unpinnedTabs: [],
|
unpinnedTabs: [],
|
||||||
conversationMode: "code",
|
conversationMode: "code",
|
||||||
subConversationTaskId: null,
|
subConversationTaskId: null,
|
||||||
|
draftMessage: null,
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -121,6 +123,7 @@ export function useConversationLocalStorageState(conversationId: string): {
|
|||||||
setRightPanelShown: (shown: boolean) => void;
|
setRightPanelShown: (shown: boolean) => void;
|
||||||
setUnpinnedTabs: (tabs: string[]) => void;
|
setUnpinnedTabs: (tabs: string[]) => void;
|
||||||
setConversationMode: (mode: ConversationMode) => void;
|
setConversationMode: (mode: ConversationMode) => void;
|
||||||
|
setDraftMessage: (message: string | null) => void;
|
||||||
} {
|
} {
|
||||||
const [state, setState] = useState<ConversationState>(() =>
|
const [state, setState] = useState<ConversationState>(() =>
|
||||||
getConversationState(conversationId),
|
getConversationState(conversationId),
|
||||||
@@ -178,5 +181,6 @@ export function useConversationLocalStorageState(conversationId: string): {
|
|||||||
setRightPanelShown: (shown) => updateState({ rightPanelShown: shown }),
|
setRightPanelShown: (shown) => updateState({ rightPanelShown: shown }),
|
||||||
setUnpinnedTabs: (tabs) => updateState({ unpinnedTabs: tabs }),
|
setUnpinnedTabs: (tabs) => updateState({ unpinnedTabs: tabs }),
|
||||||
setConversationMode: (mode) => updateState({ conversationMode: mode }),
|
setConversationMode: (mode) => updateState({ conversationMode: mode }),
|
||||||
|
setDraftMessage: (message) => updateState({ draftMessage: message }),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,9 @@ from openhands.app_server.event_callback.event_callback_service import (
|
|||||||
from openhands.app_server.event_callback.set_title_callback_processor import (
|
from openhands.app_server.event_callback.set_title_callback_processor import (
|
||||||
SetTitleCallbackProcessor,
|
SetTitleCallbackProcessor,
|
||||||
)
|
)
|
||||||
|
from openhands.app_server.pending_messages.pending_message_service import (
|
||||||
|
PendingMessageService,
|
||||||
|
)
|
||||||
from openhands.app_server.sandbox.docker_sandbox_service import DockerSandboxService
|
from openhands.app_server.sandbox.docker_sandbox_service import DockerSandboxService
|
||||||
from openhands.app_server.sandbox.sandbox_models import (
|
from openhands.app_server.sandbox.sandbox_models import (
|
||||||
AGENT_SERVER,
|
AGENT_SERVER,
|
||||||
@@ -127,6 +130,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
|||||||
sandbox_service: SandboxService
|
sandbox_service: SandboxService
|
||||||
sandbox_spec_service: SandboxSpecService
|
sandbox_spec_service: SandboxSpecService
|
||||||
jwt_service: JwtService
|
jwt_service: JwtService
|
||||||
|
pending_message_service: PendingMessageService
|
||||||
sandbox_startup_timeout: int
|
sandbox_startup_timeout: int
|
||||||
sandbox_startup_poll_frequency: int
|
sandbox_startup_poll_frequency: int
|
||||||
max_num_conversations_per_sandbox: int
|
max_num_conversations_per_sandbox: int
|
||||||
@@ -373,6 +377,15 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
|||||||
task.app_conversation_id = info.id
|
task.app_conversation_id = info.id
|
||||||
yield task
|
yield task
|
||||||
|
|
||||||
|
# Process any pending messages queued while waiting for conversation
|
||||||
|
if sandbox.session_api_key:
|
||||||
|
await self._process_pending_messages(
|
||||||
|
task_id=task.id,
|
||||||
|
conversation_id=info.id,
|
||||||
|
agent_server_url=agent_server_url,
|
||||||
|
session_api_key=sandbox.session_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
_logger.exception('Error starting conversation', stack_info=True)
|
_logger.exception('Error starting conversation', stack_info=True)
|
||||||
task.status = AppConversationStartTaskStatus.ERROR
|
task.status = AppConversationStartTaskStatus.ERROR
|
||||||
@@ -1424,6 +1437,89 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
|||||||
plugins=plugins,
|
plugins=plugins,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _process_pending_messages(
|
||||||
|
self,
|
||||||
|
task_id: UUID,
|
||||||
|
conversation_id: UUID,
|
||||||
|
agent_server_url: str,
|
||||||
|
session_api_key: str,
|
||||||
|
) -> None:
|
||||||
|
"""Process pending messages queued before conversation was ready.
|
||||||
|
|
||||||
|
Messages are delivered concurrently to the agent server. After processing,
|
||||||
|
all messages are deleted from the database regardless of success or failure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The start task ID (may have been used as conversation_id initially)
|
||||||
|
conversation_id: The real conversation ID
|
||||||
|
agent_server_url: URL of the agent server
|
||||||
|
session_api_key: API key for authenticating with agent server
|
||||||
|
"""
|
||||||
|
# Convert UUIDs to strings for the pending message service
|
||||||
|
# The frontend uses task-{uuid.hex} format (no hyphens), matching OpenHandsUUID serialization
|
||||||
|
task_id_str = f'task-{task_id.hex}'
|
||||||
|
# conversation_id uses standard format (with hyphens) for agent server API compatibility
|
||||||
|
conversation_id_str = str(conversation_id)
|
||||||
|
|
||||||
|
_logger.info(f'task_id={task_id_str} conversation_id={conversation_id_str}')
|
||||||
|
|
||||||
|
# First, update any messages that were queued with the task_id
|
||||||
|
updated_count = await self.pending_message_service.update_conversation_id(
|
||||||
|
old_conversation_id=task_id_str,
|
||||||
|
new_conversation_id=conversation_id_str,
|
||||||
|
)
|
||||||
|
_logger.info(f'updated_count={updated_count} ')
|
||||||
|
if updated_count > 0:
|
||||||
|
_logger.info(
|
||||||
|
f'Updated {updated_count} pending messages from task_id={task_id_str} '
|
||||||
|
f'to conversation_id={conversation_id_str}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all pending messages for this conversation
|
||||||
|
pending_messages = await self.pending_message_service.get_pending_messages(
|
||||||
|
conversation_id_str
|
||||||
|
)
|
||||||
|
|
||||||
|
if not pending_messages:
|
||||||
|
return
|
||||||
|
|
||||||
|
_logger.info(
|
||||||
|
f'Processing {len(pending_messages)} pending messages for '
|
||||||
|
f'conversation {conversation_id_str}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process messages sequentially to preserve order
|
||||||
|
for msg in pending_messages:
|
||||||
|
try:
|
||||||
|
# Serialize content objects to JSON-compatible dicts
|
||||||
|
content_json = [item.model_dump() for item in msg.content]
|
||||||
|
# Use the events endpoint which handles message sending
|
||||||
|
response = await self.httpx_client.post(
|
||||||
|
f'{agent_server_url}/api/conversations/{conversation_id_str}/events',
|
||||||
|
json={
|
||||||
|
'role': msg.role,
|
||||||
|
'content': content_json,
|
||||||
|
'run': True,
|
||||||
|
},
|
||||||
|
headers={'X-Session-API-Key': session_api_key},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
_logger.debug(f'Delivered pending message {msg.id}')
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f'Failed to deliver pending message {msg.id}: {e}')
|
||||||
|
|
||||||
|
# Delete all pending messages after processing (regardless of success/failure)
|
||||||
|
deleted_count = (
|
||||||
|
await self.pending_message_service.delete_messages_for_conversation(
|
||||||
|
conversation_id_str
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_logger.info(
|
||||||
|
f'Finished processing pending messages for conversation {conversation_id_str}. '
|
||||||
|
f'Deleted {deleted_count} messages.'
|
||||||
|
)
|
||||||
|
|
||||||
async def update_agent_server_conversation_title(
|
async def update_agent_server_conversation_title(
|
||||||
self,
|
self,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
@@ -1796,6 +1892,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
|||||||
get_global_config,
|
get_global_config,
|
||||||
get_httpx_client,
|
get_httpx_client,
|
||||||
get_jwt_service,
|
get_jwt_service,
|
||||||
|
get_pending_message_service,
|
||||||
get_sandbox_service,
|
get_sandbox_service,
|
||||||
get_sandbox_spec_service,
|
get_sandbox_spec_service,
|
||||||
get_user_context,
|
get_user_context,
|
||||||
@@ -1815,6 +1912,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
|||||||
get_event_service(state, request) as event_service,
|
get_event_service(state, request) as event_service,
|
||||||
get_jwt_service(state, request) as jwt_service,
|
get_jwt_service(state, request) as jwt_service,
|
||||||
get_httpx_client(state, request) as httpx_client,
|
get_httpx_client(state, request) as httpx_client,
|
||||||
|
get_pending_message_service(state, request) as pending_message_service,
|
||||||
):
|
):
|
||||||
access_token_hard_timeout = None
|
access_token_hard_timeout = None
|
||||||
if self.access_token_hard_timeout:
|
if self.access_token_hard_timeout:
|
||||||
@@ -1859,6 +1957,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
|||||||
event_callback_service=event_callback_service,
|
event_callback_service=event_callback_service,
|
||||||
event_service=event_service,
|
event_service=event_service,
|
||||||
jwt_service=jwt_service,
|
jwt_service=jwt_service,
|
||||||
|
pending_message_service=pending_message_service,
|
||||||
sandbox_startup_timeout=self.sandbox_startup_timeout,
|
sandbox_startup_timeout=self.sandbox_startup_timeout,
|
||||||
sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency,
|
sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency,
|
||||||
max_num_conversations_per_sandbox=self.max_num_conversations_per_sandbox,
|
max_num_conversations_per_sandbox=self.max_num_conversations_per_sandbox,
|
||||||
|
|||||||
39
openhands/app_server/app_lifespan/alembic/versions/007.py
Normal file
39
openhands/app_server/app_lifespan/alembic/versions/007.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""Add pending_messages table for server-side message queuing
|
||||||
|
|
||||||
|
Revision ID: 007
|
||||||
|
Revises: 006
|
||||||
|
Create Date: 2025-03-15 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '007'
|
||||||
|
down_revision: Union[str, None] = '006'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Create pending_messages table for storing messages before conversation is ready.
|
||||||
|
|
||||||
|
Messages are stored temporarily until the conversation becomes ready, then
|
||||||
|
delivered and deleted regardless of success or failure.
|
||||||
|
"""
|
||||||
|
op.create_table(
|
||||||
|
'pending_messages',
|
||||||
|
sa.Column('id', sa.String(), primary_key=True),
|
||||||
|
sa.Column('conversation_id', sa.String(), nullable=False, index=True),
|
||||||
|
sa.Column('role', sa.String(20), nullable=False, server_default='user'),
|
||||||
|
sa.Column('content', sa.JSON, nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove pending_messages table."""
|
||||||
|
op.drop_table('pending_messages')
|
||||||
@@ -33,6 +33,10 @@ from openhands.app_server.event_callback.event_callback_service import (
|
|||||||
EventCallbackService,
|
EventCallbackService,
|
||||||
EventCallbackServiceInjector,
|
EventCallbackServiceInjector,
|
||||||
)
|
)
|
||||||
|
from openhands.app_server.pending_messages.pending_message_service import (
|
||||||
|
PendingMessageService,
|
||||||
|
PendingMessageServiceInjector,
|
||||||
|
)
|
||||||
from openhands.app_server.sandbox.sandbox_service import (
|
from openhands.app_server.sandbox.sandbox_service import (
|
||||||
SandboxService,
|
SandboxService,
|
||||||
SandboxServiceInjector,
|
SandboxServiceInjector,
|
||||||
@@ -114,6 +118,7 @@ class AppServerConfig(OpenHandsModel):
|
|||||||
app_conversation_info: AppConversationInfoServiceInjector | None = None
|
app_conversation_info: AppConversationInfoServiceInjector | None = None
|
||||||
app_conversation_start_task: AppConversationStartTaskServiceInjector | None = None
|
app_conversation_start_task: AppConversationStartTaskServiceInjector | None = None
|
||||||
app_conversation: AppConversationServiceInjector | None = None
|
app_conversation: AppConversationServiceInjector | None = None
|
||||||
|
pending_message: PendingMessageServiceInjector | None = None
|
||||||
user: UserContextInjector | None = None
|
user: UserContextInjector | None = None
|
||||||
jwt: JwtServiceInjector | None = None
|
jwt: JwtServiceInjector | None = None
|
||||||
httpx: HttpxClientInjector = Field(default_factory=HttpxClientInjector)
|
httpx: HttpxClientInjector = Field(default_factory=HttpxClientInjector)
|
||||||
@@ -280,6 +285,13 @@ def config_from_env() -> AppServerConfig:
|
|||||||
tavily_api_key=tavily_api_key
|
tavily_api_key=tavily_api_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.pending_message is None:
|
||||||
|
from openhands.app_server.pending_messages.pending_message_service import (
|
||||||
|
SQLPendingMessageServiceInjector,
|
||||||
|
)
|
||||||
|
|
||||||
|
config.pending_message = SQLPendingMessageServiceInjector()
|
||||||
|
|
||||||
if config.user is None:
|
if config.user is None:
|
||||||
config.user = AuthUserContextInjector()
|
config.user = AuthUserContextInjector()
|
||||||
|
|
||||||
@@ -358,6 +370,14 @@ def get_app_conversation_service(
|
|||||||
return injector.context(state, request)
|
return injector.context(state, request)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pending_message_service(
|
||||||
|
state: InjectorState, request: Request | None = None
|
||||||
|
) -> AsyncContextManager[PendingMessageService]:
|
||||||
|
injector = get_global_config().pending_message
|
||||||
|
assert injector is not None
|
||||||
|
return injector.context(state, request)
|
||||||
|
|
||||||
|
|
||||||
def get_user_context(
|
def get_user_context(
|
||||||
state: InjectorState, request: Request | None = None
|
state: InjectorState, request: Request | None = None
|
||||||
) -> AsyncContextManager[UserContext]:
|
) -> AsyncContextManager[UserContext]:
|
||||||
@@ -433,6 +453,12 @@ def depends_app_conversation_service():
|
|||||||
return Depends(injector.depends)
|
return Depends(injector.depends)
|
||||||
|
|
||||||
|
|
||||||
|
def depends_pending_message_service():
|
||||||
|
injector = get_global_config().pending_message
|
||||||
|
assert injector is not None
|
||||||
|
return Depends(injector.depends)
|
||||||
|
|
||||||
|
|
||||||
def depends_user_context():
|
def depends_user_context():
|
||||||
injector = get_global_config().user
|
injector = get_global_config().user
|
||||||
assert injector is not None
|
assert injector is not None
|
||||||
|
|||||||
21
openhands/app_server/pending_messages/__init__.py
Normal file
21
openhands/app_server/pending_messages/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Pending messages module for server-side message queuing."""
|
||||||
|
|
||||||
|
from openhands.app_server.pending_messages.pending_message_models import (
|
||||||
|
PendingMessage,
|
||||||
|
PendingMessageResponse,
|
||||||
|
)
|
||||||
|
from openhands.app_server.pending_messages.pending_message_service import (
|
||||||
|
PendingMessageService,
|
||||||
|
PendingMessageServiceInjector,
|
||||||
|
SQLPendingMessageService,
|
||||||
|
SQLPendingMessageServiceInjector,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'PendingMessage',
|
||||||
|
'PendingMessageResponse',
|
||||||
|
'PendingMessageService',
|
||||||
|
'PendingMessageServiceInjector',
|
||||||
|
'SQLPendingMessageService',
|
||||||
|
'SQLPendingMessageServiceInjector',
|
||||||
|
]
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
"""Models for pending message queue functionality."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from openhands.agent_server.models import ImageContent, TextContent
|
||||||
|
from openhands.agent_server.utils import utc_now
|
||||||
|
|
||||||
|
|
||||||
|
class PendingMessage(BaseModel):
|
||||||
|
"""A message queued for delivery when conversation becomes ready.
|
||||||
|
|
||||||
|
Pending messages are stored in the database and delivered to the agent_server
|
||||||
|
when the conversation transitions to READY status. Messages are deleted after
|
||||||
|
processing, regardless of success or failure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
|
conversation_id: str # Can be task-{uuid} or real conversation UUID
|
||||||
|
role: str = 'user'
|
||||||
|
content: list[TextContent | ImageContent]
|
||||||
|
created_at: datetime = Field(default_factory=utc_now)
|
||||||
|
|
||||||
|
|
||||||
|
class PendingMessageResponse(BaseModel):
|
||||||
|
"""Response when queueing a pending message."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
queued: bool
|
||||||
|
position: int = Field(description='Position in the queue (1-based)')
|
||||||
104
openhands/app_server/pending_messages/pending_message_router.py
Normal file
104
openhands/app_server/pending_messages/pending_message_router.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""REST API router for pending messages."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request, status
|
||||||
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
|
from openhands.agent_server.models import ImageContent, TextContent
|
||||||
|
from openhands.app_server.config import depends_pending_message_service
|
||||||
|
from openhands.app_server.pending_messages.pending_message_models import (
|
||||||
|
PendingMessageResponse,
|
||||||
|
)
|
||||||
|
from openhands.app_server.pending_messages.pending_message_service import (
|
||||||
|
PendingMessageService,
|
||||||
|
)
|
||||||
|
from openhands.server.dependencies import get_dependencies
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type adapter for validating content from request
|
||||||
|
_content_type_adapter = TypeAdapter(list[TextContent | ImageContent])
|
||||||
|
|
||||||
|
# Create router with authentication dependencies
|
||||||
|
router = APIRouter(
|
||||||
|
prefix='/conversations/{conversation_id}/pending-messages',
|
||||||
|
tags=['Pending Messages'],
|
||||||
|
dependencies=get_dependencies(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create dependency at module level
|
||||||
|
pending_message_service_dependency = depends_pending_message_service()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
'', response_model=PendingMessageResponse, status_code=status.HTTP_201_CREATED
|
||||||
|
)
|
||||||
|
async def queue_pending_message(
|
||||||
|
conversation_id: str,
|
||||||
|
request: Request,
|
||||||
|
pending_service: PendingMessageService = pending_message_service_dependency,
|
||||||
|
) -> PendingMessageResponse:
|
||||||
|
"""Queue a message for delivery when conversation becomes ready.
|
||||||
|
|
||||||
|
This endpoint allows users to submit messages even when the conversation's
|
||||||
|
WebSocket connection is not yet established. Messages are stored server-side
|
||||||
|
and delivered automatically when the conversation transitions to READY status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The conversation ID (can be task ID before conversation is ready)
|
||||||
|
request: The FastAPI request containing message content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PendingMessageResponse with the message ID and queue position
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 400: If the request body is invalid
|
||||||
|
HTTPException 429: If too many pending messages are queued (limit: 10)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail='Invalid request body',
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_content = body.get('content')
|
||||||
|
role = body.get('role', 'user')
|
||||||
|
|
||||||
|
if not raw_content or not isinstance(raw_content, list):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail='content must be a non-empty list',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate and parse content into typed objects
|
||||||
|
try:
|
||||||
|
content = _content_type_adapter.validate_python(raw_content)
|
||||||
|
except ValidationError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f'Invalid content format: {e}',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rate limit: max 10 pending messages per conversation
|
||||||
|
pending_count = await pending_service.count_pending_messages(conversation_id)
|
||||||
|
if pending_count >= 10:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail='Too many pending messages. Maximum 10 messages per conversation.',
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await pending_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
content=content,
|
||||||
|
role=role,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'Queued pending message {response.id} for conversation {conversation_id} '
|
||||||
|
f'(position: {response.position})'
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
200
openhands/app_server/pending_messages/pending_message_service.py
Normal file
200
openhands/app_server/pending_messages/pending_message_service.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""Service for managing pending messages in SQL database."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
from sqlalchemy import JSON, Column, String, func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from openhands.agent_server.models import ImageContent, TextContent
|
||||||
|
from openhands.app_server.pending_messages.pending_message_models import (
|
||||||
|
PendingMessage,
|
||||||
|
PendingMessageResponse,
|
||||||
|
)
|
||||||
|
from openhands.app_server.services.injector import Injector, InjectorState
|
||||||
|
from openhands.app_server.utils.sql_utils import Base, UtcDateTime
|
||||||
|
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||||
|
|
||||||
|
# Type adapter for deserializing content from JSON
|
||||||
|
_content_type_adapter = TypeAdapter(list[TextContent | ImageContent])
|
||||||
|
|
||||||
|
|
||||||
|
class StoredPendingMessage(Base): # type: ignore
|
||||||
|
"""SQLAlchemy model for pending messages."""
|
||||||
|
|
||||||
|
__tablename__ = 'pending_messages'
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
conversation_id = Column(String, nullable=False, index=True)
|
||||||
|
role = Column(String(20), nullable=False, default='user')
|
||||||
|
content = Column(JSON, nullable=False)
|
||||||
|
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PendingMessageService(ABC):
|
||||||
|
"""Abstract service for managing pending messages."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add_message(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
content: list[TextContent | ImageContent],
|
||||||
|
role: str = 'user',
|
||||||
|
) -> PendingMessageResponse:
|
||||||
|
"""Queue a message for delivery when conversation becomes ready."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]:
|
||||||
|
"""Get all pending messages for a conversation, ordered by created_at."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def count_pending_messages(self, conversation_id: str) -> int:
|
||||||
|
"""Count pending messages for a conversation."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_messages_for_conversation(self, conversation_id: str) -> int:
|
||||||
|
"""Delete all pending messages for a conversation, returning count deleted."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_conversation_id(
|
||||||
|
self, old_conversation_id: str, new_conversation_id: str
|
||||||
|
) -> int:
|
||||||
|
"""Update conversation_id when task-id transitions to real conversation-id.
|
||||||
|
|
||||||
|
Returns the number of messages updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SQLPendingMessageService(PendingMessageService):
|
||||||
|
"""SQL implementation of PendingMessageService."""
|
||||||
|
|
||||||
|
db_session: AsyncSession
|
||||||
|
|
||||||
|
async def add_message(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
content: list[TextContent | ImageContent],
|
||||||
|
role: str = 'user',
|
||||||
|
) -> PendingMessageResponse:
|
||||||
|
"""Queue a message for delivery when conversation becomes ready."""
|
||||||
|
# Create the pending message
|
||||||
|
pending_message = PendingMessage(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role=role,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count existing pending messages for position
|
||||||
|
count_stmt = select(func.count()).where(
|
||||||
|
StoredPendingMessage.conversation_id == conversation_id
|
||||||
|
)
|
||||||
|
result = await self.db_session.execute(count_stmt)
|
||||||
|
position = result.scalar() or 0
|
||||||
|
|
||||||
|
# Serialize content to JSON-compatible format for storage
|
||||||
|
content_json = [item.model_dump() for item in content]
|
||||||
|
|
||||||
|
# Store in database
|
||||||
|
stored_message = StoredPendingMessage(
|
||||||
|
id=str(pending_message.id),
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role=role,
|
||||||
|
content=content_json,
|
||||||
|
created_at=pending_message.created_at,
|
||||||
|
)
|
||||||
|
self.db_session.add(stored_message)
|
||||||
|
await self.db_session.commit()
|
||||||
|
|
||||||
|
return PendingMessageResponse(
|
||||||
|
id=pending_message.id,
|
||||||
|
queued=True,
|
||||||
|
position=position + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]:
|
||||||
|
"""Get all pending messages for a conversation, ordered by created_at."""
|
||||||
|
stmt = (
|
||||||
|
select(StoredPendingMessage)
|
||||||
|
.where(StoredPendingMessage.conversation_id == conversation_id)
|
||||||
|
.order_by(StoredPendingMessage.created_at.asc())
|
||||||
|
)
|
||||||
|
result = await self.db_session.execute(stmt)
|
||||||
|
stored_messages = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
PendingMessage(
|
||||||
|
id=msg.id,
|
||||||
|
conversation_id=msg.conversation_id,
|
||||||
|
role=msg.role,
|
||||||
|
content=_content_type_adapter.validate_python(msg.content),
|
||||||
|
created_at=msg.created_at,
|
||||||
|
)
|
||||||
|
for msg in stored_messages
|
||||||
|
]
|
||||||
|
|
||||||
|
async def count_pending_messages(self, conversation_id: str) -> int:
|
||||||
|
"""Count pending messages for a conversation."""
|
||||||
|
count_stmt = select(func.count()).where(
|
||||||
|
StoredPendingMessage.conversation_id == conversation_id
|
||||||
|
)
|
||||||
|
result = await self.db_session.execute(count_stmt)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def delete_messages_for_conversation(self, conversation_id: str) -> int:
|
||||||
|
"""Delete all pending messages for a conversation, returning count deleted."""
|
||||||
|
stmt = select(StoredPendingMessage).where(
|
||||||
|
StoredPendingMessage.conversation_id == conversation_id
|
||||||
|
)
|
||||||
|
result = await self.db_session.execute(stmt)
|
||||||
|
stored_messages = result.scalars().all()
|
||||||
|
|
||||||
|
count = len(stored_messages)
|
||||||
|
for msg in stored_messages:
|
||||||
|
await self.db_session.delete(msg)
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
await self.db_session.commit()
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def update_conversation_id(
|
||||||
|
self, old_conversation_id: str, new_conversation_id: str
|
||||||
|
) -> int:
|
||||||
|
"""Update conversation_id when task-id transitions to real conversation-id."""
|
||||||
|
stmt = select(StoredPendingMessage).where(
|
||||||
|
StoredPendingMessage.conversation_id == old_conversation_id
|
||||||
|
)
|
||||||
|
result = await self.db_session.execute(stmt)
|
||||||
|
stored_messages = result.scalars().all()
|
||||||
|
|
||||||
|
count = len(stored_messages)
|
||||||
|
for msg in stored_messages:
|
||||||
|
msg.conversation_id = new_conversation_id
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
await self.db_session.commit()
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class PendingMessageServiceInjector(
|
||||||
|
DiscriminatedUnionMixin, Injector[PendingMessageService], ABC
|
||||||
|
):
|
||||||
|
"""Abstract injector for PendingMessageService."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SQLPendingMessageServiceInjector(PendingMessageServiceInjector):
|
||||||
|
"""SQL-based injector for PendingMessageService."""
|
||||||
|
|
||||||
|
async def inject(
|
||||||
|
self, state: InjectorState, request: Request | None = None
|
||||||
|
) -> AsyncGenerator[PendingMessageService, None]:
|
||||||
|
from openhands.app_server.config import get_db_session
|
||||||
|
|
||||||
|
async with get_db_session(state) as db_session:
|
||||||
|
yield SQLPendingMessageService(db_session=db_session)
|
||||||
@@ -5,6 +5,9 @@ from openhands.app_server.event import event_router
|
|||||||
from openhands.app_server.event_callback import (
|
from openhands.app_server.event_callback import (
|
||||||
webhook_router,
|
webhook_router,
|
||||||
)
|
)
|
||||||
|
from openhands.app_server.pending_messages.pending_message_router import (
|
||||||
|
router as pending_message_router,
|
||||||
|
)
|
||||||
from openhands.app_server.sandbox import sandbox_router, sandbox_spec_router
|
from openhands.app_server.sandbox import sandbox_router, sandbox_spec_router
|
||||||
from openhands.app_server.user import user_router
|
from openhands.app_server.user import user_router
|
||||||
from openhands.app_server.web_client import web_client_router
|
from openhands.app_server.web_client import web_client_router
|
||||||
@@ -13,6 +16,7 @@ from openhands.app_server.web_client import web_client_router
|
|||||||
router = APIRouter(prefix='/api/v1')
|
router = APIRouter(prefix='/api/v1')
|
||||||
router.include_router(event_router.router)
|
router.include_router(event_router.router)
|
||||||
router.include_router(app_conversation_router.router)
|
router.include_router(app_conversation_router.router)
|
||||||
|
router.include_router(pending_message_router)
|
||||||
router.include_router(sandbox_router.router)
|
router.include_router(sandbox_router.router)
|
||||||
router.include_router(sandbox_spec_router.router)
|
router.include_router(sandbox_spec_router.router)
|
||||||
router.include_router(user_router.router)
|
router.include_router(user_router.router)
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ class TestLiveStatusAppConversationService:
|
|||||||
self.mock_event_callback_service = Mock()
|
self.mock_event_callback_service = Mock()
|
||||||
self.mock_event_service = Mock()
|
self.mock_event_service = Mock()
|
||||||
self.mock_httpx_client = Mock()
|
self.mock_httpx_client = Mock()
|
||||||
|
self.mock_pending_message_service = Mock()
|
||||||
|
|
||||||
# Create service instance
|
# Create service instance
|
||||||
self.service = LiveStatusAppConversationService(
|
self.service = LiveStatusAppConversationService(
|
||||||
@@ -92,6 +93,7 @@ class TestLiveStatusAppConversationService:
|
|||||||
sandbox_service=self.mock_sandbox_service,
|
sandbox_service=self.mock_sandbox_service,
|
||||||
sandbox_spec_service=self.mock_sandbox_spec_service,
|
sandbox_spec_service=self.mock_sandbox_spec_service,
|
||||||
jwt_service=self.mock_jwt_service,
|
jwt_service=self.mock_jwt_service,
|
||||||
|
pending_message_service=self.mock_pending_message_service,
|
||||||
sandbox_startup_timeout=30,
|
sandbox_startup_timeout=30,
|
||||||
sandbox_startup_poll_frequency=1,
|
sandbox_startup_poll_frequency=1,
|
||||||
max_num_conversations_per_sandbox=20,
|
max_num_conversations_per_sandbox=20,
|
||||||
@@ -2329,6 +2331,7 @@ class TestPluginHandling:
|
|||||||
self.mock_event_callback_service = Mock()
|
self.mock_event_callback_service = Mock()
|
||||||
self.mock_event_service = Mock()
|
self.mock_event_service = Mock()
|
||||||
self.mock_httpx_client = Mock()
|
self.mock_httpx_client = Mock()
|
||||||
|
self.mock_pending_message_service = Mock()
|
||||||
|
|
||||||
# Create service instance
|
# Create service instance
|
||||||
self.service = LiveStatusAppConversationService(
|
self.service = LiveStatusAppConversationService(
|
||||||
@@ -2341,6 +2344,7 @@ class TestPluginHandling:
|
|||||||
sandbox_service=self.mock_sandbox_service,
|
sandbox_service=self.mock_sandbox_service,
|
||||||
sandbox_spec_service=self.mock_sandbox_spec_service,
|
sandbox_spec_service=self.mock_sandbox_spec_service,
|
||||||
jwt_service=self.mock_jwt_service,
|
jwt_service=self.mock_jwt_service,
|
||||||
|
pending_message_service=self.mock_pending_message_service,
|
||||||
sandbox_startup_timeout=30,
|
sandbox_startup_timeout=30,
|
||||||
sandbox_startup_poll_frequency=1,
|
sandbox_startup_poll_frequency=1,
|
||||||
max_num_conversations_per_sandbox=20,
|
max_num_conversations_per_sandbox=20,
|
||||||
|
|||||||
227
tests/unit/app_server/test_pending_message_router.py
Normal file
227
tests/unit/app_server/test_pending_message_router.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""Unit tests for the pending_message_router endpoints.
|
||||||
|
|
||||||
|
This module tests the queue_pending_message endpoint,
|
||||||
|
focusing on request validation and rate limiting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
from openhands.agent_server.models import TextContent
|
||||||
|
from openhands.app_server.pending_messages.pending_message_models import (
|
||||||
|
PendingMessageResponse,
|
||||||
|
)
|
||||||
|
from openhands.app_server.pending_messages.pending_message_router import (
|
||||||
|
queue_pending_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_service(
|
||||||
|
add_message_return=None,
|
||||||
|
count_pending_messages_return=0,
|
||||||
|
):
|
||||||
|
"""Create a mock PendingMessageService for testing."""
|
||||||
|
service = MagicMock()
|
||||||
|
service.add_message = AsyncMock(return_value=add_message_return)
|
||||||
|
service.count_pending_messages = AsyncMock(
|
||||||
|
return_value=count_pending_messages_return
|
||||||
|
)
|
||||||
|
return service
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_request(body: dict):
|
||||||
|
"""Create a mock FastAPI Request with given JSON body."""
|
||||||
|
request = MagicMock()
|
||||||
|
request.json = AsyncMock(return_value=body)
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestQueuePendingMessage:
|
||||||
|
"""Test suite for queue_pending_message endpoint."""
|
||||||
|
|
||||||
|
async def test_queues_message_successfully(self):
|
||||||
|
"""Test that a valid message is queued successfully."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
raw_content = [{'type': 'text', 'text': 'Hello, world!'}]
|
||||||
|
expected_response = PendingMessageResponse(
|
||||||
|
id=str(uuid4()),
|
||||||
|
queued=True,
|
||||||
|
position=1,
|
||||||
|
)
|
||||||
|
mock_service = _make_mock_service(
|
||||||
|
add_message_return=expected_response,
|
||||||
|
count_pending_messages_return=0,
|
||||||
|
)
|
||||||
|
mock_request = _make_mock_request({'content': raw_content, 'role': 'user'})
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await queue_pending_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
request=mock_request,
|
||||||
|
pending_service=mock_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == expected_response
|
||||||
|
mock_service.add_message.assert_called_once()
|
||||||
|
call_kwargs = mock_service.add_message.call_args.kwargs
|
||||||
|
assert call_kwargs['conversation_id'] == conversation_id
|
||||||
|
assert call_kwargs['role'] == 'user'
|
||||||
|
# Content should be parsed into typed objects
|
||||||
|
assert len(call_kwargs['content']) == 1
|
||||||
|
assert isinstance(call_kwargs['content'][0], TextContent)
|
||||||
|
assert call_kwargs['content'][0].text == 'Hello, world!'
|
||||||
|
|
||||||
|
async def test_uses_default_role_when_not_provided(self):
|
||||||
|
"""Test that 'user' role is used by default."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
raw_content = [{'type': 'text', 'text': 'Test message'}]
|
||||||
|
expected_response = PendingMessageResponse(
|
||||||
|
id=str(uuid4()),
|
||||||
|
queued=True,
|
||||||
|
position=1,
|
||||||
|
)
|
||||||
|
mock_service = _make_mock_service(
|
||||||
|
add_message_return=expected_response,
|
||||||
|
count_pending_messages_return=0,
|
||||||
|
)
|
||||||
|
mock_request = _make_mock_request({'content': raw_content})
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await queue_pending_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
request=mock_request,
|
||||||
|
pending_service=mock_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_service.add_message.assert_called_once()
|
||||||
|
call_kwargs = mock_service.add_message.call_args.kwargs
|
||||||
|
assert call_kwargs['conversation_id'] == conversation_id
|
||||||
|
assert call_kwargs['role'] == 'user'
|
||||||
|
assert isinstance(call_kwargs['content'][0], TextContent)
|
||||||
|
|
||||||
|
async def test_returns_400_for_invalid_json_body(self):
|
||||||
|
"""Test that invalid JSON body returns 400 Bad Request."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
mock_service = _make_mock_service()
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.json = AsyncMock(side_effect=Exception('Invalid JSON'))
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await queue_pending_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
request=mock_request,
|
||||||
|
pending_service=mock_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert 'Invalid request body' in exc_info.value.detail
|
||||||
|
|
||||||
|
async def test_returns_400_when_content_is_missing(self):
|
||||||
|
"""Test that missing content returns 400 Bad Request."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
mock_service = _make_mock_service()
|
||||||
|
mock_request = _make_mock_request({'role': 'user'})
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await queue_pending_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
request=mock_request,
|
||||||
|
pending_service=mock_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert 'content must be a non-empty list' in exc_info.value.detail
|
||||||
|
|
||||||
|
async def test_returns_400_when_content_is_not_a_list(self):
|
||||||
|
"""Test that non-list content returns 400 Bad Request."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
mock_service = _make_mock_service()
|
||||||
|
mock_request = _make_mock_request({'content': 'not a list'})
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await queue_pending_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
request=mock_request,
|
||||||
|
pending_service=mock_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert 'content must be a non-empty list' in exc_info.value.detail
|
||||||
|
|
||||||
|
async def test_returns_400_when_content_is_empty_list(self):
|
||||||
|
"""Test that empty list content returns 400 Bad Request."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
mock_service = _make_mock_service()
|
||||||
|
mock_request = _make_mock_request({'content': []})
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await queue_pending_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
request=mock_request,
|
||||||
|
pending_service=mock_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert 'content must be a non-empty list' in exc_info.value.detail
|
||||||
|
|
||||||
|
async def test_returns_429_when_rate_limit_exceeded(self):
|
||||||
|
"""Test that exceeding rate limit returns 429 Too Many Requests."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
raw_content = [{'type': 'text', 'text': 'Test message'}]
|
||||||
|
mock_service = _make_mock_service(count_pending_messages_return=10)
|
||||||
|
mock_request = _make_mock_request({'content': raw_content})
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await queue_pending_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
request=mock_request,
|
||||||
|
pending_service=mock_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||||
|
assert 'Maximum 10 messages' in exc_info.value.detail
|
||||||
|
|
||||||
|
async def test_allows_up_to_10_messages(self):
|
||||||
|
"""Test that 9 existing messages still allows adding one more."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
raw_content = [{'type': 'text', 'text': 'Test message'}]
|
||||||
|
expected_response = PendingMessageResponse(
|
||||||
|
id=str(uuid4()),
|
||||||
|
queued=True,
|
||||||
|
position=10,
|
||||||
|
)
|
||||||
|
mock_service = _make_mock_service(
|
||||||
|
add_message_return=expected_response,
|
||||||
|
count_pending_messages_return=9,
|
||||||
|
)
|
||||||
|
mock_request = _make_mock_request({'content': raw_content})
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await queue_pending_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
request=mock_request,
|
||||||
|
pending_service=mock_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == expected_response
|
||||||
|
mock_service.add_message.assert_called_once()
|
||||||
309
tests/unit/app_server/test_pending_message_service.py
Normal file
309
tests/unit/app_server/test_pending_message_service.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
"""Tests for SQLPendingMessageService.
|
||||||
|
|
||||||
|
This module tests the SQL implementation of PendingMessageService,
|
||||||
|
covering message queuing, retrieval, counting, deletion, and
|
||||||
|
conversation_id updates using SQLite as a mock database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from openhands.agent_server.models import TextContent
|
||||||
|
from openhands.app_server.pending_messages.pending_message_models import (
|
||||||
|
PendingMessageResponse,
|
||||||
|
)
|
||||||
|
from openhands.app_server.pending_messages.pending_message_service import (
|
||||||
|
SQLPendingMessageService,
|
||||||
|
)
|
||||||
|
from openhands.app_server.utils.sql_utils import Base
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def async_engine():
|
||||||
|
"""Create an async SQLite engine for testing."""
|
||||||
|
engine = create_async_engine(
|
||||||
|
'sqlite+aiosqlite:///:memory:',
|
||||||
|
poolclass=StaticPool,
|
||||||
|
connect_args={'check_same_thread': False},
|
||||||
|
echo=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
yield engine
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Create an async session for testing."""
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(async_session) -> SQLPendingMessageService:
|
||||||
|
"""Create a SQLPendingMessageService instance for testing."""
|
||||||
|
return SQLPendingMessageService(db_session=async_session)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_content() -> list[TextContent]:
|
||||||
|
"""Create sample message content for testing."""
|
||||||
|
return [TextContent(text='Hello, this is a test message')]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSQLPendingMessageService:
|
||||||
|
"""Test suite for SQLPendingMessageService."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_message_creates_message_with_correct_data(
|
||||||
|
self,
|
||||||
|
service: SQLPendingMessageService,
|
||||||
|
sample_content: list[TextContent],
|
||||||
|
):
|
||||||
|
"""Test that add_message creates a message with the expected fields."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = await service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
content=sample_content,
|
||||||
|
role='user',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(response, PendingMessageResponse)
|
||||||
|
assert response.queued is True
|
||||||
|
assert response.id is not None
|
||||||
|
|
||||||
|
# Verify the message was stored correctly
|
||||||
|
messages = await service.get_pending_messages(conversation_id)
|
||||||
|
assert len(messages) == 1
|
||||||
|
assert messages[0].conversation_id == conversation_id
|
||||||
|
assert len(messages[0].content) == 1
|
||||||
|
assert isinstance(messages[0].content[0], TextContent)
|
||||||
|
assert messages[0].content[0].text == sample_content[0].text
|
||||||
|
assert messages[0].role == 'user'
|
||||||
|
assert messages[0].created_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_message_returns_correct_queue_position(
|
||||||
|
self,
|
||||||
|
service: SQLPendingMessageService,
|
||||||
|
sample_content: list[TextContent],
|
||||||
|
):
|
||||||
|
"""Test that queue position increments correctly for each message."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
|
||||||
|
# Act - Add three messages
|
||||||
|
response1 = await service.add_message(conversation_id, sample_content)
|
||||||
|
response2 = await service.add_message(conversation_id, sample_content)
|
||||||
|
response3 = await service.add_message(conversation_id, sample_content)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response1.position == 1
|
||||||
|
assert response2.position == 2
|
||||||
|
assert response3.position == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pending_messages_returns_messages_ordered_by_created_at(
|
||||||
|
self,
|
||||||
|
service: SQLPendingMessageService,
|
||||||
|
):
|
||||||
|
"""Test that messages are returned in chronological order."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
contents = [
|
||||||
|
[TextContent(text='First message')],
|
||||||
|
[TextContent(text='Second message')],
|
||||||
|
[TextContent(text='Third message')],
|
||||||
|
]
|
||||||
|
|
||||||
|
for content in contents:
|
||||||
|
await service.add_message(conversation_id, content)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
messages = await service.get_pending_messages(conversation_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(messages) == 3
|
||||||
|
assert messages[0].content[0].text == 'First message'
|
||||||
|
assert messages[1].content[0].text == 'Second message'
|
||||||
|
assert messages[2].content[0].text == 'Third message'
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pending_messages_returns_empty_list_when_none_exist(
|
||||||
|
self,
|
||||||
|
service: SQLPendingMessageService,
|
||||||
|
):
|
||||||
|
"""Test that an empty list is returned for a conversation with no messages."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
messages = await service.get_pending_messages(conversation_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert messages == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_pending_messages_returns_correct_count(
|
||||||
|
self,
|
||||||
|
service: SQLPendingMessageService,
|
||||||
|
sample_content: list[TextContent],
|
||||||
|
):
|
||||||
|
"""Test that count_pending_messages returns the correct number."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
other_conversation_id = f'task-{uuid4().hex}'
|
||||||
|
|
||||||
|
# Add 3 messages to first conversation
|
||||||
|
for _ in range(3):
|
||||||
|
await service.add_message(conversation_id, sample_content)
|
||||||
|
|
||||||
|
# Add 2 messages to second conversation
|
||||||
|
for _ in range(2):
|
||||||
|
await service.add_message(other_conversation_id, sample_content)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
count1 = await service.count_pending_messages(conversation_id)
|
||||||
|
count2 = await service.count_pending_messages(other_conversation_id)
|
||||||
|
count_empty = await service.count_pending_messages('nonexistent')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert count1 == 3
|
||||||
|
assert count2 == 2
|
||||||
|
assert count_empty == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_messages_for_conversation_removes_all_messages(
|
||||||
|
self,
|
||||||
|
service: SQLPendingMessageService,
|
||||||
|
sample_content: list[TextContent],
|
||||||
|
):
|
||||||
|
"""Test that delete_messages_for_conversation removes all messages and returns count."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
other_conversation_id = f'task-{uuid4().hex}'
|
||||||
|
|
||||||
|
# Add messages to both conversations
|
||||||
|
for _ in range(3):
|
||||||
|
await service.add_message(conversation_id, sample_content)
|
||||||
|
await service.add_message(other_conversation_id, sample_content)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
deleted_count = await service.delete_messages_for_conversation(conversation_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert deleted_count == 3
|
||||||
|
assert await service.count_pending_messages(conversation_id) == 0
|
||||||
|
# Other conversation should be unaffected
|
||||||
|
assert await service.count_pending_messages(other_conversation_id) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_messages_for_conversation_returns_zero_when_none_exist(
|
||||||
|
self,
|
||||||
|
service: SQLPendingMessageService,
|
||||||
|
):
|
||||||
|
"""Test that deleting from nonexistent conversation returns 0."""
|
||||||
|
# Arrange
|
||||||
|
conversation_id = f'task-{uuid4().hex}'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
deleted_count = await service.delete_messages_for_conversation(conversation_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert deleted_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_conversation_id_updates_all_matching_messages(
|
||||||
|
self,
|
||||||
|
service: SQLPendingMessageService,
|
||||||
|
sample_content: list[TextContent],
|
||||||
|
):
|
||||||
|
"""Test that update_conversation_id updates all messages with the old ID."""
|
||||||
|
# Arrange
|
||||||
|
old_conversation_id = f'task-{uuid4().hex}'
|
||||||
|
new_conversation_id = str(uuid4())
|
||||||
|
unrelated_conversation_id = f'task-{uuid4().hex}'
|
||||||
|
|
||||||
|
# Add messages to old conversation
|
||||||
|
for _ in range(3):
|
||||||
|
await service.add_message(old_conversation_id, sample_content)
|
||||||
|
|
||||||
|
# Add message to unrelated conversation
|
||||||
|
await service.add_message(unrelated_conversation_id, sample_content)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
updated_count = await service.update_conversation_id(
|
||||||
|
old_conversation_id, new_conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert updated_count == 3
|
||||||
|
|
||||||
|
# Verify old conversation has no messages
|
||||||
|
assert await service.count_pending_messages(old_conversation_id) == 0
|
||||||
|
|
||||||
|
# Verify new conversation has all messages
|
||||||
|
messages = await service.get_pending_messages(new_conversation_id)
|
||||||
|
assert len(messages) == 3
|
||||||
|
for msg in messages:
|
||||||
|
assert msg.conversation_id == new_conversation_id
|
||||||
|
|
||||||
|
# Verify unrelated conversation is unchanged
|
||||||
|
assert await service.count_pending_messages(unrelated_conversation_id) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_conversation_id_returns_zero_when_no_match(
|
||||||
|
self,
|
||||||
|
service: SQLPendingMessageService,
|
||||||
|
):
|
||||||
|
"""Test that updating nonexistent conversation_id returns 0."""
|
||||||
|
# Arrange
|
||||||
|
old_conversation_id = f'task-{uuid4().hex}'
|
||||||
|
new_conversation_id = str(uuid4())
|
||||||
|
|
||||||
|
# Act
|
||||||
|
updated_count = await service.update_conversation_id(
|
||||||
|
old_conversation_id, new_conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert updated_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_messages_are_isolated_between_conversations(
|
||||||
|
self,
|
||||||
|
service: SQLPendingMessageService,
|
||||||
|
):
|
||||||
|
"""Test that operations on one conversation don't affect others."""
|
||||||
|
# Arrange
|
||||||
|
conv1 = f'task-{uuid4().hex}'
|
||||||
|
conv2 = f'task-{uuid4().hex}'
|
||||||
|
|
||||||
|
await service.add_message(conv1, [TextContent(text='Conv1 msg')])
|
||||||
|
await service.add_message(conv2, [TextContent(text='Conv2 msg')])
|
||||||
|
|
||||||
|
# Act
|
||||||
|
messages1 = await service.get_pending_messages(conv1)
|
||||||
|
messages2 = await service.get_pending_messages(conv2)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(messages1) == 1
|
||||||
|
assert len(messages2) == 1
|
||||||
|
assert messages1[0].content[0].text == 'Conv1 msg'
|
||||||
|
assert messages2[0].content[0].text == 'Conv2 msg'
|
||||||
@@ -2187,6 +2187,7 @@ async def test_delete_v1_conversation_with_sub_conversations():
|
|||||||
sandbox_service=mock_sandbox_service,
|
sandbox_service=mock_sandbox_service,
|
||||||
sandbox_spec_service=MagicMock(),
|
sandbox_spec_service=MagicMock(),
|
||||||
jwt_service=MagicMock(),
|
jwt_service=MagicMock(),
|
||||||
|
pending_message_service=MagicMock(),
|
||||||
sandbox_startup_timeout=120,
|
sandbox_startup_timeout=120,
|
||||||
sandbox_startup_poll_frequency=2,
|
sandbox_startup_poll_frequency=2,
|
||||||
max_num_conversations_per_sandbox=20,
|
max_num_conversations_per_sandbox=20,
|
||||||
@@ -2311,6 +2312,7 @@ async def test_delete_v1_conversation_with_no_sub_conversations():
|
|||||||
sandbox_service=mock_sandbox_service,
|
sandbox_service=mock_sandbox_service,
|
||||||
sandbox_spec_service=MagicMock(),
|
sandbox_spec_service=MagicMock(),
|
||||||
jwt_service=MagicMock(),
|
jwt_service=MagicMock(),
|
||||||
|
pending_message_service=MagicMock(),
|
||||||
sandbox_startup_timeout=120,
|
sandbox_startup_timeout=120,
|
||||||
sandbox_startup_poll_frequency=2,
|
sandbox_startup_poll_frequency=2,
|
||||||
max_num_conversations_per_sandbox=20,
|
max_num_conversations_per_sandbox=20,
|
||||||
@@ -2465,6 +2467,7 @@ async def test_delete_v1_conversation_sub_conversation_deletion_error():
|
|||||||
sandbox_service=mock_sandbox_service,
|
sandbox_service=mock_sandbox_service,
|
||||||
sandbox_spec_service=MagicMock(),
|
sandbox_spec_service=MagicMock(),
|
||||||
jwt_service=MagicMock(),
|
jwt_service=MagicMock(),
|
||||||
|
pending_message_service=MagicMock(),
|
||||||
sandbox_startup_timeout=120,
|
sandbox_startup_timeout=120,
|
||||||
sandbox_startup_poll_frequency=2,
|
sandbox_startup_poll_frequency=2,
|
||||||
max_num_conversations_per_sandbox=20,
|
max_num_conversations_per_sandbox=20,
|
||||||
|
|||||||
Reference in New Issue
Block a user