From 0ec962e96be6281a4603e851536bcd7953613f67 Mon Sep 17 00:00:00 2001 From: MkDev11 <94194147+MkDev11@users.noreply.github.com> Date: Thu, 19 Mar 2026 07:13:58 -0700 Subject: [PATCH] feat: add /clear endpoint for V1 conversations (#12786) Co-authored-by: mkdev11 Co-authored-by: openhands Co-authored-by: tofarr Co-authored-by: hieptl --- .../components/interactive-chat-box.test.tsx | 30 ++ .../use-new-conversation-command.test.tsx | 299 ++++++++++++++++++ .../__tests__/hooks/use-websocket.test.ts | 4 +- .../v1-conversation-service.api.ts | 4 + .../features/chat/chat-interface.tsx | 32 +- .../chat/components/chat-input-container.tsx | 3 + .../chat/components/chat-input-field.tsx | 10 +- .../chat/components/chat-input-row.tsx | 3 + .../features/chat/custom-chat-input.tsx | 3 + .../features/chat/interactive-chat-box.tsx | 8 +- .../mutation/use-new-conversation-command.ts | 115 +++++++ .../query/use-unified-get-git-changes.ts | 1 + frontend/src/i18n/declaration.ts | 8 + frontend/src/i18n/translation.json | 136 ++++++++ frontend/src/utils/websocket-url.ts | 13 + .../app_conversation_info_service.py | 8 + .../app_conversation_service.py | 26 +- .../live_status_app_conversation_service.py | 13 +- .../sql_app_conversation_info_service.py | 8 + openhands/app_server/config.py | 21 ++ .../sandbox/docker_sandbox_service.py | 17 +- openhands/server/middleware.py | 21 +- .../server/routes/manage_conversations.py | 18 +- .../test_sql_app_conversation_info_service.py | 48 +++ .../server/data_models/test_conversation.py | 12 +- tests/unit/server/test_middleware.py | 79 +++-- 26 files changed, 884 insertions(+), 56 deletions(-) create mode 100644 frontend/__tests__/hooks/mutation/use-new-conversation-command.test.tsx create mode 100644 frontend/src/hooks/mutation/use-new-conversation-command.ts diff --git a/frontend/__tests__/components/interactive-chat-box.test.tsx b/frontend/__tests__/components/interactive-chat-box.test.tsx index ecb6623806..884217facb 100644 --- a/frontend/__tests__/components/interactive-chat-box.test.tsx +++ b/frontend/__tests__/components/interactive-chat-box.test.tsx @@ -216,6 +216,36 @@ describe("InteractiveChatBox", () => { expect(onSubmitMock).not.toHaveBeenCalled(); }); + it("should lock the text input field when disabled prop is true (isNewConversationPending)", () => { + mockStores(AgentState.INIT); + + renderInteractiveChatBox({ + onSubmit: onSubmitMock, + disabled: true, + }); + + const chatInput = screen.getByTestId("chat-input"); + // When disabled=true, the text field should not be editable + expect(chatInput).toHaveAttribute("contenteditable", "false"); + // Should show visual disabled state + expect(chatInput.className).toContain("cursor-not-allowed"); + expect(chatInput.className).toContain("opacity-50"); + }); + + it("should keep the text input field editable when disabled prop is false", () => { + mockStores(AgentState.INIT); + + renderInteractiveChatBox({ + onSubmit: onSubmitMock, + disabled: false, + }); + + const chatInput = screen.getByTestId("chat-input"); + expect(chatInput).toHaveAttribute("contenteditable", "true"); + expect(chatInput.className).not.toContain("cursor-not-allowed"); + expect(chatInput.className).not.toContain("opacity-50"); + }); + it("should handle image upload and message submission correctly", async () => { const user = userEvent.setup(); const onSubmit = vi.fn(); diff --git a/frontend/__tests__/hooks/mutation/use-new-conversation-command.test.tsx b/frontend/__tests__/hooks/mutation/use-new-conversation-command.test.tsx new file mode 100644 index 0000000000..07f110ff17 --- /dev/null +++ b/frontend/__tests__/hooks/mutation/use-new-conversation-command.test.tsx @@ -0,0 +1,299 @@ +import { renderHook, waitFor } from "@testing-library/react"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { describe, expect, it, vi, beforeEach } from "vitest"; +import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api"; +import { useNewConversationCommand } from "#/hooks/mutation/use-new-conversation-command"; + +const mockNavigate = vi.fn(); + +vi.mock("react-router", () => ({ + useNavigate: () => mockNavigate, + useParams: () => ({ conversationId: "conv-123" }), +})); + +vi.mock("react-i18next", () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})); + +const { mockToast } = vi.hoisted(() => { + const mockToast = Object.assign(vi.fn(), { + loading: vi.fn(), + dismiss: vi.fn(), + }); + return { mockToast }; +}); + +vi.mock("react-hot-toast", () => ({ + default: mockToast, +})); + +vi.mock("#/utils/custom-toast-handlers", () => ({ + displaySuccessToast: vi.fn(), + displayErrorToast: vi.fn(), + TOAST_OPTIONS: { position: "top-right" }, +})); + +const mockConversation = { + conversation_id: "conv-123", + sandbox_id: "sandbox-456", + title: "Test Conversation", + selected_repository: null, + selected_branch: null, + git_provider: null, + last_updated_at: new Date().toISOString(), + created_at: new Date().toISOString(), + status: "RUNNING" as const, + runtime_status: null, + url: null, + session_api_key: null, + conversation_version: "V1" as const, +}; + +vi.mock("#/hooks/query/use-active-conversation", () => ({ + useActiveConversation: () => ({ + data: mockConversation, + }), +})); + +function makeStartTask(overrides: Record = {}) { + return { + id: "task-789", + created_by_user_id: null, + status: "READY", + detail: null, + app_conversation_id: "new-conv-999", + sandbox_id: "sandbox-456", + agent_server_url: "http://agent-server.local", + request: { + sandbox_id: null, + initial_message: null, + processors: [], + llm_model: null, + selected_repository: null, + selected_branch: null, + git_provider: null, + suggested_task: null, + title: null, + trigger: null, + pr_number: [], + parent_conversation_id: null, + agent_type: "default", + }, + created_at: new Date().toISOString(), + updated_at: new Date().toISOString(), + ...overrides, + }; +} + +describe("useNewConversationCommand", () => { + let queryClient: QueryClient; + + beforeEach(() => { + vi.clearAllMocks(); + queryClient = new QueryClient({ + defaultOptions: { mutations: { retry: false } }, + }); + // Mock batchGetAppConversations to return V1 data with llm_model + vi.spyOn( + V1ConversationService, + "batchGetAppConversations", + ).mockResolvedValue([ + { + id: "conv-123", + title: "Test Conversation", + sandbox_id: "sandbox-456", + sandbox_status: "RUNNING", + execution_status: "IDLE", + conversation_url: null, + session_api_key: null, + selected_repository: null, + selected_branch: null, + git_provider: null, + trigger: null, + pr_number: [], + llm_model: "gpt-4o", + metrics: null, + created_at: new Date().toISOString(), + updated_at: new Date().toISOString(), + sub_conversation_ids: [], + public: false, + } as never, + ]); + }); + + const wrapper = ({ children }: { children: React.ReactNode }) => ( + {children} + ); + + it("calls createConversation with sandbox_id and navigates on success", async () => { + const readyTask = makeStartTask(); + const createSpy = vi + .spyOn(V1ConversationService, "createConversation") + .mockResolvedValue(readyTask as never); + const getStartTaskSpy = vi + .spyOn(V1ConversationService, "getStartTask") + .mockResolvedValue(readyTask as never); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + await result.current.mutateAsync(); + + await waitFor(() => { + expect(createSpy).toHaveBeenCalledWith( + undefined, + undefined, + undefined, + undefined, + undefined, + undefined, + undefined, + undefined, + undefined, + "sandbox-456", + "gpt-4o", + ); + expect(getStartTaskSpy).toHaveBeenCalledWith("task-789"); + expect(mockNavigate).toHaveBeenCalledWith( + "/conversations/new-conv-999", + ); + }); + }); + + it("polls getStartTask until status is READY", async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }); + + const workingTask = makeStartTask({ + status: "WORKING", + app_conversation_id: null, + }); + const readyTask = makeStartTask({ status: "READY" }); + + vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue( + workingTask as never, + ); + const getStartTaskSpy = vi + .spyOn(V1ConversationService, "getStartTask") + .mockResolvedValueOnce(workingTask as never) + .mockResolvedValueOnce(readyTask as never); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + const mutatePromise = result.current.mutateAsync(); + + await vi.advanceTimersByTimeAsync(2000); + await mutatePromise; + + await waitFor(() => { + expect(getStartTaskSpy).toHaveBeenCalledTimes(2); + expect(mockNavigate).toHaveBeenCalledWith( + "/conversations/new-conv-999", + ); + }); + + vi.useRealTimers(); + }); + + it("throws when task status is ERROR", async () => { + const errorTask = makeStartTask({ + status: "ERROR", + detail: "Sandbox crashed", + app_conversation_id: null, + }); + + vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue( + errorTask as never, + ); + vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue( + errorTask as never, + ); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + await expect(result.current.mutateAsync()).rejects.toThrow( + "Sandbox crashed", + ); + }); + + it("invalidates conversation list queries on success", async () => { + const readyTask = makeStartTask(); + + vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue( + readyTask as never, + ); + vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue( + readyTask as never, + ); + + const invalidateSpy = vi.spyOn(queryClient, "invalidateQueries"); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + await result.current.mutateAsync(); + + await waitFor(() => { + expect(invalidateSpy).toHaveBeenCalledWith({ + queryKey: ["user", "conversations"], + }); + expect(invalidateSpy).toHaveBeenCalledWith({ + queryKey: ["v1-batch-get-app-conversations"], + }); + }); + }); + + it("creates a standalone conversation (not a sub-conversation) so it appears in the list", async () => { + const readyTask = makeStartTask(); + const createSpy = vi + .spyOn(V1ConversationService, "createConversation") + .mockResolvedValue(readyTask as never); + vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue( + readyTask as never, + ); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + await result.current.mutateAsync(); + + await waitFor(() => { + // parent_conversation_id should be undefined so the new conversation + // is NOT a sub-conversation and will appear in the conversation list. + expect(createSpy).toHaveBeenCalledWith( + undefined, // selectedRepository (null from mock) + undefined, // git_provider (null from mock) + undefined, // initialUserMsg + undefined, // selected_branch (null from mock) + undefined, // conversationInstructions + undefined, // suggestedTask + undefined, // trigger + undefined, // parent_conversation_id is NOT set + undefined, // agent_type + "sandbox-456", // sandbox_id IS set to reuse the sandbox + "gpt-4o", // llm_model IS inherited from the original conversation + ); + }); + }); + + it("shows a loading toast immediately and dismisses it on success", async () => { + const readyTask = makeStartTask(); + + vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue( + readyTask as never, + ); + vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue( + readyTask as never, + ); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + await result.current.mutateAsync(); + + await waitFor(() => { + expect(mockToast.loading).toHaveBeenCalledWith( + "CONVERSATION$CLEARING", + expect.objectContaining({ id: "clear-conversation" }), + ); + expect(mockToast.dismiss).toHaveBeenCalledWith("clear-conversation"); + }); + }); +}); diff --git a/frontend/__tests__/hooks/use-websocket.test.ts b/frontend/__tests__/hooks/use-websocket.test.ts index 7d42507a87..d00db6f856 100644 --- a/frontend/__tests__/hooks/use-websocket.test.ts +++ b/frontend/__tests__/hooks/use-websocket.test.ts @@ -205,7 +205,9 @@ describe("useWebSocket", () => { expect(result.current.isConnected).toBe(true); }); - expect(onCloseSpy).not.toHaveBeenCalled(); + // Reset spy after connection is established to ignore any spurious + // close events fired by the MSW mock during the handshake. + onCloseSpy.mockClear(); // Unmount to trigger close unmount(); diff --git a/frontend/src/api/conversation-service/v1-conversation-service.api.ts b/frontend/src/api/conversation-service/v1-conversation-service.api.ts index a0e99abe0f..bcdad50077 100644 --- a/frontend/src/api/conversation-service/v1-conversation-service.api.ts +++ b/frontend/src/api/conversation-service/v1-conversation-service.api.ts @@ -68,6 +68,8 @@ class V1ConversationService { trigger?: ConversationTrigger, parent_conversation_id?: string, agent_type?: "default" | "plan", + sandbox_id?: string, + llm_model?: string, ): Promise { const body: V1AppConversationStartRequest = { selected_repository: selectedRepository, @@ -78,6 +80,8 @@ class V1ConversationService { trigger, parent_conversation_id: parent_conversation_id || null, agent_type, + sandbox_id: sandbox_id || null, + llm_model: llm_model || null, }; // suggested_task implies the backend will construct the initial_message diff --git a/frontend/src/components/features/chat/chat-interface.tsx b/frontend/src/components/features/chat/chat-interface.tsx index 43218149ae..3b11a2fbc4 100644 --- a/frontend/src/components/features/chat/chat-interface.tsx +++ b/frontend/src/components/features/chat/chat-interface.tsx @@ -38,6 +38,8 @@ import { useTaskPolling } from "#/hooks/query/use-task-polling"; import { useConversationWebSocket } from "#/contexts/conversation-websocket-context"; import ChatStatusIndicator from "./chat-status-indicator"; import { getStatusColor, getStatusText } from "#/utils/utils"; +import { useNewConversationCommand } from "#/hooks/mutation/use-new-conversation-command"; +import { I18nKey } from "#/i18n/declaration"; function getEntryPoint( hasRepository: boolean | null, @@ -80,6 +82,10 @@ export function ChatInterface() { setHitBottom, } = useScrollToBottom(scrollRef); const { data: config } = useConfig(); + const { + mutate: newConversationCommand, + isPending: isNewConversationPending, + } = useNewConversationCommand(); const { curAgentState } = useAgentState(); const { handleBuildPlanClick } = useHandleBuildPlanClick(); @@ -146,6 +152,27 @@ export function ChatInterface() { originalImages: File[], originalFiles: File[], ) => { + // Handle /new command for V1 conversations + if (content.trim() === "/new") { + if (!isV1Conversation) { + displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_V1_ONLY)); + return; + } + if (!params.conversationId) { + displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_NO_ID)); + return; + } + if (totalEvents === 0) { + displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_EMPTY)); + return; + } + if (isNewConversationPending) { + return; + } + newConversationCommand(); + return; + } + // Create mutable copies of the arrays const images = [...originalImages]; const files = [...originalFiles]; @@ -338,7 +365,10 @@ export function ChatInterface() { /> )} - + {config?.app_mode !== "saas" && !isV1Conversation && ( diff --git a/frontend/src/components/features/chat/components/chat-input-container.tsx b/frontend/src/components/features/chat/components/chat-input-container.tsx index ef67069de5..ebb3924458 100644 --- a/frontend/src/components/features/chat/components/chat-input-container.tsx +++ b/frontend/src/components/features/chat/components/chat-input-container.tsx @@ -12,6 +12,7 @@ interface ChatInputContainerProps { chatContainerRef: React.RefObject; isDragOver: boolean; disabled: boolean; + isNewConversationPending?: boolean; showButton: boolean; buttonClassName: string; chatInputRef: React.RefObject; @@ -36,6 +37,7 @@ export function ChatInputContainer({ chatContainerRef, isDragOver, disabled, + isNewConversationPending = false, showButton, buttonClassName, chatInputRef, @@ -89,6 +91,7 @@ export function ChatInputContainer({ ; + disabled?: boolean; onInput: () => void; onPaste: (e: React.ClipboardEvent) => void; onKeyDown: (e: React.KeyboardEvent) => void; @@ -14,6 +16,7 @@ interface ChatInputFieldProps { export function ChatInputField({ chatInputRef, + disabled = false, onInput, onPaste, onKeyDown, @@ -36,8 +39,11 @@ export function ChatInputField({
; disabled: boolean; + isNewConversationPending?: boolean; showButton: boolean; buttonClassName: string; handleFileIconClick: (isDisabled: boolean) => void; @@ -21,6 +22,7 @@ interface ChatInputRowProps { export function ChatInputRow({ chatInputRef, disabled, + isNewConversationPending = false, showButton, buttonClassName, handleFileIconClick, @@ -41,6 +43,7 @@ export function ChatInputRow({ void; @@ -25,6 +26,7 @@ export interface CustomChatInputProps { export function CustomChatInput({ disabled = false, + isNewConversationPending = false, showButton = true, conversationStatus = null, onSubmit, @@ -147,6 +149,7 @@ export function CustomChatInput({ chatContainerRef={chatContainerRef} isDragOver={isDragOver} disabled={isDisabled} + isNewConversationPending={isNewConversationPending} showButton={showButton} buttonClassName={buttonClassName} chatInputRef={chatInputRef} diff --git a/frontend/src/components/features/chat/interactive-chat-box.tsx b/frontend/src/components/features/chat/interactive-chat-box.tsx index 74818d1d6c..cf46336887 100644 --- a/frontend/src/components/features/chat/interactive-chat-box.tsx +++ b/frontend/src/components/features/chat/interactive-chat-box.tsx @@ -13,9 +13,13 @@ import { isTaskPolling } from "#/utils/utils"; interface InteractiveChatBoxProps { onSubmit: (message: string, images: File[], files: File[]) => void; + disabled?: boolean; } -export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) { +export function InteractiveChatBox({ + onSubmit, + disabled = false, +}: InteractiveChatBoxProps) { const { images, files, @@ -145,6 +149,7 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) { // Allow users to submit messages during LOADING state - they will be // queued server-side and delivered when the conversation becomes ready const isDisabled = + disabled || curAgentState === AgentState.AWAITING_USER_CONFIRMATION || isTaskPolling(subConversationTaskStatus); @@ -152,6 +157,7 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) {
{ + const queryClient = useQueryClient(); + const navigate = useNavigate(); + const { t } = useTranslation(); + const { data: conversation } = useActiveConversation(); + + const mutation = useMutation({ + mutationFn: async () => { + if (!conversation?.conversation_id || !conversation.sandbox_id) { + throw new Error("No active conversation or sandbox"); + } + + // Fetch V1 conversation data to get llm_model (not available in legacy type) + const v1Conversations = + await V1ConversationService.batchGetAppConversations([ + conversation.conversation_id, + ]); + const llmModel = v1Conversations?.[0]?.llm_model; + + // Start a new conversation reusing the existing sandbox directly. + // We pass sandbox_id instead of parent_conversation_id so that the + // new conversation is NOT marked as a sub-conversation and will + // appear in the conversation list. + const startTask = await V1ConversationService.createConversation( + conversation.selected_repository ?? undefined, // selectedRepository + conversation.git_provider ?? undefined, // git_provider + undefined, // initialUserMsg + conversation.selected_branch ?? undefined, // selected_branch + undefined, // conversationInstructions + undefined, // suggestedTask + undefined, // trigger + undefined, // parent_conversation_id + undefined, // agent_type + conversation.sandbox_id ?? undefined, // sandbox_id - reuse the same sandbox + llmModel ?? undefined, // llm_model - preserve the LLM model + ); + + // Poll for the task to complete and get the new conversation ID + let task = await V1ConversationService.getStartTask(startTask.id); + const maxAttempts = 60; // 60 seconds timeout + let attempts = 0; + + /* eslint-disable no-await-in-loop */ + while ( + task && + !["READY", "ERROR"].includes(task.status) && + attempts < maxAttempts + ) { + // eslint-disable-next-line no-await-in-loop + await new Promise((resolve) => { + setTimeout(resolve, 1000); + }); + task = await V1ConversationService.getStartTask(startTask.id); + attempts += 1; + } + + if (!task || task.status !== "READY" || !task.app_conversation_id) { + throw new Error( + task?.detail || "Failed to create new conversation in sandbox", + ); + } + + return { + newConversationId: task.app_conversation_id, + oldConversationId: conversation.conversation_id, + }; + }, + onMutate: () => { + toast.loading(t(I18nKey.CONVERSATION$CLEARING), { + ...TOAST_OPTIONS, + id: "clear-conversation", + }); + }, + onSuccess: (data) => { + toast.dismiss("clear-conversation"); + displaySuccessToast(t(I18nKey.CONVERSATION$CLEAR_SUCCESS)); + navigate(`/conversations/${data.newConversationId}`); + + // Refresh the sidebar to show the new conversation. + queryClient.invalidateQueries({ + queryKey: ["user", "conversations"], + }); + queryClient.invalidateQueries({ + queryKey: ["v1-batch-get-app-conversations"], + }); + }, + onError: (error) => { + toast.dismiss("clear-conversation"); + let clearError = t(I18nKey.CONVERSATION$CLEAR_UNKNOWN_ERROR); + if (error instanceof Error) { + clearError = error.message; + } else if (typeof error === "string") { + clearError = error; + } + displayErrorToast( + t(I18nKey.CONVERSATION$CLEAR_FAILED, { error: clearError }), + ); + }, + }); + + return mutation; +}; diff --git a/frontend/src/hooks/query/use-unified-get-git-changes.ts b/frontend/src/hooks/query/use-unified-get-git-changes.ts index 801b1a067a..a1de3852f9 100644 --- a/frontend/src/hooks/query/use-unified-get-git-changes.ts +++ b/frontend/src/hooks/query/use-unified-get-git-changes.ts @@ -57,6 +57,7 @@ export const useUnifiedGetGitChanges = () => { retry: false, staleTime: 1000 * 60 * 5, // 5 minutes gcTime: 1000 * 60 * 15, // 15 minutes + refetchOnMount: "always", // Always refetch when mounting (e.g. navigating between conversations that share a sandbox) enabled: runtimeIsReady && !!conversationId, meta: { disableToast: true, diff --git a/frontend/src/i18n/declaration.ts b/frontend/src/i18n/declaration.ts index fe6f248cfa..9b355ae432 100644 --- a/frontend/src/i18n/declaration.ts +++ b/frontend/src/i18n/declaration.ts @@ -1151,6 +1151,14 @@ export enum I18nKey { ONBOARDING$NEXT_BUTTON = "ONBOARDING$NEXT_BUTTON", ONBOARDING$BACK_BUTTON = "ONBOARDING$BACK_BUTTON", ONBOARDING$FINISH_BUTTON = "ONBOARDING$FINISH_BUTTON", + CONVERSATION$CLEAR_V1_ONLY = "CONVERSATION$CLEAR_V1_ONLY", + CONVERSATION$CLEAR_EMPTY = "CONVERSATION$CLEAR_EMPTY", + CONVERSATION$CLEAR_NO_ID = "CONVERSATION$CLEAR_NO_ID", + CONVERSATION$CLEAR_NO_NEW_ID = "CONVERSATION$CLEAR_NO_NEW_ID", + CONVERSATION$CLEAR_UNKNOWN_ERROR = "CONVERSATION$CLEAR_UNKNOWN_ERROR", + CONVERSATION$CLEAR_FAILED = "CONVERSATION$CLEAR_FAILED", + CONVERSATION$CLEAR_SUCCESS = "CONVERSATION$CLEAR_SUCCESS", + CONVERSATION$CLEARING = "CONVERSATION$CLEARING", CTA$ENTERPRISE = "CTA$ENTERPRISE", CTA$ENTERPRISE_DEPLOY = "CTA$ENTERPRISE_DEPLOY", CTA$FEATURE_ON_PREMISES = "CTA$FEATURE_ON_PREMISES", diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index f43c33b0d2..57b89cd193 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -19569,6 +19569,142 @@ "uk": "Завершити", "ca": "Finalitza" }, + "CONVERSATION$CLEAR_V1_ONLY": { + "en": "The /new command is only available for V1 conversations", + "ja": "/newコマンドはV1会話でのみ使用できます", + "zh-CN": "/new 命令仅适用于 V1 对话", + "zh-TW": "/new 指令僅適用於 V1 對話", + "ko-KR": "/new 명령은 V1 대화에서만 사용할 수 있습니다", + "no": "/new-kommandoen er kun tilgjengelig for V1-samtaler", + "it": "Il comando /new è disponibile solo per le conversazioni V1", + "pt": "O comando /new está disponível apenas para conversas V1", + "es": "El comando /new solo está disponible para conversaciones V1", + "ar": "أمر /new متاح فقط لمحادثات V1", + "fr": "La commande /new n'est disponible que pour les conversations V1", + "tr": "/new komutu yalnızca V1 konuşmalarında kullanılabilir", + "de": "Der /new-Befehl ist nur für V1-Konversationen verfügbar", + "uk": "Команда /new доступна лише для розмов V1", + "ca": "L'ordre /new només està disponible per a converses V1" + }, + "CONVERSATION$CLEAR_EMPTY": { + "en": "Nothing to clear. This conversation has no messages yet.", + "ja": "クリアするものがありません。この会話にはまだメッセージがありません。", + "zh-CN": "没有可清除的内容。此对话尚无消息。", + "zh-TW": "沒有可清除的內容。此對話尚無訊息。", + "ko-KR": "지울 내용이 없습니다. 이 대화에는 아직 메시지가 없습니다.", + "no": "Ingenting å tømme. Denne samtalen har ingen meldinger ennå.", + "it": "Niente da cancellare. Questa conversazione non ha ancora messaggi.", + "pt": "Nada para limpar. Esta conversa ainda não tem mensagens.", + "es": "Nada que borrar. Esta conversación aún no tiene mensajes.", + "ar": "لا يوجد شيء لمسحه. لا تحتوي هذه المحادثة على رسائل بعد.", + "fr": "Rien à effacer. Cette conversation n'a pas encore de messages.", + "tr": "Temizlenecek bir şey yok. Bu konuşmada henüz mesaj yok.", + "de": "Nichts zu löschen. Diese Konversation hat noch keine Nachrichten.", + "uk": "Нічого очищувати. Ця розмова ще не має повідомлень.", + "ca": "No hi ha res a esborrar. Aquesta conversa encara no té missatges." + }, + "CONVERSATION$CLEAR_NO_ID": { + "en": "No conversation ID found", + "ja": "会話IDが見つかりません", + "zh-CN": "未找到对话 ID", + "zh-TW": "找不到對話 ID", + "ko-KR": "대화 ID를 찾을 수 없습니다", + "no": "Ingen samtale-ID funnet", + "it": "Nessun ID conversazione trovato", + "pt": "Nenhum ID de conversa encontrado", + "es": "No se encontró el ID de conversación", + "ar": "لم يتم العثور على معرف المحادثة", + "fr": "Aucun identifiant de conversation trouvé", + "tr": "Konuşma kimliği bulunamadı", + "de": "Keine Konversations-ID gefunden", + "uk": "Ідентифікатор розмови не знайдено", + "ca": "No s'ha trobat l'identificador de la conversa" + }, + "CONVERSATION$CLEAR_NO_NEW_ID": { + "en": "Server did not return a new conversation ID", + "ja": "サーバーが新しい会話IDを返しませんでした", + "zh-CN": "服务器未返回新的对话 ID", + "zh-TW": "伺服器未返回新的對話 ID", + "ko-KR": "서버가 새 대화 ID를 반환하지 않았습니다", + "no": "Serveren returnerte ikke en ny samtale-ID", + "it": "Il server non ha restituito un nuovo ID conversazione", + "pt": "O servidor não retornou um novo ID de conversa", + "es": "El servidor no devolvió un nuevo ID de conversación", + "ar": "لم يقم الخادم بإرجاع معرف محادثة جديد", + "fr": "Le serveur n'a pas renvoyé un nouvel identifiant de conversation", + "tr": "Sunucu yeni bir konuşma kimliği döndürmedi", + "de": "Der Server hat keine neue Konversations-ID zurückgegeben", + "uk": "Сервер не повернув новий ідентифікатор розмови", + "ca": "El servidor no ha retornat un nou identificador de conversa" + }, + "CONVERSATION$CLEAR_UNKNOWN_ERROR": { + "en": "Unknown error", + "ja": "不明なエラー", + "zh-CN": "未知错误", + "zh-TW": "未知錯誤", + "ko-KR": "알 수 없는 오류", + "no": "Ukjent feil", + "it": "Errore sconosciuto", + "pt": "Erro desconhecido", + "es": "Error desconocido", + "ar": "خطأ غير معروف", + "fr": "Erreur inconnue", + "tr": "Bilinmeyen hata", + "de": "Unbekannter Fehler", + "uk": "Невідома помилка", + "ca": "Error desconegut" + }, + "CONVERSATION$CLEAR_FAILED": { + "en": "Failed to start new conversation: {{error}}", + "ja": "新しい会話の開始に失敗しました: {{error}}", + "zh-CN": "启动新对话失败: {{error}}", + "zh-TW": "啟動新對話失敗: {{error}}", + "ko-KR": "새 대화 시작 실패: {{error}}", + "no": "Kunne ikke starte ny samtale: {{error}}", + "it": "Impossibile avviare una nuova conversazione: {{error}}", + "pt": "Falha ao iniciar nova conversa: {{error}}", + "es": "Error al iniciar nueva conversación: {{error}}", + "ar": "فشل في بدء محادثة جديدة: {{error}}", + "fr": "Échec du démarrage d'une nouvelle conversation : {{error}}", + "tr": "Yeni konuşma başlatılamadı: {{error}}", + "de": "Neue Konversation konnte nicht gestartet werden: {{error}}", + "uk": "Не вдалося розпочати нову розмову: {{error}}", + "ca": "No s'ha pogut iniciar una nova conversa: {{error}}" + }, + "CONVERSATION$CLEAR_SUCCESS": { + "en": "Starting a new conversation in the same sandbox. These conversations share the same runtime.", + "ja": "同じサンドボックスで新しい会話を開始します。これらの会話は同じランタイムを共有します。", + "zh-CN": "正在同一沙箱中开始新对话。这些对话共享同一运行时。", + "zh-TW": "正在同一沙盒中開始新對話。這些對話共享同一執行環境。", + "ko-KR": "같은 샌드박스에서 새 대화를 시작합니다. 이 대화들은 같은 런타임을 공유합니다.", + "no": "Starter ny samtale i samme sandbox. Disse samtalene deler samme kjøretid.", + "it": "Avvio nuova conversazione nello stesso sandbox. Queste conversazioni condividono lo stesso runtime.", + "pt": "Iniciando nova conversa no mesmo sandbox. Essas conversas compartilham o mesmo runtime.", + "es": "Iniciando nueva conversación en el mismo sandbox. Estas conversaciones comparten el mismo runtime.", + "ar": "بدء محادثة جديدة في نفس صندوق الحماية. هذه المحادثات تشارك نفس وقت التشغيل.", + "fr": "Démarrage d'une nouvelle conversation dans le même bac à sable. Ces conversations partagent le même environnement d'exécution.", + "tr": "Aynı korumalı alanda yeni konuşma başlatılıyor. Bu konuşmalar aynı çalışma ortamını paylaşır.", + "de": "Starte neue Konversation in derselben Sandbox. Diese Konversationen teilen dieselbe Laufzeitumgebung.", + "uk": "Починаю нову розмову в тому самому захищеному середовищі. Ці розмови використовують одне середовище виконання.", + "ca": "S'està iniciant una nova conversa al mateix entorn aïllat. Aquestes converses comparteixen el mateix entorn d'execució." + }, + "CONVERSATION$CLEARING": { + "en": "Creating new conversation...", + "ja": "新しい会話を作成中...", + "zh-CN": "正在创建新对话...", + "zh-TW": "正在建立新對話...", + "ko-KR": "새 대화를 만드는 중...", + "no": "Oppretter ny samtale...", + "it": "Creazione nuova conversazione...", + "pt": "Criando nova conversa...", + "es": "Creando nueva conversación...", + "ar": "جارٍ إنشاء محادثة جديدة...", + "fr": "Création d'une nouvelle conversation...", + "tr": "Yeni konuşma oluşturuluyor...", + "de": "Neue Konversation wird erstellt...", + "uk": "Створення нової розмови...", + "ca": "S'està creant una nova conversa..." + }, "CTA$ENTERPRISE": { "en": "Enterprise", "ja": "エンタープライズ", diff --git a/frontend/src/utils/websocket-url.ts b/frontend/src/utils/websocket-url.ts index 0e72c24dc8..787032b2c9 100644 --- a/frontend/src/utils/websocket-url.ts +++ b/frontend/src/utils/websocket-url.ts @@ -9,6 +9,19 @@ export function extractBaseHost( if (conversationUrl && !conversationUrl.startsWith("/")) { try { const url = new URL(conversationUrl); + // If the conversation URL points to localhost but we're accessing from external, + // use the browser's hostname with the conversation URL's port + const urlHostname = url.hostname; + const browserHostname = + window.location.hostname ?? window.location.host?.split(":")[0]; + if ( + browserHostname && + (urlHostname === "localhost" || urlHostname === "127.0.0.1") && + browserHostname !== "localhost" && + browserHostname !== "127.0.0.1" + ) { + return `${browserHostname}:${url.port}`; + } return url.host; // e.g., "localhost:3000" } catch { return window.location.host; diff --git a/openhands/app_server/app_conversation/app_conversation_info_service.py b/openhands/app_server/app_conversation/app_conversation_info_service.py index bb83ab5801..e14f1dbf6e 100644 --- a/openhands/app_server/app_conversation/app_conversation_info_service.py +++ b/openhands/app_server/app_conversation/app_conversation_info_service.py @@ -84,6 +84,14 @@ class AppConversationInfoService(ABC): List of sub-conversation IDs """ + @abstractmethod + async def count_conversations_by_sandbox_id(self, sandbox_id: str) -> int: + """Count V1 conversations that reference the given sandbox. + + Used to decide whether a sandbox can be safely deleted when a + conversation is removed (only delete if count is 0). + """ + # Mutators @abstractmethod diff --git a/openhands/app_server/app_conversation/app_conversation_service.py b/openhands/app_server/app_conversation/app_conversation_service.py index 1f955cac9c..6be1d32ddf 100644 --- a/openhands/app_server/app_conversation/app_conversation_service.py +++ b/openhands/app_server/app_conversation/app_conversation_service.py @@ -77,8 +77,20 @@ class AppConversationService(ABC): id, starting a conversation, attaching a callback, and then running the conversation. - Yields an instance of AppConversationStartTask as updates occur, which can be used to determine - the progress of the task. + This method returns an async iterator that yields the same + AppConversationStartTask repeatedly as status updates occur. Callers + should iterate until the task reaches a terminal status:: + + async for task in service.start_app_conversation(request): + if task.status in ( + AppConversationStartTaskStatus.READY, + AppConversationStartTaskStatus.ERROR, + ): + break + + Status progression: WORKING → WAITING_FOR_SANDBOX → PREPARING_REPOSITORY + → RUNNING_SETUP_SCRIPT → SETTING_UP_GIT_HOOKS → SETTING_UP_SKILLS + → STARTING_CONVERSATION → READY (or ERROR at any point). """ # This is an abstract method - concrete implementations should provide real values from openhands.app_server.app_conversation.app_conversation_models import ( @@ -111,15 +123,21 @@ class AppConversationService(ABC): """ @abstractmethod - async def delete_app_conversation(self, conversation_id: UUID) -> bool: + async def delete_app_conversation( + self, conversation_id: UUID, skip_agent_server_delete: bool = False + ) -> bool: """Delete a V1 conversation and all its associated data. Args: conversation_id: The UUID of the conversation to delete. + skip_agent_server_delete: If True, skip the agent server DELETE call. + This should be set when the sandbox is shared with other + conversations (e.g. created via /new) to avoid destabilizing + the shared runtime. This method should: 1. Delete the conversation from the database - 2. Call the agent server to delete the conversation + 2. Call the agent server to delete the conversation (unless skipped) 3. Clean up any related data Returns True if the conversation was deleted successfully, False otherwise. diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index 703899ec83..b85e1de48f 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -1740,13 +1740,19 @@ class LiveStatusAppConversationService(AppConversationServiceBase): conversations = await self._build_app_conversations([info]) return conversations[0] - async def delete_app_conversation(self, conversation_id: UUID) -> bool: + async def delete_app_conversation( + self, conversation_id: UUID, skip_agent_server_delete: bool = False + ) -> bool: """Delete a V1 conversation and all its associated data. This method will also cascade delete all sub-conversations of the parent. Args: conversation_id: The UUID of the conversation to delete. + skip_agent_server_delete: If True, skip the agent server DELETE call. + This should be set when the sandbox is shared with other + conversations (e.g. created via /new) to avoid destabilizing + the shared runtime. """ # Check if we have the required SQL implementation for transactional deletion if not isinstance( @@ -1772,8 +1778,9 @@ class LiveStatusAppConversationService(AppConversationServiceBase): await self._delete_sub_conversations(conversation_id) # Now delete the parent conversation - # Delete from agent server if sandbox is running - await self._delete_from_agent_server(app_conversation) + # Delete from agent server if sandbox is running (skip if sandbox is shared) + if not skip_agent_server_delete: + await self._delete_from_agent_server(app_conversation) # Delete from database using the conversation info from app_conversation # AppConversation extends AppConversationInfo, so we can use it directly diff --git a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py index c7c9e1935e..80b77957ba 100644 --- a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py +++ b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py @@ -278,6 +278,14 @@ class SQLAppConversationInfoService(AppConversationInfoService): rows = result_set.scalars().all() return [UUID(row.conversation_id) for row in rows] + async def count_conversations_by_sandbox_id(self, sandbox_id: str) -> int: + query = await self._secure_select() + query = query.where(StoredConversationMetadata.sandbox_id == sandbox_id) + count_query = select(func.count()).select_from(query.subquery()) + result = await self.db_session.execute(count_query) + count = result.scalar() + return count or 0 + async def get_app_conversation_info( self, conversation_id: UUID ) -> AppConversationInfo | None: diff --git a/openhands/app_server/config.py b/openhands/app_server/config.py index 4b7f78e389..96168143ed 100644 --- a/openhands/app_server/config.py +++ b/openhands/app_server/config.py @@ -87,6 +87,19 @@ def get_default_web_url() -> str | None: return f'https://{web_host}' +def get_default_permitted_cors_origins() -> list[str]: + """Get permitted CORS origins, falling back to legacy PERMITTED_CORS_ORIGINS env var. + + The preferred configuration is via OH_PERMITTED_CORS_ORIGINS_0, _1, etc. + (handled by the pydantic from_env parser). This fallback supports the legacy + comma-separated PERMITTED_CORS_ORIGINS environment variable. + """ + legacy = os.getenv('PERMITTED_CORS_ORIGINS', '') + if legacy: + return [o.strip() for o in legacy.split(',') if o.strip()] + return [] + + def get_openhands_provider_base_url() -> str | None: """Return the base URL for the OpenHands provider, if configured.""" return os.getenv('OPENHANDS_PROVIDER_BASE_URL') or None @@ -106,6 +119,14 @@ class AppServerConfig(OpenHandsModel): default_factory=get_default_web_url, description='The URL where OpenHands is running (e.g., http://localhost:3000)', ) + permitted_cors_origins: list[str] = Field( + default_factory=get_default_permitted_cors_origins, + description=( + 'Additional permitted CORS origins for both the app server and agent ' + 'server containers. Configure via OH_PERMITTED_CORS_ORIGINS_0, _1, etc. ' + 'Falls back to legacy PERMITTED_CORS_ORIGINS env var.' + ), + ) openhands_provider_base_url: str | None = Field( default_factory=get_openhands_provider_base_url, description='Base URL for the OpenHands provider', diff --git a/openhands/app_server/sandbox/docker_sandbox_service.py b/openhands/app_server/sandbox/docker_sandbox_service.py index f5a302fa73..cccd873cb6 100644 --- a/openhands/app_server/sandbox/docker_sandbox_service.py +++ b/openhands/app_server/sandbox/docker_sandbox_service.py @@ -27,7 +27,6 @@ from openhands.app_server.sandbox.sandbox_models import ( SandboxStatus, ) from openhands.app_server.sandbox.sandbox_service import ( - ALLOW_CORS_ORIGINS_VARIABLE, SESSION_API_KEY_VARIABLE, WEBHOOK_CALLBACK_VARIABLE, SandboxService, @@ -91,6 +90,7 @@ class DockerSandboxService(SandboxService): httpx_client: httpx.AsyncClient max_num_sandboxes: int web_url: str | None = None + permitted_cors_origins: list[str] = field(default_factory=list) extra_hosts: dict[str, str] = field(default_factory=dict) docker_client: docker.DockerClient = field(default_factory=get_docker_client) startup_grace_seconds: int = STARTUP_GRACE_SECONDS @@ -386,8 +386,18 @@ class DockerSandboxService(SandboxService): # Set CORS origins for remote browser access when web_url is configured. # This allows the agent-server container to accept requests from the # frontend when running OpenHands on a remote machine. + # Each origin gets its own indexed env var (OH_ALLOW_CORS_ORIGINS_0, _1, etc.) + cors_origins: list[str] = [] if self.web_url: - env_vars[ALLOW_CORS_ORIGINS_VARIABLE] = self.web_url + cors_origins.append(self.web_url) + cors_origins.extend(self.permitted_cors_origins) + # Deduplicate while preserving order + seen: set[str] = set() + for origin in cors_origins: + if origin not in seen: + seen.add(origin) + idx = len(seen) - 1 + env_vars[f'OH_ALLOW_CORS_ORIGINS_{idx}'] = origin # Prepare port mappings and add port environment variables # When using host network, container ports are directly accessible on the host @@ -621,7 +631,7 @@ class DockerSandboxServiceInjector(SandboxServiceInjector): get_sandbox_spec_service, ) - # Get web_url from global config for CORS support + # Get web_url and permitted_cors_origins from global config config = get_global_config() web_url = config.web_url @@ -640,6 +650,7 @@ class DockerSandboxServiceInjector(SandboxServiceInjector): httpx_client=httpx_client, max_num_sandboxes=self.max_num_sandboxes, web_url=web_url, + permitted_cors_origins=config.permitted_cors_origins, extra_hosts=self.extra_hosts, startup_grace_seconds=self.startup_grace_seconds, use_host_network=self.use_host_network, diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 902a881df8..b1e9c5649e 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -7,7 +7,7 @@ # Tag: Legacy-V0 # This module belongs to the old V0 web server. The V1 application server lives under openhands/app_server/. import asyncio -import os +import logging from collections import defaultdict from datetime import datetime, timedelta from urllib.parse import urlparse @@ -20,6 +20,8 @@ from starlette.requests import Request as StarletteRequest from starlette.responses import Response from starlette.types import ASGIApp +from openhands.app_server.config import get_global_config + class LocalhostCORSMiddleware(CORSMiddleware): """Custom CORS middleware that allows any request from localhost/127.0.0.1 domains, @@ -27,13 +29,8 @@ class LocalhostCORSMiddleware(CORSMiddleware): """ def __init__(self, app: ASGIApp) -> None: - allow_origins_str = os.getenv('PERMITTED_CORS_ORIGINS') - if allow_origins_str: - allow_origins = tuple( - origin.strip() for origin in allow_origins_str.split(',') - ) - else: - allow_origins = () + config = get_global_config() + allow_origins = tuple(config.permitted_cors_origins) super().__init__( app, allow_origins=allow_origins, @@ -51,6 +48,14 @@ class LocalhostCORSMiddleware(CORSMiddleware): if hostname in ['localhost', '127.0.0.1']: return True + # Allow any origin when no specific origins are configured (development mode) + # WARNING: This disables CORS protection. Use explicit CORS origins in production. + logging.getLogger(__name__).warning( + f'No CORS origins configured, allowing origin: {origin}. ' + 'Set OH_PERMITTED_CORS_ORIGINS for production environments.' + ) + return True + # For missing origin or other origins, use the parent class's logic result: bool = super().is_allowed_origin(origin) return result diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index fa73aa4d52..5789e9784c 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -603,16 +603,28 @@ async def _try_delete_v1_conversation( ) ) if app_conversation_info: + # Check if the sandbox is shared with other conversations + # (e.g. multiple conversations can share a sandbox via /new). + # If shared, skip the agent server DELETE call to avoid + # destabilizing the runtime for the remaining conversations. + sandbox_id = app_conversation_info.sandbox_id + sandbox_is_shared = False + if sandbox_id: + conversation_count = await app_conversation_info_service.count_conversations_by_sandbox_id( + sandbox_id + ) + sandbox_is_shared = conversation_count > 1 + # This is a V1 conversation, delete it using the app conversation service - # Pass the conversation ID for secure deletion result = await app_conversation_service.delete_app_conversation( - app_conversation_info.id + app_conversation_info.id, + skip_agent_server_delete=sandbox_is_shared, ) # Manually commit so that the conversation will vanish from the list await db_session.commit() - # Delete the sandbox in the background + # Delete the sandbox in the background (checks remaining conversations first) asyncio.create_task( _finalize_delete_and_close_connections( sandbox_service, diff --git a/tests/unit/app_server/test_sql_app_conversation_info_service.py b/tests/unit/app_server/test_sql_app_conversation_info_service.py index a491fa93af..48e9693641 100644 --- a/tests/unit/app_server/test_sql_app_conversation_info_service.py +++ b/tests/unit/app_server/test_sql_app_conversation_info_service.py @@ -286,6 +286,54 @@ class TestSQLAppConversationInfoService: results = await service.batch_get_app_conversation_info([]) assert results == [] + @pytest.mark.asyncio + async def test_count_conversations_by_sandbox_id( + self, + service: SQLAppConversationInfoService, + ): + """Test count by sandbox_id: only delete sandbox when no conversation uses it.""" + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + shared_sandbox = 'shared_sandbox_1' + other_sandbox = 'other_sandbox' + for i in range(3): + info = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id=shared_sandbox, + selected_repository='https://github.com/test/repo', + selected_branch='main', + git_provider=ProviderType.GITHUB, + title=f'Conversation {i}', + trigger=ConversationTrigger.GUI, + pr_number=[], + llm_model='gpt-4', + metrics=None, + created_at=base_time, + updated_at=base_time, + ) + await service.save_app_conversation_info(info) + for i in range(2): + info = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id=other_sandbox, + selected_repository='https://github.com/test/repo', + selected_branch='main', + git_provider=ProviderType.GITHUB, + title=f'Other {i}', + trigger=ConversationTrigger.GUI, + pr_number=[], + llm_model='gpt-4', + metrics=None, + created_at=base_time, + updated_at=base_time, + ) + await service.save_app_conversation_info(info) + + assert await service.count_conversations_by_sandbox_id(shared_sandbox) == 3 + assert await service.count_conversations_by_sandbox_id(other_sandbox) == 2 + assert await service.count_conversations_by_sandbox_id('no_such_sandbox') == 0 + @pytest.mark.asyncio async def test_search_conversation_info_no_filters( self, diff --git a/tests/unit/server/data_models/test_conversation.py b/tests/unit/server/data_models/test_conversation.py index 99dbdfaacc..3c84afd0c6 100644 --- a/tests/unit/server/data_models/test_conversation.py +++ b/tests/unit/server/data_models/test_conversation.py @@ -1038,6 +1038,9 @@ async def test_delete_v1_conversation_success(): return_value=mock_app_conversation_info ) mock_service.delete_app_conversation = AsyncMock(return_value=True) + mock_info_service.count_conversations_by_sandbox_id = AsyncMock( + return_value=1 + ) # Call delete_conversation with V1 conversation ID result = await delete_conversation( @@ -1059,7 +1062,8 @@ async def test_delete_v1_conversation_success(): # Verify that delete_app_conversation was called with the conversation ID mock_service.delete_app_conversation.assert_called_once_with( - conversation_uuid + conversation_uuid, + skip_agent_server_delete=False, ) @@ -1357,6 +1361,9 @@ async def test_delete_v1_conversation_with_agent_server(): return_value=mock_app_conversation_info ) mock_service.delete_app_conversation = AsyncMock(return_value=True) + mock_info_service.count_conversations_by_sandbox_id = AsyncMock( + return_value=1 + ) # Call delete_conversation with V1 conversation ID result = await delete_conversation( @@ -1378,7 +1385,8 @@ async def test_delete_v1_conversation_with_agent_server(): # Verify that delete_app_conversation was called with the conversation ID mock_service.delete_app_conversation.assert_called_once_with( - conversation_uuid + conversation_uuid, + skip_agent_server_delete=False, ) diff --git a/tests/unit/server/test_middleware.py b/tests/unit/server/test_middleware.py index 2bdf2275fc..cdc922b5e9 100644 --- a/tests/unit/server/test_middleware.py +++ b/tests/unit/server/test_middleware.py @@ -1,5 +1,4 @@ -import os -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from fastapi import FastAPI @@ -21,34 +20,46 @@ def app(): return app -def test_localhost_cors_middleware_init_with_env_var(): - """Test that the middleware correctly parses PERMITTED_CORS_ORIGINS environment variable.""" - with patch.dict( - os.environ, {'PERMITTED_CORS_ORIGINS': 'https://example.com,https://test.com'} +def test_localhost_cors_middleware_init_with_config(): + """Test that the middleware correctly reads permitted_cors_origins from global config.""" + mock_config = MagicMock() + mock_config.permitted_cors_origins = [ + 'https://example.com', + 'https://test.com', + ] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config ): app = FastAPI() middleware = LocalhostCORSMiddleware(app) - # Check that the origins were correctly parsed from the environment variable + # Check that the origins were correctly read from the config assert 'https://example.com' in middleware.allow_origins assert 'https://test.com' in middleware.allow_origins assert len(middleware.allow_origins) == 2 -def test_localhost_cors_middleware_init_without_env_var(): - """Test that the middleware works correctly without PERMITTED_CORS_ORIGINS environment variable.""" - with patch.dict(os.environ, {}, clear=True): +def test_localhost_cors_middleware_init_without_config(): + """Test that the middleware works correctly without permitted_cors_origins configured.""" + mock_config = MagicMock() + mock_config.permitted_cors_origins = [] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config + ): app = FastAPI() middleware = LocalhostCORSMiddleware(app) - # Check that allow_origins is empty when no environment variable is set + # Check that allow_origins is empty when no origins are configured assert middleware.allow_origins == () def test_localhost_cors_middleware_is_allowed_origin_localhost(app): """Test that localhost origins are allowed regardless of port when no specific origins are configured.""" - # Test without setting PERMITTED_CORS_ORIGINS to trigger localhost behavior - with patch.dict(os.environ, {}, clear=True): + mock_config = MagicMock() + mock_config.permitted_cors_origins = [] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config + ): app.add_middleware(LocalhostCORSMiddleware) client = TestClient(app) @@ -76,8 +87,11 @@ def test_localhost_cors_middleware_is_allowed_origin_localhost(app): def test_localhost_cors_middleware_is_allowed_origin_non_localhost(app): """Test that non-localhost origins follow the standard CORS rules.""" - # Set up the middleware with specific allowed origins - with patch.dict(os.environ, {'PERMITTED_CORS_ORIGINS': 'https://example.com'}): + mock_config = MagicMock() + mock_config.permitted_cors_origins = ['https://example.com'] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config + ): app.add_middleware(LocalhostCORSMiddleware) client = TestClient(app) @@ -95,7 +109,11 @@ def test_localhost_cors_middleware_is_allowed_origin_non_localhost(app): def test_localhost_cors_middleware_missing_origin(app): """Test behavior when Origin header is missing.""" - with patch.dict(os.environ, {}, clear=True): + mock_config = MagicMock() + mock_config.permitted_cors_origins = [] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config + ): app.add_middleware(LocalhostCORSMiddleware) client = TestClient(app) @@ -113,17 +131,22 @@ def test_localhost_cors_middleware_inheritance(): def test_localhost_cors_middleware_cors_parameters(): """Test that CORS parameters are set correctly in the middleware.""" - # We need to inspect the initialization parameters rather than attributes - # since CORSMiddleware doesn't expose these as attributes - with patch('fastapi.middleware.cors.CORSMiddleware.__init__') as mock_init: - mock_init.return_value = None - app = FastAPI() - LocalhostCORSMiddleware(app) + mock_config = MagicMock() + mock_config.permitted_cors_origins = [] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config + ): + # We need to inspect the initialization parameters rather than attributes + # since CORSMiddleware doesn't expose these as attributes + with patch('fastapi.middleware.cors.CORSMiddleware.__init__') as mock_init: + mock_init.return_value = None + app = FastAPI() + LocalhostCORSMiddleware(app) - # Check that the parent class was initialized with the correct parameters - mock_init.assert_called_once() - _, kwargs = mock_init.call_args + # Check that the parent class was initialized with the correct parameters + mock_init.assert_called_once() + _, kwargs = mock_init.call_args - assert kwargs['allow_credentials'] is True - assert kwargs['allow_methods'] == ['*'] - assert kwargs['allow_headers'] == ['*'] + assert kwargs['allow_credentials'] is True + assert kwargs['allow_methods'] == ['*'] + assert kwargs['allow_headers'] == ['*']